mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a14791da3 | ||
|
|
711a579945 | ||
|
|
1ac5a24a8a | ||
|
|
282df322d5 | ||
|
|
8523ea88f2 | ||
|
|
cad97d3da3 | ||
|
|
efc5a762fc | ||
|
|
9131c45645 | ||
|
|
75ca44d5f9 | ||
|
|
8b08af3949 | ||
|
|
df9ea8dcc1 | ||
|
|
25a57326b3 | ||
|
|
7f3e8087ba | ||
|
|
dfc7835359 |
8
.github/workflows/build-container.yml
vendored
8
.github/workflows/build-container.yml
vendored
@@ -45,9 +45,6 @@ jobs:
|
||||
steps:
|
||||
- name: Free up more disk space on the runner
|
||||
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
|
||||
# the /mnt dir has 70GBs of free space
|
||||
# /dev/sda1 74G 28K 70G 1% /mnt
|
||||
# According to some online posts the /mnt is not always there, so checking before setting docker to use it
|
||||
run: |
|
||||
echo "----- Free space before cleanup"
|
||||
df -h
|
||||
@@ -55,11 +52,6 @@ jobs:
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
if [ -d /mnt ]; then
|
||||
sudo chmod -R 777 /mnt
|
||||
echo '{"data-root": "/mnt/docker-root"}' | sudo tee /etc/docker/daemon.json
|
||||
sudo systemctl restart docker
|
||||
fi
|
||||
echo "----- Free space after cleanup"
|
||||
df -h
|
||||
|
||||
|
||||
30
.github/workflows/lfs-checks.yml
vendored
30
.github/workflows/lfs-checks.yml
vendored
@@ -1,30 +0,0 @@
|
||||
# Checks that large files and LFS-tracked files are properly checked in with pointer format.
|
||||
# Uses https://github.com/ppremk/lfs-warning to detect LFS issues.
|
||||
|
||||
name: 'lfs checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
lfs-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
# Required to label and comment on the PRs
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check lfs files
|
||||
uses: ppremk/lfs-warning@v3.3
|
||||
12
.github/workflows/typegen-checks.yml
vendored
12
.github/workflows/typegen-checks.yml
vendored
@@ -39,18 +39,6 @@ jobs:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up more disk space on the runner
|
||||
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
|
||||
run: |
|
||||
echo "----- Free space before cleanup"
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
echo "----- Free space after cleanup"
|
||||
df -h
|
||||
|
||||
- name: check for changed files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
|
||||
@@ -22,10 +22,6 @@
|
||||
## GPU_DRIVER can be set to either `cuda` or `rocm` to enable GPU support in the container accordingly.
|
||||
# GPU_DRIVER=cuda #| rocm
|
||||
|
||||
## If you are using ROCM, you will need to ensure that the render group within the container and the host system use the same group ID.
|
||||
## To obtain the group ID of the render group on the host system, run `getent group render` and grab the number.
|
||||
# RENDER_GROUP_ID=
|
||||
|
||||
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
|
||||
## It is usually not necessary to change this. Use `id -u` on the host system to find the UID.
|
||||
# CONTAINER_UID=1000
|
||||
|
||||
@@ -43,6 +43,7 @@ ENV \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cu124" \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
@@ -73,18 +74,20 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
ulimit -n 30000 && \
|
||||
uv sync --extra $GPU_DRIVER --frozen
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids" && groupadd render
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
|
||||
fi && \
|
||||
uv sync --frozen
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
@@ -102,6 +105,8 @@ COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
ulimit -n 30000 && \
|
||||
uv pip install -e .[$GPU_DRIVER]
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
|
||||
fi && \
|
||||
uv pip install -e .
|
||||
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
#### Web UI ------------------------------------
|
||||
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
COPY invokeai/frontend/web/ ./
|
||||
RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN npx vite build
|
||||
|
||||
## Backend ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:24.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
apt update && apt install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
git \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev \
|
||||
wget
|
||||
|
||||
ENV \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
INVOKEAI_SRC=/opt/invokeai \
|
||||
PYTHON_VERSION=3.12 \
|
||||
UV_PYTHON=3.12 \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
CONTAINER_UID=${CONTAINER_UID:-1000} \
|
||||
CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
|
||||
|
||||
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install ${PYTHON_VERSION} && \
|
||||
# chmod --recursive a+rX /root/.local/share/uv/python
|
||||
chmod 711 /root
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
ulimit -n 30000 && \
|
||||
uv sync --extra $GPU_DRIVER --frozen
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
if [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
wget -O /tmp/amdgpu-install.deb \
|
||||
https://repo.radeon.com/amdgpu-install/6.3.4/ubuntu/noble/amdgpu-install_6.3.60304-1_all.deb && \
|
||||
apt install -y /tmp/amdgpu-install.deb && \
|
||||
apt update && \
|
||||
amdgpu-install --usecase=rocm -y && \
|
||||
apt-get autoclean && \
|
||||
apt clean && \
|
||||
rm -rf /tmp/* /var/tmp/* && \
|
||||
usermod -a -G render ubuntu && \
|
||||
usermod -a -G video ubuntu && \
|
||||
echo "\\n/opt/rocm/lib\\n/opt/rocm/lib64" >> /etc/ld.so.conf.d/rocm.conf && \
|
||||
ldconfig && \
|
||||
update-alternatives --auto rocm; \
|
||||
fi
|
||||
|
||||
## Heathen711: Leaving this for review input, will remove before merge
|
||||
# RUN --mount=type=cache,target=/var/cache/apt \
|
||||
# --mount=type=cache,target=/var/lib/apt \
|
||||
# if [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
# groupadd render && \
|
||||
# usermod -a -G render ubuntu && \
|
||||
# usermod -a -G video ubuntu; \
|
||||
# fi
|
||||
|
||||
## Link amdgpu.ids for ROCm builds
|
||||
## contributed by https://github.com/Rubonnek
|
||||
# RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
# ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web"]
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# add sources last to minimize image changes on code changes
|
||||
COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
|
||||
# this should not increase image size because we've already installed dependencies
|
||||
# in a previous layer
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
ulimit -n 30000 && \
|
||||
uv pip install -e .[$GPU_DRIVER]
|
||||
|
||||
@@ -47,9 +47,8 @@ services:
|
||||
|
||||
invokeai-rocm:
|
||||
<<: *invokeai
|
||||
environment:
|
||||
- AMD_VISIBLE_DEVICES=all
|
||||
- RENDER_GROUP_ID=${RENDER_GROUP_ID}
|
||||
runtime: amd
|
||||
devices:
|
||||
- /dev/kfd:/dev/kfd
|
||||
- /dev/dri:/dev/dri
|
||||
profiles:
|
||||
- rocm
|
||||
|
||||
@@ -21,17 +21,6 @@ _=$(id ${USER} 2>&1) || useradd -u ${USER_ID} ${USER}
|
||||
# ensure the UID is correct
|
||||
usermod -u ${USER_ID} ${USER} 1>/dev/null
|
||||
|
||||
## ROCM specific configuration
|
||||
# render group within the container must match the host render group
|
||||
# otherwise the container will not be able to access the host GPU.
|
||||
if [[ -v "RENDER_GROUP_ID" ]] && [[ ! -z "${RENDER_GROUP_ID}" ]]; then
|
||||
# ensure the render group exists
|
||||
groupmod -g ${RENDER_GROUP_ID} render
|
||||
usermod -a -G render ${USER}
|
||||
usermod -a -G video ${USER}
|
||||
fi
|
||||
|
||||
|
||||
### Set the $PUBLIC_KEY env var to enable SSH access.
|
||||
# We do not install openssh-server in the image by default to avoid bloat.
|
||||
# but it is useful to have the full SSH server e.g. on Runpod.
|
||||
|
||||
@@ -13,7 +13,7 @@ run() {
|
||||
|
||||
# parse .env file for build args
|
||||
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
|
||||
profile="$(awk -F '=' '/GPU_DRIVER=/ {print $2}' .env)"
|
||||
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
|
||||
|
||||
# default to 'cuda' profile
|
||||
[[ -z "$profile" ]] && profile="cuda"
|
||||
@@ -30,7 +30,7 @@ run() {
|
||||
|
||||
printf "%s\n" "starting service $service_name"
|
||||
docker compose --profile "$profile" up -d "$service_name"
|
||||
docker compose --profile "$profile" logs -f
|
||||
docker compose logs -f
|
||||
}
|
||||
|
||||
run
|
||||
|
||||
@@ -265,7 +265,7 @@ If the key is unrecognized, this call raises an
|
||||
|
||||
#### exists(key) -> AnyModelConfig
|
||||
|
||||
Returns True if a model with the given key exists in the database.
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
|
||||
@@ -718,7 +718,7 @@ When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
retrieved from the HuggingFace and Civitai model repositories.
|
||||
|
||||
The probed values can be overridden by providing a dictionary in the
|
||||
The probed values can be overriden by providing a dictionary in the
|
||||
optional `config` argument passed to `import_model()`. You may provide
|
||||
overriding values for any of the model's configuration
|
||||
attributes. Here is an example of setting the
|
||||
@@ -841,7 +841,7 @@ variable.
|
||||
|
||||
#### installer.start(invoker)
|
||||
|
||||
The `start` method is called by the API initialization routines when
|
||||
The `start` method is called by the API intialization routines when
|
||||
the API starts up. Its effect is to call `sync_to_config()` to
|
||||
synchronize the model record store database with what's currently on
|
||||
disk.
|
||||
|
||||
@@ -16,7 +16,7 @@ We thank [all contributors](https://github.com/invoke-ai/InvokeAI/graphs/contrib
|
||||
- @psychedelicious (Spencer Mabrito) - Web Team Leader
|
||||
- @joshistoast (Josh Corbett) - Web Development
|
||||
- @cheerio (Mary Rogers) - Lead Engineer & Web App Development
|
||||
- @ebr (Eugene Brodsky) - Cloud/DevOps/Software engineer; your friendly neighbourhood cluster-autoscaler
|
||||
- @ebr (Eugene Brodsky) - Cloud/DevOps/Sofware engineer; your friendly neighbourhood cluster-autoscaler
|
||||
- @sunija - Standalone version
|
||||
- @brandon (Brandon Rising) - Platform, Infrastructure, Backend Systems
|
||||
- @ryanjdick (Ryan Dick) - Machine Learning & Training
|
||||
|
||||
@@ -69,34 +69,34 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
|
||||
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
|
||||
|
||||
7. Determine the torch backend to use for installation, if any. This is necessary to get the right version of torch installed. This is acheived by using [UV's built in torch support.](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection)
|
||||
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
|
||||
|
||||
=== "Invoke v5.12 and later"
|
||||
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu128`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.3`.
|
||||
- **In all other cases, do not use a torch backend.**
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu128`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v5.10.0 to v5.11.0"
|
||||
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu126`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.2.4`.
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu126`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v5.0.0 to v5.9.1"
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.1`.
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v4"
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
|
||||
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm5.2`.
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
@@ -105,10 +105,10 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
|
||||
```
|
||||
|
||||
If you determined you needed to use a torch backend in the previous step, you'll need to set the backend like this:
|
||||
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --torch-backend=<VERSION> --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
```
|
||||
|
||||
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
@@ -33,45 +33,30 @@ Hardware requirements vary significantly depending on model and image output siz
|
||||
|
||||
More detail on system requirements can be found [here](./requirements.md).
|
||||
|
||||
## Step 2: Download and Set Up the Launcher
|
||||
## Step 2: Download
|
||||
|
||||
The Launcher manages your Invoke install. Follow these instructions to download and set up the Launcher.
|
||||
Download the most recent launcher for your operating system:
|
||||
|
||||
!!! info "Instructions for each OS"
|
||||
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
|
||||
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
|
||||
- [Download for Linux](https://download.invoke.ai/Invoke%20Community%20Edition.AppImage)
|
||||
|
||||
=== "Windows"
|
||||
## Step 3: Install or Update
|
||||
|
||||
- [Download for Windows](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition.Setup.latest.exe)
|
||||
- Run the `EXE` to install the Launcher and start it.
|
||||
- A desktop shortcut will be created; use this to run the Launcher in the future.
|
||||
- You can delete the `EXE` file you downloaded.
|
||||
|
||||
=== "macOS"
|
||||
|
||||
- [Download for macOS](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest-arm64.dmg)
|
||||
- Open the `DMG` and drag the app into `Applications`.
|
||||
- Run the Launcher using its entry in `Applications`.
|
||||
- You can delete the `DMG` file you downloaded.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
- [Download for Linux](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest.AppImage)
|
||||
- You may need to edit the `AppImage` file properties and make it executable.
|
||||
- Optionally move the file to a location that does not require admin privileges and add a desktop shortcut for it.
|
||||
- Run the Launcher by double-clicking the `AppImage` or the shortcut you made.
|
||||
|
||||
## Step 3: Install Invoke
|
||||
|
||||
Run the Launcher you just set up if you haven't already. Click **Install** and follow the instructions to install (or update) Invoke.
|
||||
Run the launcher you just downloaded, click **Install** and follow the instructions to get set up.
|
||||
|
||||
If you have an existing Invoke installation, you can select it and let the launcher manage the install. You'll be able to update or launch the installation.
|
||||
|
||||
!!! tip "Updating"
|
||||
!!! warning "Problem running the launcher on macOS"
|
||||
|
||||
The Launcher will check for updates for itself _and_ Invoke.
|
||||
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can manually flag the launcher as safe:
|
||||
|
||||
- When the Launcher detects an update is available for itself, you'll get a small popup window. Click through this and the Launcher will update itself.
|
||||
- When the Launcher detects an update for Invoke, you'll see a small green alert in the Launcher. Click that and follow the instructions to update Invoke.
|
||||
- Open the **Invoke Community Edition.dmg** file.
|
||||
- Drag the launcher to **Applications**.
|
||||
- Open a terminal.
|
||||
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
|
||||
|
||||
You should now be able to run the launcher.
|
||||
|
||||
## Step 4: Launch
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ Nodes have a "Use Cache" option in their footer. This allows for performance imp
|
||||
|
||||
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
|
||||
|
||||
### Create Latent Noise
|
||||
### Noise
|
||||
|
||||
An initial noise tensor is necessary for the latent diffusion process. As a result, the Denoising node requires a noise node input.
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards.boards_default import BoardService
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.download.download_default import DownloadQueueService
|
||||
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
||||
@@ -152,7 +151,6 @@ class ApiDependencies:
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
||||
client_state_persistence = ClientStatePersistenceSqlite(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -183,7 +181,6 @@ class ApiDependencies:
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
workflow_thumbnails=workflow_thumbnails,
|
||||
client_state_persistence=client_state_persistence,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"])
|
||||
|
||||
|
||||
@client_state_router.get(
|
||||
"/{queue_id}/get_by_key",
|
||||
operation_id="get_client_state_by_key",
|
||||
response_model=str | None,
|
||||
)
|
||||
async def get_client_state_by_key(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
key: str = Query(..., description="Key to get"),
|
||||
) -> str | None:
|
||||
"""Gets the client state"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
|
||||
|
||||
@client_state_router.post(
|
||||
"/{queue_id}/set_by_key",
|
||||
operation_id="set_client_state",
|
||||
response_model=str,
|
||||
)
|
||||
async def set_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
key: str = Query(..., description="Key to set"),
|
||||
value: str = Body(..., description="Stringified value to set"),
|
||||
) -> str:
|
||||
"""Sets the client state"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
|
||||
|
||||
@client_state_router.post(
|
||||
"/{queue_id}/delete",
|
||||
operation_id="delete_client_state",
|
||||
responses={204: {"description": "Client state deleted"}},
|
||||
)
|
||||
async def delete_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> None:
|
||||
"""Deletes the client state"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting client state")
|
||||
@@ -19,7 +19,6 @@ from invokeai.app.api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
client_state,
|
||||
download_queue,
|
||||
images,
|
||||
model_manager,
|
||||
@@ -132,7 +131,6 @@ app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||
app.include_router(client_state.client_state_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
@@ -157,12 +155,6 @@ def overridden_redoc() -> HTMLResponse:
|
||||
|
||||
web_root_path = Path(list(web_dir.__path__)[0])
|
||||
|
||||
if app_config.unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
"The unsafe_disable_picklescan option is enabled. This disables malware scanning while installing and"
|
||||
"loading models, which may allow malicious code to be executed. Use at your own risk."
|
||||
)
|
||||
|
||||
try:
|
||||
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
||||
except RuntimeError:
|
||||
|
||||
158
invokeai/app/invocations/bria_controlnet.py
Normal file
158
invokeai/app/invocations/bria_controlnet.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
|
||||
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.invocation_api import Classification, ImageOutput
|
||||
|
||||
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
|
||||
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/"
|
||||
|
||||
|
||||
class BriaControlNetField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
|
||||
conditioning_scale: float = Field(description="The weight given to the ControlNet")
|
||||
|
||||
|
||||
@invocation_output("bria_controlnet_output")
|
||||
class BriaControlNetOutput(BaseInvocationOutput):
|
||||
"""Bria ControlNet info"""
|
||||
|
||||
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
|
||||
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_controlnet",
|
||||
title="ControlNet - Bria",
|
||||
tags=["controlnet", "bria"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Collect Bria ControlNet info to pass to denoiser node."""
|
||||
|
||||
control_image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
|
||||
)
|
||||
control_mode: BRIA_CONTROL_MODES = InputField(default="depth", description="The mode of the ControlNet")
|
||||
control_weight: float = InputField(default=1.0, ge=-1, le=2, description="The weight given to the ControlNet")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
|
||||
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
|
||||
if self.control_mode == "canny":
|
||||
control_image = extract_canny(image_in)
|
||||
elif self.control_mode == "depth":
|
||||
control_image = extract_depth(image_in, context)
|
||||
elif self.control_mode == "pose":
|
||||
control_image = extract_openpose(image_in, context)
|
||||
elif self.control_mode == "colorgrid":
|
||||
control_image = tile(64, image_in)
|
||||
elif self.control_mode == "recolor":
|
||||
control_image = convert_to_grayscale(image_in)
|
||||
elif self.control_mode == "tile":
|
||||
control_image = tile(16, image_in)
|
||||
|
||||
control_image = resize_img(control_image)
|
||||
image_dto = context.images.save(image=control_image)
|
||||
image_output = ImageOutput.build(image_dto)
|
||||
return BriaControlNetOutput(
|
||||
preprocessed_images=image_output.image,
|
||||
control=BriaControlNetField(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
model=self.control_model,
|
||||
mode=self.control_mode,
|
||||
conditioning_scale=self.control_weight,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
RATIO_CONFIGS_1024 = {
|
||||
0.6666666666666666: {"width": 832, "height": 1248},
|
||||
0.7432432432432432: {"width": 880, "height": 1184},
|
||||
0.8028169014084507: {"width": 912, "height": 1136},
|
||||
1.0: {"width": 1024, "height": 1024},
|
||||
1.2456140350877194: {"width": 1136, "height": 912},
|
||||
1.3454545454545455: {"width": 1184, "height": 880},
|
||||
1.4339622641509433: {"width": 1216, "height": 848},
|
||||
1.5: {"width": 1248, "height": 832},
|
||||
1.5490196078431373: {"width": 1264, "height": 816},
|
||||
1.62: {"width": 1296, "height": 800},
|
||||
1.7708333333333333: {"width": 1360, "height": 768},
|
||||
}
|
||||
|
||||
|
||||
def extract_depth(image: Image.Image, context: InvocationContext):
|
||||
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
|
||||
|
||||
with loaded_model as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
return depth_map
|
||||
|
||||
|
||||
def extract_openpose(image: Image.Image, context: InvocationContext):
|
||||
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
|
||||
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
|
||||
face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face)
|
||||
|
||||
with body_model as body_model, hand_model as hand_model, face_model as face_model:
|
||||
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
|
||||
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
|
||||
|
||||
processed_image_open_pose = processed_image_open_pose.resize(image.size)
|
||||
return processed_image_open_pose
|
||||
|
||||
|
||||
def extract_canny(input_image):
|
||||
image = np.array(input_image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
return canny_image
|
||||
|
||||
|
||||
def convert_to_grayscale(image):
|
||||
gray_image = image.convert("L").convert("RGB")
|
||||
return gray_image
|
||||
|
||||
|
||||
def tile(downscale_factor, input_image):
|
||||
control_image = input_image.resize(
|
||||
(input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)
|
||||
).resize(input_image.size, Image.Resampling.NEAREST)
|
||||
return control_image
|
||||
|
||||
|
||||
def resize_img(control_image):
|
||||
image_ratio = control_image.width / control_image.height
|
||||
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
|
||||
to_height = RATIO_CONFIGS_1024[ratio]["height"]
|
||||
to_width = RATIO_CONFIGS_1024[ratio]["width"]
|
||||
resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image
|
||||
46
invokeai/app/invocations/bria_decoder.py
Normal file
46
invokeai/app/invocations/bria_decoder.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_decoder",
|
||||
title="Decoder - Bria",
|
||||
tags=["image", "bria"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaDecoderInvocation(BaseInvocation):
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
|
||||
|
||||
with context.models.load(self.vae.vae) as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
latents = latents / vae.config.scaling_factor
|
||||
latents = latents.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
decoded_output = vae.decode(latents)
|
||||
image = decoded_output.sample
|
||||
|
||||
# Convert to numpy with proper gradient handling
|
||||
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
|
||||
img = Image.fromarray(image)
|
||||
image_dto = context.images.save(image=img)
|
||||
return ImageOutput.build(image_dto)
|
||||
180
invokeai/app/invocations/bria_denoiser.py
Normal file
180
invokeai/app/invocations/bria_denoiser.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
|
||||
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
|
||||
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
|
||||
from invokeai.backend.bria.controlnet_utils import prepare_control_images
|
||||
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
|
||||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
|
||||
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
|
||||
|
||||
|
||||
@invocation_output("bria_denoise_output")
|
||||
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
|
||||
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_denoise",
|
||||
title="Denoise - Bria",
|
||||
tags=["image", "bria"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaDenoiseInvocation(BaseInvocation):
|
||||
num_steps: int = InputField(
|
||||
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
|
||||
)
|
||||
guidance_scale: float = InputField(
|
||||
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser"
|
||||
)
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description="Bria model (Transformer) to load",
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
title="VAE",
|
||||
)
|
||||
latents: LatentsField = InputField(
|
||||
description="Latents to denoise",
|
||||
input=Input.Connection,
|
||||
title="Latents",
|
||||
)
|
||||
latent_image_ids: LatentsField = InputField(
|
||||
description="Latent Image IDs to denoise",
|
||||
input=Input.Connection,
|
||||
title="Latent Image IDs",
|
||||
)
|
||||
pos_embeds: LatentsField = InputField(
|
||||
description="Positive Prompt Embeds",
|
||||
input=Input.Connection,
|
||||
title="Positive Prompt Embeds",
|
||||
)
|
||||
neg_embeds: LatentsField = InputField(
|
||||
description="Negative Prompt Embeds",
|
||||
input=Input.Connection,
|
||||
title="Negative Prompt Embeds",
|
||||
)
|
||||
text_ids: LatentsField = InputField(
|
||||
description="Text IDs",
|
||||
input=Input.Connection,
|
||||
title="Text IDs",
|
||||
)
|
||||
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
|
||||
description="ControlNet",
|
||||
input=Input.Connection,
|
||||
title="ControlNet",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
|
||||
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
|
||||
text_ids = context.tensors.load(self.text_ids.latents_name)
|
||||
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
|
||||
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
|
||||
device = None
|
||||
dtype = None
|
||||
with (
|
||||
context.models.load(self.transformer.transformer) as transformer,
|
||||
context.models.load(scheduler_identifier) as scheduler,
|
||||
context.models.load(self.vae.vae) as vae,
|
||||
context.models.load(self.t5_encoder.text_encoder) as t5_encoder,
|
||||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(transformer, BriaTransformer2DModel)
|
||||
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
dtype = transformer.dtype
|
||||
device = transformer.device
|
||||
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds))
|
||||
|
||||
control_model, control_images, control_modes, control_scales = None, None, None, None
|
||||
if self.control is not None:
|
||||
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
|
||||
context=context,
|
||||
vae=vae,
|
||||
width=1024,
|
||||
height=1024,
|
||||
device=vae.device,
|
||||
)
|
||||
|
||||
pipeline = BriaControlNetPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=t5_encoder,
|
||||
tokenizer=t5_tokenizer,
|
||||
controlnet=control_model,
|
||||
)
|
||||
pipeline.to(device=transformer.device, dtype=transformer.dtype)
|
||||
|
||||
latents = pipeline(
|
||||
control_image=control_images,
|
||||
control_mode=control_modes,
|
||||
width=1024,
|
||||
height=1024,
|
||||
controlnet_conditioning_scale=control_scales,
|
||||
num_inference_steps=self.num_steps,
|
||||
max_sequence_length=128,
|
||||
guidance_scale=self.guidance_scale,
|
||||
latents=latents,
|
||||
latent_image_ids=latent_image_ids,
|
||||
text_ids=text_ids,
|
||||
prompt_embeds=pos_embeds,
|
||||
negative_prompt_embeds=neg_embeds,
|
||||
output_type="latent",
|
||||
)[0]
|
||||
|
||||
assert isinstance(latents, torch.Tensor)
|
||||
saved_input_latents_tensor = context.tensors.save(latents)
|
||||
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
|
||||
return BriaDenoiseInvocationOutput(latents=latents_output)
|
||||
|
||||
def _prepare_multi_control(
|
||||
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device
|
||||
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
|
||||
control = self.control if isinstance(self.control, list) else [self.control]
|
||||
control_images, control_models, control_modes, control_scales = [], [], [], []
|
||||
for controlnet in control:
|
||||
if controlnet is not None:
|
||||
control_models.append(context.models.load(controlnet.model).model)
|
||||
control_modes.append(BriaControlModes[controlnet.mode].value)
|
||||
control_scales.append(controlnet.conditioning_scale)
|
||||
try:
|
||||
control_images.append(context.images.get_pil(controlnet.image.image_name))
|
||||
except Exception:
|
||||
raise FileNotFoundError(
|
||||
f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline."
|
||||
)
|
||||
|
||||
control_model = BriaMultiControlNetModel(control_models).to(device)
|
||||
tensored_control_images, tensored_control_modes = prepare_control_images(
|
||||
vae=vae,
|
||||
control_images=control_images,
|
||||
control_modes=control_modes,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
)
|
||||
return control_model, tensored_control_images, tensored_control_modes, control_scales
|
||||
76
invokeai/app/invocations/bria_latent_sampler.py
Normal file
76
invokeai/app/invocations/bria_latent_sampler.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.fields import Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import (
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
LatentsField,
|
||||
)
|
||||
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("bria_latent_sampler_output")
|
||||
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a CogView text conditioning tensor."""
|
||||
|
||||
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_latent_sampler",
|
||||
title="Latent Sampler - Bria",
|
||||
tags=["image", "bria"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaLatentSamplerInvocation(BaseInvocation):
|
||||
seed: int = InputField(
|
||||
default=42,
|
||||
title="Seed",
|
||||
description="The seed to use for the latent sampler",
|
||||
)
|
||||
transformer: TransformerField = InputField(
|
||||
description="Bria model (Transformer) to load",
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
|
||||
with context.models.load(self.transformer.transformer) as transformer:
|
||||
device = transformer.device
|
||||
dtype = transformer.dtype
|
||||
|
||||
height, width = 1024, 1024
|
||||
generator = torch.Generator(device=device).manual_seed(self.seed)
|
||||
|
||||
num_channels_latents = 4
|
||||
latents, latent_image_ids = prepare_latents(
|
||||
batch_size=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
saved_latents_tensor = context.tensors.save(latents)
|
||||
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)
|
||||
latents_output = LatentsField(latents_name=saved_latents_tensor)
|
||||
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
|
||||
|
||||
return BriaLatentSamplerInvocationOutput(
|
||||
latents=latents_output,
|
||||
latent_image_ids=latent_image_ids_output,
|
||||
)
|
||||
58
invokeai/app/invocations/bria_model_loader.py
Normal file
58
invokeai/app/invocations/bria_model_loader.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
SubModelType,
|
||||
T5EncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("bria_model_loader_output")
|
||||
class BriaModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Bria base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_model_loader",
|
||||
title="Main Model - Bria",
|
||||
tags=["model", "bria"],
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a bria base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description="Bria model (Transformer) to load",
|
||||
ui_type=UIType.BriaMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput:
|
||||
for key in [self.model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return BriaModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
93
invokeai/app/invocations/bria_text_encoder.py
Normal file
93
invokeai/app/invocations/bria_text_encoder.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.model import T5EncoderField
|
||||
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bria.pipeline_bria_controlnet import encode_prompt
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
InputField,
|
||||
LatentsField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("bria_text_encoder_output")
|
||||
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a CogView text conditioning tensor."""
|
||||
|
||||
pos_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
neg_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
text_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_text_encoder",
|
||||
title="Prompt - Bria",
|
||||
tags=["prompt", "conditioning", "bria"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaTextEncoderInvocation(BaseInvocation):
|
||||
prompt: str = InputField(
|
||||
title="Prompt",
|
||||
description="The prompt to encode",
|
||||
)
|
||||
negative_prompt: Optional[str] = InputField(
|
||||
title="Negative Prompt",
|
||||
description="The negative prompt to encode",
|
||||
default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate",
|
||||
)
|
||||
max_length: int = InputField(
|
||||
default=128,
|
||||
title="Max Length",
|
||||
description="The maximum length of the prompt",
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
|
||||
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
with (
|
||||
t5_encoder_info as text_encoder,
|
||||
t5_tokenizer_info as tokenizer,
|
||||
):
|
||||
assert isinstance(tokenizer, T5TokenizerFast)
|
||||
assert isinstance(text_encoder, T5EncoderModel)
|
||||
|
||||
(prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt(
|
||||
prompt=self.prompt,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
negative_prompt=self.negative_prompt,
|
||||
device=text_encoder.device,
|
||||
num_images_per_prompt=1,
|
||||
max_sequence_length=self.max_length,
|
||||
lora_scale=1.0,
|
||||
)
|
||||
|
||||
saved_pos_tensor = context.tensors.save(prompt_embeds)
|
||||
saved_neg_tensor = context.tensors.save(negative_prompt_embeds)
|
||||
saved_text_ids_tensor = context.tensors.save(text_ids)
|
||||
pos_embeds_output = LatentsField(latents_name=saved_pos_tensor)
|
||||
neg_embeds_output = LatentsField(latents_name=saved_neg_tensor)
|
||||
text_ids_output = LatentsField(latents_name=saved_text_ids_tensor)
|
||||
return BriaTextEncoderInvocationOutput(
|
||||
pos_embeds=pos_embeds_output,
|
||||
neg_embeds=neg_embeds_output,
|
||||
text_ids=text_ids_output,
|
||||
)
|
||||
@@ -17,7 +17,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
|
||||
|
||||
# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
|
||||
# should refactor to avoid this duplication.
|
||||
@@ -39,11 +38,7 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
estimated_working_memory = estimate_vae_working_memory_cogview4(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
@@ -67,8 +62,6 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
@@ -6,6 +6,7 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -19,7 +20,6 @@ from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
|
||||
|
||||
# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
|
||||
# should refactor to avoid this duplication.
|
||||
@@ -39,15 +39,22 @@ class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
estimated_working_memory = estimate_vae_working_memory_cogview4(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
|
||||
@@ -42,6 +42,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
MainModel = "MainModelField"
|
||||
CogView4MainModel = "CogView4MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
BriaMainModel = "BriaMainModelField"
|
||||
BriaControlNetModel = "BriaControlNetModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
@@ -64,7 +66,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
Imagen3Model = "Imagen3ModelField"
|
||||
Imagen4Model = "Imagen4ModelField"
|
||||
ChatGPT4oModel = "ChatGPT4oModelField"
|
||||
Gemini2_5Model = "Gemini2_5ModelField"
|
||||
FluxKontextModel = "FluxKontextModelField"
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="4.1.0",
|
||||
version="4.0.0",
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
@@ -153,7 +153,7 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
|
||||
)
|
||||
|
||||
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
|
||||
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
|
||||
default=None,
|
||||
description="FLUX Kontext conditioning (reference image).",
|
||||
input=Input.Connection,
|
||||
@@ -328,21 +328,6 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
cfg_scale_end_step=self.cfg_scale_end_step,
|
||||
)
|
||||
|
||||
kontext_extension = None
|
||||
if self.kontext_conditioning:
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
|
||||
|
||||
kontext_extension = KontextExtension(
|
||||
context=context,
|
||||
kontext_conditioning=self.kontext_conditioning
|
||||
if isinstance(self.kontext_conditioning, list)
|
||||
else [self.kontext_conditioning],
|
||||
vae_field=self.controlnet_vae,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
# Prepare ControlNet extensions.
|
||||
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
|
||||
@@ -400,6 +385,19 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
kontext_extension = None
|
||||
if self.kontext_conditioning is not None:
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
|
||||
|
||||
kontext_extension = KontextExtension(
|
||||
context=context,
|
||||
kontext_conditioning=self.kontext_conditioning,
|
||||
vae_field=self.controlnet_vae,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# Prepare Kontext conditioning if provided
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -17,7 +18,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -39,11 +39,17 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
assert isinstance(vae_info.model, AutoEncoder)
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
|
||||
@@ -15,7 +15,6 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -42,12 +41,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
assert isinstance(vae_info.model, AutoEncoder)
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
@@ -1347,96 +1347,3 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
|
||||
|
||||
image_dto = context.images.save(image=target_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_kontext_image_prep",
|
||||
title="FLUX Kontext Image Prep",
|
||||
tags=["image", "concatenate", "flux", "kontext"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Prepares an image or images for use with FLUX Kontext. The first/single image is resized to the nearest
|
||||
preferred Kontext resolution. All other images are concatenated horizontally, maintaining their aspect ratio."""
|
||||
|
||||
images: list[ImageField] = InputField(
|
||||
description="The images to concatenate",
|
||||
min_length=1,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
use_preferred_resolution: bool = InputField(
|
||||
default=True, description="Use FLUX preferred resolutions for the first image"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
|
||||
|
||||
# Step 1: Load all images
|
||||
pil_images = []
|
||||
for image_field in self.images:
|
||||
image = context.images.get_pil(image_field.image_name, mode="RGBA")
|
||||
pil_images.append(image)
|
||||
|
||||
# Step 2: Determine target resolution for the first image
|
||||
first_image = pil_images[0]
|
||||
width, height = first_image.size
|
||||
|
||||
if self.use_preferred_resolution:
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Find the closest preferred resolution for the first image
|
||||
_, target_width, target_height = min(
|
||||
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
|
||||
)
|
||||
|
||||
# Apply BFL's scaling formula
|
||||
scaled_height = 2 * int(target_height / 16)
|
||||
final_height = 8 * scaled_height # This will be consistent for all images
|
||||
scaled_width = 2 * int(target_width / 16)
|
||||
first_width = 8 * scaled_width
|
||||
else:
|
||||
# Use original dimensions of first image, ensuring divisibility by 16
|
||||
final_height = 16 * (height // 16)
|
||||
first_width = 16 * (width // 16)
|
||||
# Ensure minimum dimensions
|
||||
if final_height < 16:
|
||||
final_height = 16
|
||||
if first_width < 16:
|
||||
first_width = 16
|
||||
|
||||
# Step 3: Process and resize all images with consistent height
|
||||
processed_images = []
|
||||
total_width = 0
|
||||
|
||||
for i, image in enumerate(pil_images):
|
||||
if i == 0:
|
||||
# First image uses the calculated dimensions
|
||||
final_width = first_width
|
||||
else:
|
||||
# Subsequent images maintain aspect ratio with the same height
|
||||
img_aspect_ratio = image.width / image.height
|
||||
# Calculate width that maintains aspect ratio at the target height
|
||||
calculated_width = int(final_height * img_aspect_ratio)
|
||||
# Ensure width is divisible by 16 for proper VAE encoding
|
||||
final_width = 16 * (calculated_width // 16)
|
||||
# Ensure minimum width
|
||||
if final_width < 16:
|
||||
final_width = 16
|
||||
|
||||
# Resize image to calculated dimensions
|
||||
resized_image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
|
||||
processed_images.append(resized_image)
|
||||
total_width += final_width
|
||||
|
||||
# Step 4: Concatenate images horizontally
|
||||
concatenated_image = Image.new("RGB", (total_width, final_height))
|
||||
x_offset = 0
|
||||
for img in processed_images:
|
||||
concatenated_image.paste(img, (x_offset, 0))
|
||||
x_offset += img.width
|
||||
|
||||
# Save the concatenated image
|
||||
image_dto = context.images.save(image=concatenated_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -27,7 +27,6 @@ from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -53,24 +52,11 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def vae_encode(
|
||||
cls,
|
||||
vae_info: LoadedModel,
|
||||
upcast: bool,
|
||||
tiled: bool,
|
||||
image_tensor: torch.Tensor,
|
||||
tile_size: int = 0,
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
|
||||
operation="encode",
|
||||
image_tensor=image_tensor,
|
||||
vae=vae_info.model,
|
||||
tile_size=tile_size if tiled else None,
|
||||
fp32=upcast,
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
@@ -127,7 +113,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
@@ -135,11 +120,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info,
|
||||
upcast=self.fp32,
|
||||
tiled=self.tiled or context.config.get().force_tiled_decode,
|
||||
image_tensor=image_tensor,
|
||||
tile_size=self.tile_size,
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
@@ -27,7 +27,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -54,6 +53,39 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
def _estimate_working_memory(
|
||||
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
out_h = tile_size
|
||||
out_w = tile_size
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = working_memory * 1.25
|
||||
else:
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
if self.fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
@@ -62,13 +94,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
|
||||
operation="decode",
|
||||
image_tensor=latents,
|
||||
vae=vae_info.model,
|
||||
tile_size=self.tile_size if use_tiling else None,
|
||||
fp32=self.fp32,
|
||||
)
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
|
||||
@@ -17,7 +17,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -35,11 +34,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
estimated_working_memory = estimate_vae_working_memory_sd3(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
@@ -63,8 +58,6 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
@@ -6,6 +6,7 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -19,7 +20,6 @@ from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -41,15 +41,22 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
estimated_working_memory = estimate_vae_working_memory_sd3(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ClientStatePersistenceABC(ABC):
|
||||
"""
|
||||
Base class for client persistence implementations.
|
||||
This class defines the interface for persisting client data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
|
||||
"""
|
||||
Set a key-value pair for the client.
|
||||
|
||||
Args:
|
||||
key (str): The key to set.
|
||||
value (str): The value to set for the key.
|
||||
|
||||
Returns:
|
||||
str: The value that was set.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_by_key(self, queue_id: str, key: str) -> str | None:
|
||||
"""
|
||||
Get the value for a specific key of the client.
|
||||
|
||||
Args:
|
||||
key (str): The key to retrieve the value for.
|
||||
|
||||
Returns:
|
||||
str | None: The value associated with the key, or None if the key does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, queue_id: str) -> None:
|
||||
"""
|
||||
Delete all client state.
|
||||
"""
|
||||
pass
|
||||
@@ -1,65 +0,0 @@
|
||||
import json
|
||||
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
|
||||
"""
|
||||
Base class for client persistence implementations.
|
||||
This class defines the interface for persisting client data.
|
||||
"""
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._default_row_id = 1
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def _get(self) -> dict[str, str] | None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT data FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return json.loads(row[0])
|
||||
|
||||
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
|
||||
state = self._get() or {}
|
||||
state.update({key: value})
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO client_state (id, data)
|
||||
VALUES ({self._default_row_id}, ?)
|
||||
ON CONFLICT(id) DO UPDATE
|
||||
SET data = excluded.data;
|
||||
""",
|
||||
(json.dumps(state),),
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def get_by_key(self, queue_id: str, key: str) -> str | None:
|
||||
state = self._get()
|
||||
if state is None:
|
||||
return None
|
||||
return state.get(key, None)
|
||||
|
||||
def delete(self, queue_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
DELETE FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
)
|
||||
@@ -107,7 +107,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
hashing_algorithm: Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.<br>Valid values: `blake3_multi`, `blake3_single`, `random`, `md5`, `sha1`, `sha224`, `sha256`, `sha384`, `sha512`, `blake2b`, `blake2s`, `sha3_224`, `sha3_256`, `sha3_384`, `sha3_512`, `shake_128`, `shake_256`
|
||||
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
|
||||
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
|
||||
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
|
||||
"""
|
||||
|
||||
_root: Optional[Path] = PrivateAttr(default=None)
|
||||
@@ -197,7 +196,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3_single", description="Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
|
||||
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
|
||||
scan_models_on_startup: bool = Field(default=False, description="Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.")
|
||||
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
||||
from invokeai.app.services.boards.boards_base import BoardServiceABC
|
||||
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@@ -74,7 +73,6 @@ class InvocationServices:
|
||||
style_preset_records: "StylePresetRecordsStorageBase",
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
workflow_thumbnails: "WorkflowThumbnailServiceBase",
|
||||
client_state_persistence: "ClientStatePersistenceABC",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@@ -104,4 +102,3 @@ class InvocationServices:
|
||||
self.style_preset_records = style_preset_records
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
self.workflow_thumbnails = workflow_thumbnails
|
||||
self.client_state_persistence = client_state_persistence
|
||||
|
||||
@@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import move, rmtree
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
@@ -186,15 +186,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
|
||||
|
||||
if preferred_name := config.name:
|
||||
if Path(model_path).is_file():
|
||||
# Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot.
|
||||
preferred_name = f"{preferred_name}{model_path.suffix}"
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
|
||||
dest_path = (
|
||||
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
|
||||
)
|
||||
try:
|
||||
new_path = self._move_model(model_path, dest_path)
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
except FileExistsError as excp:
|
||||
raise DuplicateModelException(
|
||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||
@@ -619,17 +617,30 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if old_path.is_dir():
|
||||
copytree(old_path, new_path)
|
||||
else:
|
||||
copyfile(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
if new_path.exists():
|
||||
raise FileExistsError(f"Cannot move {old_path} to {new_path}: destination already exists")
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while new_path.exists():
|
||||
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
move(old_path, new_path)
|
||||
|
||||
return new_path
|
||||
|
||||
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
|
||||
|
||||
@@ -87,21 +87,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if self._app_config.unsafe_disable_picklescan:
|
||||
self._logger.warning(
|
||||
f"Model at {checkpoint} is potentially infected by malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
if scan_result.scan_err:
|
||||
if self._app_config.unsafe_disable_picklescan:
|
||||
self._logger.warning(
|
||||
f"Error scanning model at {checkpoint} for malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
|
||||
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
|
||||
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
@@ -23,7 +23,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -64,7 +63,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_18())
|
||||
migrator.register_migration(build_migration_19(app_config=config))
|
||||
migrator.register_migration(build_migration_20())
|
||||
migrator.register_migration(build_migration_21())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration21Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE client_state (
|
||||
id INTEGER PRIMARY KEY CHECK(id = 1),
|
||||
data TEXT NOT NULL, -- Frontend will handle the shape of this data
|
||||
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER tg_client_state_updated_at
|
||||
AFTER UPDATE ON client_state
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE client_state
|
||||
SET updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = OLD.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_21() -> Migration:
|
||||
"""Builds the migration object for migrating from version 20 to version 21. This includes:
|
||||
- Creating the `client_state` table.
|
||||
- Adding a trigger to update the `updated_at` field on updates.
|
||||
"""
|
||||
return Migration(
|
||||
from_version=20,
|
||||
to_version=21,
|
||||
callback=Migration21Callback(),
|
||||
)
|
||||
0
invokeai/backend/bria/__init__.py
Normal file
0
invokeai/backend/bria/__init__.py
Normal file
314
invokeai/backend/bria/bria_utils.py
Normal file
314
invokeai/backend/bria/bria_utils.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from diffusers.utils import logging
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_t5_prompt_embeds(
|
||||
tokenizer: T5TokenizerFast,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str], None] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or text_encoder.device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
# padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
# Concat zeros to max_sequence
|
||||
b, seq_len, dim = prompt_embeds.shape
|
||||
if seq_len < max_sequence_length:
|
||||
padding = torch.zeros(
|
||||
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
# in order the get the same sigmas as in training and sample from them
|
||||
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
|
||||
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
|
||||
new_sigmas = sigmas[inds]
|
||||
return new_sigmas
|
||||
|
||||
|
||||
def is_ng_none(negative_prompt):
|
||||
return (
|
||||
negative_prompt is None
|
||||
or negative_prompt == ""
|
||||
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
|
||||
or (isinstance(negative_prompt, list) and negative_prompt[0] == "")
|
||||
)
|
||||
|
||||
|
||||
class CudaTimerContext:
|
||||
def __init__(self, times_arr):
|
||||
self.times_arr = times_arr
|
||||
|
||||
def __enter__(self):
|
||||
self.before_event = torch.cuda.Event(enable_timing=True)
|
||||
self.after_event = torch.cuda.Event(enable_timing=True)
|
||||
self.before_event.record()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.after_event.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
|
||||
self.times_arr.append(elapsed_time)
|
||||
|
||||
|
||||
def get_env_prefix():
|
||||
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
|
||||
if env == "AWS":
|
||||
return "SM_CHANNEL"
|
||||
elif env == "AZURE":
|
||||
return "AZUREML_DATAREFERENCE"
|
||||
|
||||
raise Exception(f"Env {env} not supported")
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
def initialize_distributed():
|
||||
# Initialize the process group for distributed training
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
# Get the current process's rank (ID) and the total number of processes (world size)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
print(f"Initialized distributed training: Rank {rank}/{world_size}")
|
||||
|
||||
|
||||
def get_clip_prompt_embeds(
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 77,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or text_encoder.device
|
||||
assert max_sequence_length == tokenizer.model_max_length
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [tokenizer, tokenizer_2]
|
||||
text_encoders = [text_encoder, text_encoder_2]
|
||||
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||||
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||||
data type.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class FluxPosEmbed(torch.nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
6
invokeai/backend/bria/controlnet_aux/__init__.py
Normal file
6
invokeai/backend/bria/controlnet_aux/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
__version__ = "0.0.9"
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.canny import CannyDetector as CannyDetector
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector as OpenposeDetector
|
||||
|
||||
__all__ = ["CannyDetector", "OpenposeDetector"]
|
||||
48
invokeai/backend/bria/controlnet_aux/canny/__init__.py
Normal file
48
invokeai/backend/bria/controlnet_aux/canny/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
|
||||
|
||||
|
||||
class CannyDetector:
|
||||
def __call__(
|
||||
self,
|
||||
input_image=None,
|
||||
low_threshold=100,
|
||||
high_threshold=200,
|
||||
detect_resolution=512,
|
||||
image_resolution=512,
|
||||
output_type=None,
|
||||
**kwargs,
|
||||
):
|
||||
if "img" in kwargs:
|
||||
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning, stacklevel=2)
|
||||
input_image = kwargs.pop("img")
|
||||
|
||||
if input_image is None:
|
||||
raise ValueError("input_image must be defined.")
|
||||
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
input_image = np.array(input_image, dtype=np.uint8)
|
||||
output_type = output_type or "pil"
|
||||
else:
|
||||
output_type = output_type or "np"
|
||||
|
||||
input_image = HWC3(input_image)
|
||||
input_image = resize_image(input_image, detect_resolution)
|
||||
|
||||
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
|
||||
detected_map = HWC3(detected_map)
|
||||
|
||||
img = resize_image(input_image, image_resolution)
|
||||
H, W, C = img.shape
|
||||
|
||||
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
108
invokeai/backend/bria/controlnet_aux/open_pose/LICENSE
Normal file
108
invokeai/backend/bria/controlnet_aux/open_pose/LICENSE
Normal file
@@ -0,0 +1,108 @@
|
||||
OPENPOSE: MULTIPERSON KEYPOINT DETECTION
|
||||
SOFTWARE LICENSE AGREEMENT
|
||||
ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
|
||||
|
||||
BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
|
||||
|
||||
This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
|
||||
|
||||
RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
|
||||
Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
|
||||
non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
|
||||
|
||||
CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
|
||||
|
||||
COPYRIGHT: The Software is owned by Licensor and is protected by United
|
||||
States copyright laws and applicable international treaties and/or conventions.
|
||||
|
||||
PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
|
||||
|
||||
DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
|
||||
|
||||
BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
|
||||
|
||||
USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
|
||||
|
||||
You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
|
||||
|
||||
ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
|
||||
|
||||
TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
|
||||
|
||||
The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
|
||||
|
||||
FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
|
||||
|
||||
DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
|
||||
|
||||
SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
|
||||
|
||||
EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
|
||||
|
||||
EXPORT REGULATION: Licensee agrees to comply with any and all applicable
|
||||
U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
|
||||
|
||||
SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
|
||||
|
||||
NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
|
||||
|
||||
GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
|
||||
|
||||
ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
|
||||
|
||||
|
||||
|
||||
************************************************************************
|
||||
|
||||
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
|
||||
|
||||
This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
|
||||
|
||||
1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
|
||||
|
||||
COPYRIGHT
|
||||
|
||||
All contributions by the University of California:
|
||||
Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
||||
All rights reserved.
|
||||
|
||||
All other contributions:
|
||||
Copyright (c) 2014-2017, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
Caffe uses a shared copyright model: each contributor holds copyright over
|
||||
their contributions to Caffe. The project versioning records all such
|
||||
contribution and copyright details. If a contributor wants to further mark
|
||||
their specific copyright on a particular contribution, they should indicate
|
||||
their copyright solely in the commit message of the change when it is
|
||||
committed.
|
||||
|
||||
LICENSE
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
CONTRIBUTION AGREEMENT
|
||||
|
||||
By contributing to the BVLC/caffe repository through pull-request, comment,
|
||||
or otherwise, the contributor releases their content to the
|
||||
license and copyright terms herein.
|
||||
|
||||
************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
|
||||
267
invokeai/backend/bria/controlnet_aux/open_pose/__init__.py
Normal file
267
invokeai/backend/bria/controlnet_aux/open_pose/__init__.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Openpose
|
||||
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
||||
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
||||
# 3rd Edited by ControlNet
|
||||
# 4th Edited by ControlNet (added face and correct hands)
|
||||
# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
|
||||
# This preprocessor is licensed by CMU for non-commercial use only.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
import warnings
|
||||
from typing import List, NamedTuple, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body, BodyResult, Keypoint
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
|
||||
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
|
||||
|
||||
HandResult = List[Keypoint]
|
||||
FaceResult = List[Keypoint]
|
||||
|
||||
|
||||
class PoseResult(NamedTuple):
|
||||
body: BodyResult
|
||||
left_hand: Union[HandResult, None]
|
||||
right_hand: Union[HandResult, None]
|
||||
face: Union[FaceResult, None]
|
||||
|
||||
|
||||
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
|
||||
"""
|
||||
Draw the detected poses on an empty canvas.
|
||||
|
||||
Args:
|
||||
poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
|
||||
H (int): The height of the canvas.
|
||||
W (int): The width of the canvas.
|
||||
draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
|
||||
draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
|
||||
draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
|
||||
"""
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
for pose in poses:
|
||||
if draw_body:
|
||||
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
|
||||
|
||||
if draw_hand:
|
||||
canvas = util.draw_handpose(canvas, pose.left_hand)
|
||||
canvas = util.draw_handpose(canvas, pose.right_hand)
|
||||
|
||||
if draw_face:
|
||||
canvas = util.draw_facepose(canvas, pose.face)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
class OpenposeDetector:
|
||||
"""
|
||||
A class for detecting human poses in images using the Openpose model.
|
||||
|
||||
Attributes:
|
||||
model_dir (str): Path to the directory where the pose models are stored.
|
||||
"""
|
||||
|
||||
def __init__(self, body_estimation, hand_estimation=None, face_estimation=None):
|
||||
self.body_estimation = body_estimation
|
||||
self.hand_estimation = hand_estimation
|
||||
self.face_estimation = face_estimation
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_or_path,
|
||||
filename=None,
|
||||
hand_filename=None,
|
||||
face_filename=None,
|
||||
cache_dir=None,
|
||||
local_files_only=False,
|
||||
):
|
||||
if pretrained_model_or_path == "lllyasviel/ControlNet":
|
||||
filename = filename or "annotator/ckpts/body_pose_model.pth"
|
||||
hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth"
|
||||
face_filename = face_filename or "facenet.pth"
|
||||
|
||||
face_pretrained_model_or_path = "lllyasviel/Annotators"
|
||||
else:
|
||||
filename = filename or "body_pose_model.pth"
|
||||
hand_filename = hand_filename or "hand_pose_model.pth"
|
||||
face_filename = face_filename or "facenet.pth"
|
||||
|
||||
face_pretrained_model_or_path = pretrained_model_or_path
|
||||
|
||||
if os.path.isdir(pretrained_model_or_path):
|
||||
body_model_path = os.path.join(pretrained_model_or_path, filename)
|
||||
hand_model_path = os.path.join(pretrained_model_or_path, hand_filename)
|
||||
face_model_path = os.path.join(face_pretrained_model_or_path, face_filename)
|
||||
else:
|
||||
body_model_path = hf_hub_download(
|
||||
pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only
|
||||
)
|
||||
hand_model_path = hf_hub_download(
|
||||
pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only
|
||||
)
|
||||
face_model_path = hf_hub_download(
|
||||
face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
body_estimation = Body(body_model_path)
|
||||
hand_estimation = Hand(hand_model_path)
|
||||
face_estimation = Face(face_model_path)
|
||||
|
||||
return cls(body_estimation, hand_estimation, face_estimation)
|
||||
|
||||
def to(self, device):
|
||||
self.body_estimation.to(device)
|
||||
self.hand_estimation.to(device)
|
||||
self.face_estimation.to(device)
|
||||
return self
|
||||
|
||||
def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
|
||||
left_hand = None
|
||||
right_hand = None
|
||||
H, W, _ = oriImg.shape
|
||||
for x, y, w, is_left in util.handDetect(body, oriImg):
|
||||
peaks = self.hand_estimation(oriImg[y : y + w, x : x + w, :]).astype(np.float32)
|
||||
if peaks.ndim == 2 and peaks.shape[1] == 2:
|
||||
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
|
||||
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
|
||||
|
||||
hand_result = [Keypoint(x=peak[0], y=peak[1]) for peak in peaks]
|
||||
|
||||
if is_left:
|
||||
left_hand = hand_result
|
||||
else:
|
||||
right_hand = hand_result
|
||||
|
||||
return left_hand, right_hand
|
||||
|
||||
def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
|
||||
face = util.faceDetect(body, oriImg)
|
||||
if face is None:
|
||||
return None
|
||||
|
||||
x, y, w = face
|
||||
H, W, _ = oriImg.shape
|
||||
heatmaps = self.face_estimation(oriImg[y : y + w, x : x + w, :])
|
||||
peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
|
||||
if peaks.ndim == 2 and peaks.shape[1] == 2:
|
||||
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
|
||||
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
|
||||
return [Keypoint(x=peak[0], y=peak[1]) for peak in peaks]
|
||||
|
||||
return None
|
||||
|
||||
def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
|
||||
"""
|
||||
Detect poses in the given image.
|
||||
Args:
|
||||
oriImg (numpy.ndarray): The input image for pose detection.
|
||||
include_hand (bool, optional): Whether to include hand detection. Defaults to False.
|
||||
include_face (bool, optional): Whether to include face detection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[PoseResult]: A list of PoseResult objects containing the detected poses.
|
||||
"""
|
||||
oriImg = oriImg[:, :, ::-1].copy()
|
||||
H, W, C = oriImg.shape
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.body_estimation(oriImg)
|
||||
bodies = self.body_estimation.format_body_result(candidate, subset)
|
||||
|
||||
results = []
|
||||
for body in bodies:
|
||||
left_hand, right_hand, face = (None,) * 3
|
||||
if include_hand:
|
||||
left_hand, right_hand = self.detect_hands(body, oriImg)
|
||||
if include_face:
|
||||
face = self.detect_face(body, oriImg)
|
||||
|
||||
results.append(
|
||||
PoseResult(
|
||||
BodyResult(
|
||||
keypoints=[
|
||||
Keypoint(x=keypoint.x / float(W), y=keypoint.y / float(H))
|
||||
if keypoint is not None
|
||||
else None
|
||||
for keypoint in body.keypoints
|
||||
],
|
||||
total_score=body.total_score,
|
||||
total_parts=body.total_parts,
|
||||
),
|
||||
left_hand,
|
||||
right_hand,
|
||||
face,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
detect_resolution=512,
|
||||
image_resolution=512,
|
||||
include_body=True,
|
||||
include_hand=False,
|
||||
include_face=False,
|
||||
hand_and_face=None,
|
||||
output_type="pil",
|
||||
**kwargs,
|
||||
):
|
||||
if hand_and_face is not None:
|
||||
warnings.warn(
|
||||
"hand_and_face is deprecated. Use include_hand and include_face instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
include_hand = hand_and_face
|
||||
include_face = hand_and_face
|
||||
|
||||
if "return_pil" in kwargs:
|
||||
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning, stacklevel=2)
|
||||
output_type = "pil" if kwargs["return_pil"] else "np"
|
||||
if type(output_type) is bool:
|
||||
warnings.warn(
|
||||
"Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions",
|
||||
stacklevel=2,
|
||||
)
|
||||
if output_type:
|
||||
output_type = "pil"
|
||||
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
input_image = np.array(input_image, dtype=np.uint8)
|
||||
|
||||
input_image = HWC3(input_image)
|
||||
input_image = resize_image(input_image, detect_resolution)
|
||||
H, W, C = input_image.shape
|
||||
|
||||
poses = self.detect_poses(input_image, include_hand, include_face)
|
||||
canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
|
||||
|
||||
detected_map = canvas
|
||||
detected_map = HWC3(detected_map)
|
||||
|
||||
img = resize_image(input_image, image_resolution)
|
||||
H, W, C = img.shape
|
||||
|
||||
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
319
invokeai/backend/bria/controlnet_aux/open_pose/body.py
Normal file
319
invokeai/backend/bria/controlnet_aux/open_pose/body.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import math
|
||||
from typing import List, NamedTuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.model import bodypose_model
|
||||
|
||||
|
||||
class Keypoint(NamedTuple):
|
||||
x: float
|
||||
y: float
|
||||
score: float = 1.0
|
||||
id: int = -1
|
||||
|
||||
|
||||
class BodyResult(NamedTuple):
|
||||
# Note: Using `Union` instead of `|` operator as the ladder is a Python
|
||||
# 3.10 feature.
|
||||
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
|
||||
# Python 3.8 environment.
|
||||
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
|
||||
keypoints: List[Union[Keypoint, None]]
|
||||
total_score: float
|
||||
total_parts: int
|
||||
|
||||
|
||||
class Body(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = bodypose_model()
|
||||
model_dict = util.transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, oriImg):
|
||||
device = next(iter(self.model.parameters())).device
|
||||
# scale_search = [0.5, 1.0, 1.5, 2.0]
|
||||
scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre1 = 0.1
|
||||
thre2 = 0.05
|
||||
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
||||
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
||||
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
|
||||
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
||||
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
data = data.to(device)
|
||||
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
||||
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
||||
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
|
||||
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
|
||||
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
|
||||
heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
|
||||
heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
|
||||
|
||||
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = util.smart_resize_k(paf, fx=stride, fy=stride)
|
||||
paf = paf[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
|
||||
paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
|
||||
|
||||
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
||||
paf_avg += +paf / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
peak_counter = 0
|
||||
|
||||
for part in range(18):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
|
||||
map_left = np.zeros(one_heatmap.shape)
|
||||
map_left[1:, :] = one_heatmap[:-1, :]
|
||||
map_right = np.zeros(one_heatmap.shape)
|
||||
map_right[:-1, :] = one_heatmap[1:, :]
|
||||
map_up = np.zeros(one_heatmap.shape)
|
||||
map_up[:, 1:] = one_heatmap[:, :-1]
|
||||
map_down = np.zeros(one_heatmap.shape)
|
||||
map_down[:, :-1] = one_heatmap[:, 1:]
|
||||
|
||||
peaks_binary = np.logical_and.reduce(
|
||||
(
|
||||
one_heatmap >= map_left,
|
||||
one_heatmap >= map_right,
|
||||
one_heatmap >= map_up,
|
||||
one_heatmap >= map_down,
|
||||
one_heatmap > thre1,
|
||||
)
|
||||
)
|
||||
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0], strict=False)) # note reverse
|
||||
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
|
||||
peak_id = range(peak_counter, peak_counter + len(peaks))
|
||||
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
|
||||
|
||||
all_peaks.append(peaks_with_score_and_id)
|
||||
peak_counter += len(peaks)
|
||||
|
||||
# find connection in the specified sequence, center 29 is in the position 15
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
# the middle joints heatmap correpondence
|
||||
mapIdx = [
|
||||
[31, 32],
|
||||
[39, 40],
|
||||
[33, 34],
|
||||
[35, 36],
|
||||
[41, 42],
|
||||
[43, 44],
|
||||
[19, 20],
|
||||
[21, 22],
|
||||
[23, 24],
|
||||
[25, 26],
|
||||
[27, 28],
|
||||
[29, 30],
|
||||
[47, 48],
|
||||
[49, 50],
|
||||
[53, 54],
|
||||
[51, 52],
|
||||
[55, 56],
|
||||
[37, 38],
|
||||
[45, 46],
|
||||
]
|
||||
|
||||
connection_all = []
|
||||
special_k = []
|
||||
mid_num = 10
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
||||
candA = all_peaks[limbSeq[k][0] - 1]
|
||||
candB = all_peaks[limbSeq[k][1] - 1]
|
||||
nA = len(candA)
|
||||
nB = len(candB)
|
||||
indexA, indexB = limbSeq[k]
|
||||
if nA != 0 and nB != 0:
|
||||
connection_candidate = []
|
||||
for i in range(nA):
|
||||
for j in range(nB):
|
||||
vec = np.subtract(candB[j][:2], candA[i][:2])
|
||||
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
||||
norm = max(0.001, norm)
|
||||
vec = np.divide(vec, norm)
|
||||
|
||||
startend = list(
|
||||
zip(
|
||||
np.linspace(candA[i][0], candB[j][0], num=mid_num),
|
||||
np.linspace(candA[i][1], candB[j][1], num=mid_num),
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
|
||||
vec_x = np.array(
|
||||
[
|
||||
score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 0]
|
||||
for i in range(len(startend))
|
||||
]
|
||||
)
|
||||
vec_y = np.array(
|
||||
[
|
||||
score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 1]
|
||||
for i in range(len(startend))
|
||||
]
|
||||
)
|
||||
|
||||
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
|
||||
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
|
||||
0.5 * oriImg.shape[0] / norm - 1, 0
|
||||
)
|
||||
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
|
||||
criterion2 = score_with_dist_prior > 0
|
||||
if criterion1 and criterion2:
|
||||
connection_candidate.append(
|
||||
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]
|
||||
)
|
||||
|
||||
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
|
||||
connection = np.zeros((0, 5))
|
||||
for c in range(len(connection_candidate)):
|
||||
i, j, s = connection_candidate[c][0:3]
|
||||
if i not in connection[:, 3] and j not in connection[:, 4]:
|
||||
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
|
||||
if len(connection) >= min(nA, nB):
|
||||
break
|
||||
|
||||
connection_all.append(connection)
|
||||
else:
|
||||
special_k.append(k)
|
||||
connection_all.append([])
|
||||
|
||||
# last number in each row is the total parts number of that person
|
||||
# the second last number in each row is the score of the overall configuration
|
||||
subset = -1 * np.ones((0, 20))
|
||||
candidate = np.array([item for sublist in all_peaks for item in sublist])
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
if k not in special_k:
|
||||
partAs = connection_all[k][:, 0]
|
||||
partBs = connection_all[k][:, 1]
|
||||
indexA, indexB = np.array(limbSeq[k]) - 1
|
||||
|
||||
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
||||
found = 0
|
||||
subset_idx = [-1, -1]
|
||||
for j in range(len(subset)): # 1:size(subset,1):
|
||||
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
|
||||
subset_idx[found] = j
|
||||
found += 1
|
||||
|
||||
if found == 1:
|
||||
j = subset_idx[0]
|
||||
if subset[j][indexB] != partBs[i]:
|
||||
subset[j][indexB] = partBs[i]
|
||||
subset[j][-1] += 1
|
||||
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
||||
elif found == 2: # if found 2 and disjoint, merge them
|
||||
j1, j2 = subset_idx
|
||||
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
|
||||
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
||||
subset[j1][:-2] += subset[j2][:-2] + 1
|
||||
subset[j1][-2:] += subset[j2][-2:]
|
||||
subset[j1][-2] += connection_all[k][i][2]
|
||||
subset = np.delete(subset, j2, 0)
|
||||
else: # as like found == 1
|
||||
subset[j1][indexB] = partBs[i]
|
||||
subset[j1][-1] += 1
|
||||
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
||||
|
||||
# if find no partA in the subset, create a new subset
|
||||
elif not found and k < 17:
|
||||
row = -1 * np.ones(20)
|
||||
row[indexA] = partAs[i]
|
||||
row[indexB] = partBs[i]
|
||||
row[-1] = 2
|
||||
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
|
||||
subset = np.vstack([subset, row])
|
||||
# delete some rows of subset which has few parts occur
|
||||
deleteIdx = []
|
||||
for i in range(len(subset)):
|
||||
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
||||
deleteIdx.append(i)
|
||||
subset = np.delete(subset, deleteIdx, axis=0)
|
||||
|
||||
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
|
||||
# candidate: x, y, score, id
|
||||
return candidate, subset
|
||||
|
||||
@staticmethod
|
||||
def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
|
||||
"""
|
||||
Format the body results from the candidate and subset arrays into a list of BodyResult objects.
|
||||
|
||||
Args:
|
||||
candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
|
||||
for each body part.
|
||||
subset (np.ndarray): An array of subsets containing indices to the candidate array for each
|
||||
person detected. The last two columns of each row hold the total score and total parts
|
||||
of the person.
|
||||
|
||||
Returns:
|
||||
List[BodyResult]: A list of BodyResult objects, where each object represents a person with
|
||||
detected keypoints, total score, and total parts.
|
||||
"""
|
||||
return [
|
||||
BodyResult(
|
||||
keypoints=[
|
||||
Keypoint(
|
||||
x=candidate[candidate_index][0],
|
||||
y=candidate[candidate_index][1],
|
||||
score=candidate[candidate_index][2],
|
||||
id=candidate[candidate_index][3],
|
||||
)
|
||||
if candidate_index != -1
|
||||
else None
|
||||
for candidate_index in person[:18].astype(int)
|
||||
],
|
||||
total_score=person[18],
|
||||
total_parts=person[19],
|
||||
)
|
||||
for person in subset
|
||||
]
|
||||
307
invokeai/backend/bria/controlnet_aux/open_pose/face.py
Normal file
307
invokeai/backend/bria/controlnet_aux/open_pose/face.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
|
||||
|
||||
class FaceNet(Module):
|
||||
"""Model the cascading heatmaps."""
|
||||
|
||||
def __init__(self):
|
||||
super(FaceNet, self).__init__()
|
||||
# cnn to make feature map
|
||||
self.relu = ReLU()
|
||||
self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv1_1 = Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1_2 = Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2_1 = Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2_2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_1 = Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_2 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_3 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_4 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_1 = Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_4 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_1 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3_CPM = Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# stage1
|
||||
self.conv6_1_CPM = Conv2d(in_channels=128, out_channels=512, kernel_size=1, stride=1, padding=0)
|
||||
self.conv6_2_CPM = Conv2d(in_channels=512, out_channels=71, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# stage2
|
||||
self.Mconv1_stage2 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv2_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv3_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv4_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv5_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv6_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
|
||||
self.Mconv7_stage2 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# stage3
|
||||
self.Mconv1_stage3 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv2_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv3_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv4_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv5_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv6_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
|
||||
self.Mconv7_stage3 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# stage4
|
||||
self.Mconv1_stage4 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv2_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv3_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv4_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv5_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv6_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
|
||||
self.Mconv7_stage4 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# stage5
|
||||
self.Mconv1_stage5 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv2_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv3_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv4_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv5_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv6_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
|
||||
self.Mconv7_stage5 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# stage6
|
||||
self.Mconv1_stage6 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv2_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv3_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv4_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv5_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
|
||||
self.Mconv6_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
|
||||
self.Mconv7_stage6 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Conv2d):
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Return a list of heatmaps."""
|
||||
heatmaps = []
|
||||
|
||||
h = self.relu(self.conv1_1(x))
|
||||
h = self.relu(self.conv1_2(h))
|
||||
h = self.max_pooling_2d(h)
|
||||
h = self.relu(self.conv2_1(h))
|
||||
h = self.relu(self.conv2_2(h))
|
||||
h = self.max_pooling_2d(h)
|
||||
h = self.relu(self.conv3_1(h))
|
||||
h = self.relu(self.conv3_2(h))
|
||||
h = self.relu(self.conv3_3(h))
|
||||
h = self.relu(self.conv3_4(h))
|
||||
h = self.max_pooling_2d(h)
|
||||
h = self.relu(self.conv4_1(h))
|
||||
h = self.relu(self.conv4_2(h))
|
||||
h = self.relu(self.conv4_3(h))
|
||||
h = self.relu(self.conv4_4(h))
|
||||
h = self.relu(self.conv5_1(h))
|
||||
h = self.relu(self.conv5_2(h))
|
||||
h = self.relu(self.conv5_3_CPM(h))
|
||||
feature_map = h
|
||||
|
||||
# stage1
|
||||
h = self.relu(self.conv6_1_CPM(h))
|
||||
h = self.conv6_2_CPM(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage2
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage2(h))
|
||||
h = self.relu(self.Mconv2_stage2(h))
|
||||
h = self.relu(self.Mconv3_stage2(h))
|
||||
h = self.relu(self.Mconv4_stage2(h))
|
||||
h = self.relu(self.Mconv5_stage2(h))
|
||||
h = self.relu(self.Mconv6_stage2(h))
|
||||
h = self.Mconv7_stage2(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage3
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage3(h))
|
||||
h = self.relu(self.Mconv2_stage3(h))
|
||||
h = self.relu(self.Mconv3_stage3(h))
|
||||
h = self.relu(self.Mconv4_stage3(h))
|
||||
h = self.relu(self.Mconv5_stage3(h))
|
||||
h = self.relu(self.Mconv6_stage3(h))
|
||||
h = self.Mconv7_stage3(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage4
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage4(h))
|
||||
h = self.relu(self.Mconv2_stage4(h))
|
||||
h = self.relu(self.Mconv3_stage4(h))
|
||||
h = self.relu(self.Mconv4_stage4(h))
|
||||
h = self.relu(self.Mconv5_stage4(h))
|
||||
h = self.relu(self.Mconv6_stage4(h))
|
||||
h = self.Mconv7_stage4(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage5
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage5(h))
|
||||
h = self.relu(self.Mconv2_stage5(h))
|
||||
h = self.relu(self.Mconv3_stage5(h))
|
||||
h = self.relu(self.Mconv4_stage5(h))
|
||||
h = self.relu(self.Mconv5_stage5(h))
|
||||
h = self.relu(self.Mconv6_stage5(h))
|
||||
h = self.Mconv7_stage5(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage6
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage6(h))
|
||||
h = self.relu(self.Mconv2_stage6(h))
|
||||
h = self.relu(self.Mconv3_stage6(h))
|
||||
h = self.relu(self.Mconv4_stage6(h))
|
||||
h = self.relu(self.Mconv5_stage6(h))
|
||||
h = self.relu(self.Mconv6_stage6(h))
|
||||
h = self.Mconv7_stage6(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
TOTEN = ToTensor()
|
||||
TOPIL = ToPILImage()
|
||||
|
||||
|
||||
params = {
|
||||
"gaussian_sigma": 2.5,
|
||||
"inference_img_size": 736, # 368, 736, 1312
|
||||
"heatmap_peak_thresh": 0.1,
|
||||
"crop_scale": 1.5,
|
||||
"line_indices": [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[5, 6],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[8, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[11, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[14, 15],
|
||||
[15, 16],
|
||||
[17, 18],
|
||||
[18, 19],
|
||||
[19, 20],
|
||||
[20, 21],
|
||||
[22, 23],
|
||||
[23, 24],
|
||||
[24, 25],
|
||||
[25, 26],
|
||||
[27, 28],
|
||||
[28, 29],
|
||||
[29, 30],
|
||||
[31, 32],
|
||||
[32, 33],
|
||||
[33, 34],
|
||||
[34, 35],
|
||||
[36, 37],
|
||||
[37, 38],
|
||||
[38, 39],
|
||||
[39, 40],
|
||||
[40, 41],
|
||||
[41, 36],
|
||||
[42, 43],
|
||||
[43, 44],
|
||||
[44, 45],
|
||||
[45, 46],
|
||||
[46, 47],
|
||||
[47, 42],
|
||||
[48, 49],
|
||||
[49, 50],
|
||||
[50, 51],
|
||||
[51, 52],
|
||||
[52, 53],
|
||||
[53, 54],
|
||||
[54, 55],
|
||||
[55, 56],
|
||||
[56, 57],
|
||||
[57, 58],
|
||||
[58, 59],
|
||||
[59, 48],
|
||||
[60, 61],
|
||||
[61, 62],
|
||||
[62, 63],
|
||||
[63, 64],
|
||||
[64, 65],
|
||||
[65, 66],
|
||||
[66, 67],
|
||||
[67, 60],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Face(object):
|
||||
"""
|
||||
The OpenPose face landmark detector model.
|
||||
|
||||
Args:
|
||||
inference_size: set the size of the inference image size, suggested:
|
||||
368, 736, 1312, default 736
|
||||
gaussian_sigma: blur the heatmaps, default 2.5
|
||||
heatmap_peak_thresh: return landmark if over threshold, default 0.1
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, face_model_path, inference_size=None, gaussian_sigma=None, heatmap_peak_thresh=None):
|
||||
self.inference_size = inference_size or params["inference_img_size"]
|
||||
self.sigma = gaussian_sigma or params["gaussian_sigma"]
|
||||
self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
|
||||
self.model = FaceNet()
|
||||
self.model.load_state_dict(torch.load(face_model_path))
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, face_img):
|
||||
device = next(iter(self.model.parameters())).device
|
||||
H, W, C = face_img.shape
|
||||
|
||||
w_size = 384
|
||||
x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
|
||||
|
||||
x_data = x_data.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
hs = self.model(x_data[None, ...])
|
||||
heatmaps = F.interpolate(hs[-1], (H, W), mode="bilinear", align_corners=True).cpu().numpy()[0]
|
||||
return heatmaps
|
||||
|
||||
def compute_peaks_from_heatmaps(self, heatmaps):
|
||||
all_peaks = []
|
||||
for part in range(heatmaps.shape[0]):
|
||||
map_ori = heatmaps[part].copy()
|
||||
binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
|
||||
|
||||
if np.sum(binary) == 0:
|
||||
continue
|
||||
|
||||
positions = np.where(binary > 0.5)
|
||||
intensities = map_ori[positions]
|
||||
mi = np.argmax(intensities)
|
||||
y, x = positions[0][mi], positions[1][mi]
|
||||
all_peaks.append([x, y])
|
||||
|
||||
return np.array(all_peaks)
|
||||
91
invokeai/backend/bria/controlnet_aux/open_pose/hand.py
Normal file
91
invokeai/backend/bria/controlnet_aux/open_pose/hand.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
from skimage.measure import label
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.model import handpose_model
|
||||
|
||||
|
||||
class Hand(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = handpose_model()
|
||||
model_dict = util.transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, oriImgRaw):
|
||||
device = next(iter(self.model.parameters())).device
|
||||
scale_search = [0.5, 1.0, 1.5, 2.0]
|
||||
# scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre = 0.05
|
||||
multiplier = [x * boxsize for x in scale_search]
|
||||
|
||||
wsize = 128
|
||||
heatmap_avg = np.zeros((wsize, wsize, 22))
|
||||
|
||||
Hr, Wr, Cr = oriImgRaw.shape
|
||||
|
||||
oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = util.smart_resize(oriImg, (scale, scale))
|
||||
|
||||
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
||||
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
data = data.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.model(data).cpu().numpy()
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
|
||||
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
|
||||
heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
|
||||
heatmap = util.smart_resize(heatmap, (wsize, wsize))
|
||||
|
||||
heatmap_avg += heatmap / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
for part in range(21):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
|
||||
|
||||
if np.sum(binary) == 0:
|
||||
all_peaks.append([0, 0])
|
||||
continue
|
||||
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
|
||||
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
|
||||
label_img[label_img != max_index] = 0
|
||||
map_ori[label_img == 0] = 0
|
||||
|
||||
y, x = util.npmax(map_ori)
|
||||
y = int(float(y) * float(Hr) / float(wsize))
|
||||
x = int(float(x) * float(Wr) / float(wsize))
|
||||
all_peaks.append([x, y])
|
||||
return np.array(all_peaks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
hand_estimation = Hand("../model/hand_pose_model.pth")
|
||||
|
||||
# test_image = '../images/hand.jpg'
|
||||
test_image = "../images/hand.jpg"
|
||||
oriImg = cv2.imread(test_image) # B,G,R order
|
||||
peaks = hand_estimation(oriImg)
|
||||
canvas = util.draw_handpose(oriImg, peaks, True)
|
||||
cv2.imshow("", canvas)
|
||||
cv2.waitKey(0)
|
||||
240
invokeai/backend/bria/controlnet_aux/open_pose/model.py
Normal file
240
invokeai/backend/bria/controlnet_aux/open_pose/model.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_layers(block, no_relu_layers):
|
||||
layers = []
|
||||
for layer_name, v in block.items():
|
||||
if "pool" in layer_name:
|
||||
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
|
||||
layers.append((layer_name, layer))
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4])
|
||||
layers.append((layer_name, conv2d))
|
||||
if layer_name not in no_relu_layers:
|
||||
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
|
||||
|
||||
return nn.Sequential(OrderedDict(layers))
|
||||
|
||||
|
||||
class bodypose_model(nn.Module):
|
||||
def __init__(self):
|
||||
super(bodypose_model, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv5_5_CPM_L1",
|
||||
"conv5_5_CPM_L2",
|
||||
"Mconv7_stage2_L1",
|
||||
"Mconv7_stage2_L2",
|
||||
"Mconv7_stage3_L1",
|
||||
"Mconv7_stage3_L2",
|
||||
"Mconv7_stage4_L1",
|
||||
"Mconv7_stage4_L2",
|
||||
"Mconv7_stage5_L1",
|
||||
"Mconv7_stage5_L2",
|
||||
"Mconv7_stage6_L1",
|
||||
"Mconv7_stage6_L1",
|
||||
]
|
||||
blocks = {}
|
||||
block0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3_CPM", [512, 256, 3, 1, 1]),
|
||||
("conv4_4_CPM", [256, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
# Stage 1
|
||||
block1_1 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_2 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
blocks["block1_1"] = block1_1
|
||||
blocks["block1_2"] = block1_2
|
||||
|
||||
self.model0 = make_layers(block0, no_relu_layers)
|
||||
|
||||
# Stages 2 - 6
|
||||
for i in range(2, 7):
|
||||
blocks["block%d_1" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
blocks["block%d_2" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2_1 = blocks["block2_1"]
|
||||
self.model3_1 = blocks["block3_1"]
|
||||
self.model4_1 = blocks["block4_1"]
|
||||
self.model5_1 = blocks["block5_1"]
|
||||
self.model6_1 = blocks["block6_1"]
|
||||
|
||||
self.model1_2 = blocks["block1_2"]
|
||||
self.model2_2 = blocks["block2_2"]
|
||||
self.model3_2 = blocks["block3_2"]
|
||||
self.model4_2 = blocks["block4_2"]
|
||||
self.model5_2 = blocks["block5_2"]
|
||||
self.model6_2 = blocks["block6_2"]
|
||||
|
||||
def forward(self, x):
|
||||
out1 = self.model0(x)
|
||||
|
||||
out1_1 = self.model1_1(out1)
|
||||
out1_2 = self.model1_2(out1)
|
||||
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
||||
|
||||
out2_1 = self.model2_1(out2)
|
||||
out2_2 = self.model2_2(out2)
|
||||
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
||||
|
||||
out3_1 = self.model3_1(out3)
|
||||
out3_2 = self.model3_2(out3)
|
||||
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
||||
|
||||
out4_1 = self.model4_1(out4)
|
||||
out4_2 = self.model4_2(out4)
|
||||
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
||||
|
||||
out5_1 = self.model5_1(out5)
|
||||
out5_2 = self.model5_2(out5)
|
||||
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
||||
|
||||
out6_1 = self.model6_1(out6)
|
||||
out6_2 = self.model6_2(out6)
|
||||
|
||||
return out6_1, out6_2
|
||||
|
||||
|
||||
class handpose_model(nn.Module):
|
||||
def __init__(self):
|
||||
super(handpose_model, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv6_2_CPM",
|
||||
"Mconv7_stage2",
|
||||
"Mconv7_stage3",
|
||||
"Mconv7_stage4",
|
||||
"Mconv7_stage5",
|
||||
"Mconv7_stage6",
|
||||
]
|
||||
# stage 1
|
||||
block1_0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3", [512, 512, 3, 1, 1]),
|
||||
("conv4_4", [512, 512, 3, 1, 1]),
|
||||
("conv5_1", [512, 512, 3, 1, 1]),
|
||||
("conv5_2", [512, 512, 3, 1, 1]),
|
||||
("conv5_3_CPM", [512, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_1 = OrderedDict([("conv6_1_CPM", [128, 512, 1, 1, 0]), ("conv6_2_CPM", [512, 22, 1, 1, 0])])
|
||||
|
||||
blocks = {}
|
||||
blocks["block1_0"] = block1_0
|
||||
blocks["block1_1"] = block1_1
|
||||
|
||||
# stage 2-6
|
||||
for i in range(2, 7):
|
||||
blocks["block%d" % i] = OrderedDict(
|
||||
[
|
||||
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
|
||||
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
|
||||
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
|
||||
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_0 = blocks["block1_0"]
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2 = blocks["block2"]
|
||||
self.model3 = blocks["block3"]
|
||||
self.model4 = blocks["block4"]
|
||||
self.model5 = blocks["block5"]
|
||||
self.model6 = blocks["block6"]
|
||||
|
||||
def forward(self, x):
|
||||
out1_0 = self.model1_0(x)
|
||||
out1_1 = self.model1_1(out1_0)
|
||||
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
||||
out_stage2 = self.model2(concat_stage2)
|
||||
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
||||
out_stage3 = self.model3(concat_stage3)
|
||||
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
||||
out_stage4 = self.model4(concat_stage4)
|
||||
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
||||
out_stage5 = self.model5(concat_stage5)
|
||||
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
||||
out_stage6 = self.model6(concat_stage6)
|
||||
return out_stage6
|
||||
436
invokeai/backend/bria/controlnet_aux/open_pose/util.py
Normal file
436
invokeai/backend/bria/controlnet_aux/open_pose/util.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import math
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.body import BodyResult, Keypoint
|
||||
|
||||
eps = 0.01
|
||||
|
||||
|
||||
def smart_resize(x, s):
|
||||
Ht, Wt = s
|
||||
if x.ndim == 2:
|
||||
Ho, Wo = x.shape
|
||||
Co = 1
|
||||
else:
|
||||
Ho, Wo, Co = x.shape
|
||||
if Co == 3 or Co == 1:
|
||||
k = float(Ht + Wt) / float(Ho + Wo)
|
||||
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
||||
|
||||
|
||||
def smart_resize_k(x, fx, fy):
|
||||
if x.ndim == 2:
|
||||
Ho, Wo = x.shape
|
||||
Co = 1
|
||||
else:
|
||||
Ho, Wo, Co = x.shape
|
||||
Ht, Wt = Ho * fy, Wo * fx
|
||||
if Co == 3 or Co == 1:
|
||||
k = float(Ht + Wt) / float(Ho + Wo)
|
||||
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
||||
|
||||
|
||||
def padRightDownCorner(img, stride, padValue):
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
pad = 4 * [None]
|
||||
pad[0] = 0 # up
|
||||
pad[1] = 0 # left
|
||||
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
||||
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
||||
|
||||
img_padded = img
|
||||
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
|
||||
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
||||
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
|
||||
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
||||
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
|
||||
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
||||
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
|
||||
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
||||
|
||||
return img_padded, pad
|
||||
|
||||
|
||||
def transfer(model, model_weights):
|
||||
transfered_model_weights = {}
|
||||
for weights_name in model.state_dict().keys():
|
||||
transfered_model_weights[weights_name] = model_weights[".".join(weights_name.split(".")[1:])]
|
||||
return transfered_model_weights
|
||||
|
||||
|
||||
def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
|
||||
"""
|
||||
Draw keypoints and limbs representing body pose on a given canvas.
|
||||
|
||||
Args:
|
||||
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
|
||||
keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
|
||||
|
||||
Note:
|
||||
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
||||
"""
|
||||
H, W, C = canvas.shape
|
||||
stickwidth = 4
|
||||
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
]
|
||||
|
||||
colors = [
|
||||
[255, 0, 0],
|
||||
[255, 85, 0],
|
||||
[255, 170, 0],
|
||||
[255, 255, 0],
|
||||
[170, 255, 0],
|
||||
[85, 255, 0],
|
||||
[0, 255, 0],
|
||||
[0, 255, 85],
|
||||
[0, 255, 170],
|
||||
[0, 255, 255],
|
||||
[0, 170, 255],
|
||||
[0, 85, 255],
|
||||
[0, 0, 255],
|
||||
[85, 0, 255],
|
||||
[170, 0, 255],
|
||||
[255, 0, 255],
|
||||
[255, 0, 170],
|
||||
[255, 0, 85],
|
||||
]
|
||||
|
||||
for (k1_index, k2_index), color in zip(limbSeq, colors, strict=False):
|
||||
keypoint1 = keypoints[k1_index - 1]
|
||||
keypoint2 = keypoints[k2_index - 1]
|
||||
|
||||
if keypoint1 is None or keypoint2 is None:
|
||||
continue
|
||||
|
||||
Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
|
||||
X = np.array([keypoint1.y, keypoint2.y]) * float(H)
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
||||
cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
|
||||
|
||||
for keypoint, color in zip(keypoints, colors, strict=False):
|
||||
if keypoint is None:
|
||||
continue
|
||||
|
||||
x, y = keypoint.x, keypoint.y
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
|
||||
import matplotlib
|
||||
|
||||
"""
|
||||
Draw keypoints and connections representing hand pose on a given canvas.
|
||||
|
||||
Args:
|
||||
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
||||
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
||||
or None if no keypoints are present.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
||||
|
||||
Note:
|
||||
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
||||
"""
|
||||
if not keypoints:
|
||||
return canvas
|
||||
|
||||
H, W, C = canvas.shape
|
||||
|
||||
edges = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[0, 5],
|
||||
[5, 6],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[0, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[11, 12],
|
||||
[0, 13],
|
||||
[13, 14],
|
||||
[14, 15],
|
||||
[15, 16],
|
||||
[0, 17],
|
||||
[17, 18],
|
||||
[18, 19],
|
||||
[19, 20],
|
||||
]
|
||||
|
||||
for ie, (e1, e2) in enumerate(edges):
|
||||
k1 = keypoints[e1]
|
||||
k2 = keypoints[e2]
|
||||
if k1 is None or k2 is None:
|
||||
continue
|
||||
|
||||
x1 = int(k1.x * W)
|
||||
y1 = int(k1.y * H)
|
||||
x2 = int(k2.x * W)
|
||||
y2 = int(k2.y * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
for keypoint in keypoints:
|
||||
x, y = keypoint.x, keypoint.y
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
|
||||
"""
|
||||
Draw keypoints representing face pose on a given canvas.
|
||||
|
||||
Args:
|
||||
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
|
||||
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
|
||||
or None if no keypoints are present.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
|
||||
|
||||
Note:
|
||||
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
||||
"""
|
||||
if not keypoints:
|
||||
return canvas
|
||||
|
||||
H, W, C = canvas.shape
|
||||
for keypoint in keypoints:
|
||||
x, y = keypoint.x, keypoint.y
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
||||
return canvas
|
||||
|
||||
|
||||
# detect hand according to body pose keypoints
|
||||
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
||||
def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
|
||||
"""
|
||||
Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
|
||||
|
||||
Args:
|
||||
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
|
||||
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
|
||||
corner of the bounding box, the width (height) of the bounding box, and
|
||||
a boolean flag indicating whether the hand is a left hand (True) or a
|
||||
right hand (False).
|
||||
|
||||
Notes:
|
||||
- The width and height of the bounding boxes are equal since the network requires squared input.
|
||||
- The minimum bounding box size is 20 pixels.
|
||||
"""
|
||||
ratioWristElbow = 0.33
|
||||
detect_result = []
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
|
||||
keypoints = body.keypoints
|
||||
# right hand: wrist 4, elbow 3, shoulder 2
|
||||
# left hand: wrist 7, elbow 6, shoulder 5
|
||||
left_shoulder = keypoints[5]
|
||||
left_elbow = keypoints[6]
|
||||
left_wrist = keypoints[7]
|
||||
right_shoulder = keypoints[2]
|
||||
right_elbow = keypoints[3]
|
||||
right_wrist = keypoints[4]
|
||||
|
||||
# if any of three not detected
|
||||
has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
|
||||
has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
|
||||
if not (has_left or has_right):
|
||||
return []
|
||||
|
||||
hands = []
|
||||
# left hand
|
||||
if has_left:
|
||||
hands.append([left_shoulder.x, left_shoulder.y, left_elbow.x, left_elbow.y, left_wrist.x, left_wrist.y, True])
|
||||
# right hand
|
||||
if has_right:
|
||||
hands.append(
|
||||
[right_shoulder.x, right_shoulder.y, right_elbow.x, right_elbow.y, right_wrist.x, right_wrist.y, False]
|
||||
)
|
||||
|
||||
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
||||
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
||||
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
||||
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
||||
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
||||
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
||||
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
||||
x = x3 + ratioWristElbow * (x3 - x2)
|
||||
y = y3 + ratioWristElbow * (y3 - y2)
|
||||
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
||||
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
||||
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
||||
# x-y refers to the center --> offset to topLeft point
|
||||
# handRectangle.x -= handRectangle.width / 2.f;
|
||||
# handRectangle.y -= handRectangle.height / 2.f;
|
||||
x -= width / 2
|
||||
y -= width / 2 # width = height
|
||||
# overflow the image
|
||||
if x < 0:
|
||||
x = 0
|
||||
if y < 0:
|
||||
y = 0
|
||||
width1 = width
|
||||
width2 = width
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
width = min(width1, width2)
|
||||
# the max hand box value is 20 pixels
|
||||
if width >= 20:
|
||||
detect_result.append((int(x), int(y), int(width), is_left))
|
||||
|
||||
"""
|
||||
return value: [[x, y, w, True if left hand else False]].
|
||||
width=height since the network require squared input.
|
||||
x, y is the coordinate of top left.
|
||||
"""
|
||||
return detect_result
|
||||
|
||||
|
||||
# Written by Lvmin
|
||||
def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
|
||||
"""
|
||||
Detect the face in the input body pose keypoints and calculate the bounding box for the face.
|
||||
|
||||
Args:
|
||||
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
|
||||
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
|
||||
bounding box and the width (height) of the bounding box, or None if the
|
||||
face is not detected or the bounding box width is less than 20 pixels.
|
||||
|
||||
Notes:
|
||||
- The width and height of the bounding box are equal.
|
||||
- The minimum bounding box size is 20 pixels.
|
||||
"""
|
||||
# left right eye ear 14 15 16 17
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
|
||||
keypoints = body.keypoints
|
||||
head = keypoints[0]
|
||||
left_eye = keypoints[14]
|
||||
right_eye = keypoints[15]
|
||||
left_ear = keypoints[16]
|
||||
right_ear = keypoints[17]
|
||||
|
||||
if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
|
||||
return None
|
||||
|
||||
width = 0.0
|
||||
x0, y0 = head.x, head.y
|
||||
|
||||
if left_eye is not None:
|
||||
x1, y1 = left_eye.x, left_eye.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 3.0)
|
||||
|
||||
if right_eye is not None:
|
||||
x1, y1 = right_eye.x, right_eye.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 3.0)
|
||||
|
||||
if left_ear is not None:
|
||||
x1, y1 = left_ear.x, left_ear.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 1.5)
|
||||
|
||||
if right_ear is not None:
|
||||
x1, y1 = right_ear.x, right_ear.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 1.5)
|
||||
|
||||
x, y = x0, y0
|
||||
|
||||
x -= width
|
||||
y -= width
|
||||
|
||||
if x < 0:
|
||||
x = 0
|
||||
|
||||
if y < 0:
|
||||
y = 0
|
||||
|
||||
width1 = width * 2
|
||||
width2 = width * 2
|
||||
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
|
||||
width = min(width1, width2)
|
||||
|
||||
if width >= 20:
|
||||
return int(x), int(y), int(width)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# get max index of 2d array
|
||||
def npmax(array):
|
||||
arrayindex = array.argmax(1)
|
||||
arrayvalue = array.max(1)
|
||||
i = arrayvalue.argmax()
|
||||
j = arrayindex[i]
|
||||
return i, j
|
||||
260
invokeai/backend/bria/controlnet_aux/util.py
Normal file
260
invokeai/backend/bria/controlnet_aux/util.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts")
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
|
||||
def make_noise_disk(H, W, C, F):
|
||||
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
||||
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
||||
noise = noise[F : F + H, F : F + W]
|
||||
noise -= np.min(noise)
|
||||
noise /= np.max(noise)
|
||||
if C == 1:
|
||||
noise = noise[:, :, None]
|
||||
return noise
|
||||
|
||||
|
||||
def nms(x, t, s):
|
||||
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
||||
|
||||
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
|
||||
y = np.zeros_like(x)
|
||||
|
||||
for f in [f1, f2, f3, f4]:
|
||||
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||
|
||||
z = np.zeros_like(y, dtype=np.uint8)
|
||||
z[y > t] = 255
|
||||
return z
|
||||
|
||||
|
||||
def min_max_norm(x):
|
||||
x -= np.min(x)
|
||||
x /= np.maximum(np.max(x), 1e-5)
|
||||
return x
|
||||
|
||||
|
||||
def safe_step(x, step=2):
|
||||
y = x.astype(np.float32) * float(step + 1)
|
||||
y = y.astype(np.int32).astype(np.float32) / float(step)
|
||||
return y
|
||||
|
||||
|
||||
def img2mask(img, H, W, low=10, high=90):
|
||||
assert img.ndim == 3 or img.ndim == 2
|
||||
assert img.dtype == np.uint8
|
||||
|
||||
if img.ndim == 3:
|
||||
y = img[:, :, random.randrange(0, img.shape[2])]
|
||||
else:
|
||||
y = img
|
||||
|
||||
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
if random.uniform(0, 1) < 0.5:
|
||||
y = 255 - y
|
||||
|
||||
return y < np.percentile(y, random.randrange(low, high))
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def ade_palette():
|
||||
"""ADE20K palette that maps each class to RGB values."""
|
||||
return [
|
||||
[120, 120, 120],
|
||||
[180, 120, 120],
|
||||
[6, 230, 230],
|
||||
[80, 50, 50],
|
||||
[4, 200, 3],
|
||||
[120, 120, 80],
|
||||
[140, 140, 140],
|
||||
[204, 5, 255],
|
||||
[230, 230, 230],
|
||||
[4, 250, 7],
|
||||
[224, 5, 255],
|
||||
[235, 255, 7],
|
||||
[150, 5, 61],
|
||||
[120, 120, 70],
|
||||
[8, 255, 51],
|
||||
[255, 6, 82],
|
||||
[143, 255, 140],
|
||||
[204, 255, 4],
|
||||
[255, 51, 7],
|
||||
[204, 70, 3],
|
||||
[0, 102, 200],
|
||||
[61, 230, 250],
|
||||
[255, 6, 51],
|
||||
[11, 102, 255],
|
||||
[255, 7, 71],
|
||||
[255, 9, 224],
|
||||
[9, 7, 230],
|
||||
[220, 220, 220],
|
||||
[255, 9, 92],
|
||||
[112, 9, 255],
|
||||
[8, 255, 214],
|
||||
[7, 255, 224],
|
||||
[255, 184, 6],
|
||||
[10, 255, 71],
|
||||
[255, 41, 10],
|
||||
[7, 255, 255],
|
||||
[224, 255, 8],
|
||||
[102, 8, 255],
|
||||
[255, 61, 6],
|
||||
[255, 194, 7],
|
||||
[255, 122, 8],
|
||||
[0, 255, 20],
|
||||
[255, 8, 41],
|
||||
[255, 5, 153],
|
||||
[6, 51, 255],
|
||||
[235, 12, 255],
|
||||
[160, 150, 20],
|
||||
[0, 163, 255],
|
||||
[140, 140, 140],
|
||||
[250, 10, 15],
|
||||
[20, 255, 0],
|
||||
[31, 255, 0],
|
||||
[255, 31, 0],
|
||||
[255, 224, 0],
|
||||
[153, 255, 0],
|
||||
[0, 0, 255],
|
||||
[255, 71, 0],
|
||||
[0, 235, 255],
|
||||
[0, 173, 255],
|
||||
[31, 0, 255],
|
||||
[11, 200, 200],
|
||||
[255, 82, 0],
|
||||
[0, 255, 245],
|
||||
[0, 61, 255],
|
||||
[0, 255, 112],
|
||||
[0, 255, 133],
|
||||
[255, 0, 0],
|
||||
[255, 163, 0],
|
||||
[255, 102, 0],
|
||||
[194, 255, 0],
|
||||
[0, 143, 255],
|
||||
[51, 255, 0],
|
||||
[0, 82, 255],
|
||||
[0, 255, 41],
|
||||
[0, 255, 173],
|
||||
[10, 0, 255],
|
||||
[173, 255, 0],
|
||||
[0, 255, 153],
|
||||
[255, 92, 0],
|
||||
[255, 0, 255],
|
||||
[255, 0, 245],
|
||||
[255, 0, 102],
|
||||
[255, 173, 0],
|
||||
[255, 0, 20],
|
||||
[255, 184, 184],
|
||||
[0, 31, 255],
|
||||
[0, 255, 61],
|
||||
[0, 71, 255],
|
||||
[255, 0, 204],
|
||||
[0, 255, 194],
|
||||
[0, 255, 82],
|
||||
[0, 10, 255],
|
||||
[0, 112, 255],
|
||||
[51, 0, 255],
|
||||
[0, 194, 255],
|
||||
[0, 122, 255],
|
||||
[0, 255, 163],
|
||||
[255, 153, 0],
|
||||
[0, 255, 10],
|
||||
[255, 112, 0],
|
||||
[143, 255, 0],
|
||||
[82, 0, 255],
|
||||
[163, 255, 0],
|
||||
[255, 235, 0],
|
||||
[8, 184, 170],
|
||||
[133, 0, 255],
|
||||
[0, 255, 92],
|
||||
[184, 0, 255],
|
||||
[255, 0, 31],
|
||||
[0, 184, 255],
|
||||
[0, 214, 255],
|
||||
[255, 0, 112],
|
||||
[92, 255, 0],
|
||||
[0, 224, 255],
|
||||
[112, 224, 255],
|
||||
[70, 184, 160],
|
||||
[163, 0, 255],
|
||||
[153, 0, 255],
|
||||
[71, 255, 0],
|
||||
[255, 0, 163],
|
||||
[255, 204, 0],
|
||||
[255, 0, 143],
|
||||
[0, 255, 235],
|
||||
[133, 255, 0],
|
||||
[255, 0, 235],
|
||||
[245, 0, 255],
|
||||
[255, 0, 122],
|
||||
[255, 245, 0],
|
||||
[10, 190, 212],
|
||||
[214, 255, 0],
|
||||
[0, 204, 255],
|
||||
[20, 0, 255],
|
||||
[255, 255, 0],
|
||||
[0, 153, 255],
|
||||
[0, 41, 255],
|
||||
[0, 255, 204],
|
||||
[41, 0, 255],
|
||||
[41, 255, 0],
|
||||
[173, 0, 255],
|
||||
[0, 245, 255],
|
||||
[71, 0, 255],
|
||||
[122, 0, 255],
|
||||
[0, 255, 184],
|
||||
[0, 92, 255],
|
||||
[184, 255, 0],
|
||||
[0, 133, 255],
|
||||
[255, 214, 0],
|
||||
[25, 194, 194],
|
||||
[102, 255, 0],
|
||||
[92, 0, 255],
|
||||
]
|
||||
559
invokeai/backend/bria/controlnet_bria.py
Normal file
559
invokeai/backend/bria/controlnet_bria.py
Normal file
@@ -0,0 +1,559 @@
|
||||
# type: ignore
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from diffusers.models.controlnet import zero_module
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
|
||||
from invokeai.backend.bria.transformer_bria import (
|
||||
EmbedND,
|
||||
FluxSingleTransformerBlock,
|
||||
FluxTransformerBlock,
|
||||
TimestepProjEmbeddings,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
BRIA_CONTROL_MODES = Literal["depth", "canny", "colorgrid", "recolor", "tile", "pose"]
|
||||
|
||||
|
||||
class BriaControlModes(Enum):
|
||||
depth = 0
|
||||
canny = 1
|
||||
colorgrid = 2
|
||||
recolor = 3
|
||||
tile = 4
|
||||
pose = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class BriaControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Optional[List[int]] = None,
|
||||
num_mode: int = None,
|
||||
rope_theta: int = 10000,
|
||||
time_theta: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
|
||||
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
# text_time_guidance_cls = (
|
||||
# CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
# )
|
||||
# self.time_text_embed = text_time_guidance_cls(
|
||||
# embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
# )
|
||||
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.union = num_mode is not None and num_mode > 0
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls,
|
||||
transformer,
|
||||
num_layers: int = 4,
|
||||
num_single_layers: int = 10,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers
|
||||
config["num_single_layers"] = num_single_layers
|
||||
config["attention_head_dim"] = attention_head_dim
|
||||
config["num_attention_heads"] = num_attention_heads
|
||||
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
controlnet.single_transformer_blocks.load_state_dict(
|
||||
transformer.single_transformer_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
controlnet_mode: torch.Tensor = None,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
controlnet_mode (`torch.Tensor`):
|
||||
The mode tensor of shape `(batch_size, 1)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if guidance is not None:
|
||||
print("guidance is not supported in BriaControlNetModel")
|
||||
if pooled_projections is not None:
|
||||
print("pooled_projections is not supported in BriaControlNetModel")
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# Convert controlnet_cond to the same dtype as the model weights
|
||||
controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
|
||||
# Validate controlnet_mode values are within the valid range
|
||||
if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode):
|
||||
raise ValueError(
|
||||
f"`controlnet_mode` values must be in range [0, {self.num_mode - 1}], but got values outside this range"
|
||||
)
|
||||
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch
|
||||
controlnet_mode_emb = controlnet_mode_emb.expand(
|
||||
encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2]
|
||||
)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
|
||||
txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
block_samples = ()
|
||||
for _, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for _, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
|
||||
controlnet_single_block_samples = ()
|
||||
for single_block_sample, controlnet_block in zip(
|
||||
single_block_samples, self.controlnet_single_blocks, strict=False
|
||||
):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
||||
|
||||
# scaling
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
||||
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
controlnet_single_block_samples = (
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
return BriaControlNetOutput(
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
|
||||
|
||||
class BriaMultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
|
||||
This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `BriaControlNetModel`.
|
||||
Args:
|
||||
controlnets (`List[BriaControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`BriaControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[BriaControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, strict=False)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(
|
||||
control_block_samples, block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
else:
|
||||
for i, (image, mode, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets, strict=False)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(
|
||||
control_block_samples, block_samples, strict=False
|
||||
)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
68
invokeai/backend/bria/controlnet_utils.py
Normal file
68
invokeai/backend/bria/controlnet_utils.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_control_images(
|
||||
vae: AutoencoderKL,
|
||||
control_images: list[Image.Image],
|
||||
control_modes: list[int],
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
tensored_control_images = []
|
||||
tensored_control_modes = []
|
||||
for idx, control_image_ in enumerate(control_images):
|
||||
tensored_control_image = _prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
dtype=vae.dtype,
|
||||
)
|
||||
height, width = tensored_control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample()
|
||||
tensored_control_image = (tensored_control_image) * vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = tensored_control_image.shape[2:]
|
||||
tensored_control_image = _pack_latents(
|
||||
tensored_control_image,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
tensored_control_images.append(tensored_control_image)
|
||||
tensored_control_modes.append(
|
||||
torch.tensor(control_modes[idx]).expand(tensored_control_image.shape[0]).to(device, dtype=torch.long)
|
||||
)
|
||||
|
||||
return tensored_control_images, tensored_control_modes
|
||||
|
||||
|
||||
def _prepare_image(
|
||||
image: Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
image = image.convert("RGB")
|
||||
image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width)
|
||||
image = image.repeat_interleave(1, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
return image
|
||||
|
||||
|
||||
def _pack_latents(latents, height, width):
|
||||
latents = latents.view(1, 4, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(1, (height // 2) * (width // 2), 16)
|
||||
|
||||
return latents
|
||||
636
invokeai/backend/bria/pipeline_bria.py
Normal file
636
invokeai/backend/bria/pipeline_bria.py
Normal file
@@ -0,0 +1,636 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
|
||||
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from transformers import (
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
|
||||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("sd3.png")
|
||||
```
|
||||
"""
|
||||
|
||||
T5_PRECISION = torch.float16
|
||||
|
||||
"""
|
||||
Based on FluxPipeline with several changes:
|
||||
- no pooled embeddings
|
||||
- We use zero padding for prompts
|
||||
- No guidance embedding since this is not a distilled version
|
||||
"""
|
||||
|
||||
|
||||
class BriaPipeline(FluxPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: BriaTransformer2DModel,
|
||||
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
):
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# TODO - why different than offical flux (-1)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
|
||||
|
||||
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
|
||||
|
||||
if self.vae.config.shift_factor is None:
|
||||
self.vae.config.shift_factor = 0
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = get_t5_prompt_embeds(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=self.transformer.dtype)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
if not is_ng_none(negative_prompt):
|
||||
negative_prompt = (
|
||||
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = get_t5_prompt_embeds(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=self.transformer.dtype)
|
||||
else:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, text_ids
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
max_sequence_length: int = 128,
|
||||
clip_value: Union[None, float] = None,
|
||||
normalize: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||
images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
callback_on_step_end_tensor_inputs = (
|
||||
["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
|
||||
)
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
|
||||
(prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
|
||||
and self.scheduler.config["use_dynamic_shifting"]
|
||||
):
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = latents.shape[1] # Shift by height - Why just height?
|
||||
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
else:
|
||||
# 4. Prepare timesteps
|
||||
# Sample from training sigmas
|
||||
if isinstance(self.scheduler, DDIMScheduler) or isinstance(self.scheduler, EulerAncestralDiscreteScheduler):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, None, None
|
||||
)
|
||||
else:
|
||||
sigmas = get_original_sigmas(
|
||||
num_train_timesteps=self.scheduler.config.num_train_timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# Supprot different diffusers versions
|
||||
if diffusers.__version__ >= "0.32.0":
|
||||
latent_image_ids = latent_image_ids[0]
|
||||
text_ids = text_ids[0]
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# This is predicts "v" from flow-matching or eps from diffusion
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
cfg_noise_pred_text = noise_pred_text.std()
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if normalize:
|
||||
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
|
||||
|
||||
if clip_value:
|
||||
assert clip_value > 0
|
||||
noise_pred = noise_pred.clip(-clip_value, clip_value)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
DiffusionPipeline.to(self, *args, **kwargs)
|
||||
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
|
||||
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
|
||||
for block in self.text_encoder.encoder.block:
|
||||
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
|
||||
|
||||
if self.vae.config.shift_factor == 0 and self.vae.dtype != torch.float32:
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
return self
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
671
invokeai/backend/bria/pipeline_bria_controlnet.py
Normal file
671
invokeai/backend/bria/pipeline_bria_controlnet.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL # Waiting for diffusers udpdate
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from diffusers.utils import USE_PEFT_BACKEND, logging
|
||||
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from transformers import (
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
|
||||
from invokeai.backend.bria.controlnet_bria import BriaControlNetModel
|
||||
from invokeai.backend.bria.pipeline_bria import BriaPipeline
|
||||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BriaControlNetPipeline(BriaPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__( # EYAL - removed clip text encoder + tokenizer
|
||||
self,
|
||||
transformer: BriaTransformer2DModel,
|
||||
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
controlnet: BriaControlNetModel,
|
||||
):
|
||||
super().__init__(
|
||||
transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
|
||||
)
|
||||
self.register_modules(controlnet=controlnet)
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if control_mode is not None:
|
||||
if not isinstance(control_mode, int):
|
||||
raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
|
||||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
|
||||
|
||||
return control_image, control_mode
|
||||
|
||||
def prepare_multi_control(
|
||||
self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode
|
||||
):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
control_images = []
|
||||
for _, control_image_ in enumerate(control_image):
|
||||
control_image_ = self.prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
control_images.append(control_image_)
|
||||
|
||||
control_image = control_images
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
|
||||
raise ValueError(
|
||||
"For Multi-ControlNet, `control_mode` must be a list of the same "
|
||||
+ " length as the number of controlnets (control images) specified"
|
||||
)
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode] * len(control_image)
|
||||
# set control mode
|
||||
control_modes = []
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
cmode = -1
|
||||
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
||||
control_modes.append(control_mode)
|
||||
control_mode = control_modes
|
||||
|
||||
return control_image, control_mode
|
||||
|
||||
def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
|
||||
return controlnet_keep
|
||||
|
||||
def get_control_start_end(self, control_guidance_start, control_guidance_end):
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = 1 # TODO - why is this 1?
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
return control_guidance_start, control_guidance_end
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_image: Optional[PipelineImageInput] = None,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
latent_image_ids: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
text_ids: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
max_sequence_length: int = 128,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
Examples:
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
control_guidance_start, control_guidance_end = self.get_control_start_end(
|
||||
control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
callback_on_step_end_tensor_inputs = (
|
||||
["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
|
||||
)
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if (
|
||||
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
|
||||
and self.scheduler.config["use_dynamic_shifting"]
|
||||
):
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
|
||||
# Determine image sequence length
|
||||
if control_image is not None:
|
||||
if isinstance(control_image, list):
|
||||
image_seq_len = control_image[0].shape[1]
|
||||
else:
|
||||
image_seq_len = control_image.shape[1]
|
||||
else:
|
||||
# Use latents sequence length when no control image is provided
|
||||
image_seq_len = latents.shape[1]
|
||||
|
||||
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
else:
|
||||
# 5. Prepare timesteps
|
||||
sigmas = get_original_sigmas(
|
||||
num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Create tensor stating which controlnets to keep
|
||||
if control_image is not None:
|
||||
controlnet_keep = self.get_controlnet_keep(
|
||||
timesteps=timesteps,
|
||||
control_guidance_start=control_guidance_start,
|
||||
control_guidance_end=control_guidance_end,
|
||||
)
|
||||
|
||||
if diffusers.__version__ >= "0.32.0":
|
||||
latent_image_ids = latent_image_ids[0]
|
||||
text_ids = text_ids[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# EYAL - added the CFG loop
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
|
||||
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# Handling ControlNet
|
||||
if control_image is not None:
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
if isinstance(controlnet_conditioning_scale, list):
|
||||
cond_scale = controlnet_conditioning_scale
|
||||
else:
|
||||
cond_scale = [
|
||||
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)
|
||||
]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
controlnet_mode=control_mode,
|
||||
conditioning_scale=cond_scale,
|
||||
timestep=timestep,
|
||||
# guidance=guidance,
|
||||
# pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
controlnet_block_samples, controlnet_single_block_samples = None, None
|
||||
|
||||
# This is predicts "v" from flow-matching
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
prompt: Union[str, List[str]],
|
||||
tokenizer: T5TokenizerFast,
|
||||
text_encoder: T5EncoderModel,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
device = device or torch.device("cuda")
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
# dynamically adjust the LoRA scale
|
||||
if text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
dtype = text_encoder.dtype if text_encoder is not None else torch.float32
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = get_t5_prompt_embeds(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=dtype)
|
||||
|
||||
if negative_prompt_embeds is None:
|
||||
if not is_ng_none(negative_prompt):
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = get_t5_prompt_embeds(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=dtype)
|
||||
else:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
|
||||
if text_encoder is not None:
|
||||
if USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(text_encoder, lora_scale)
|
||||
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, text_ids
|
||||
|
||||
|
||||
def prepare_latents(
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: torch.Generator,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
vae_scale_factor = 16
|
||||
height = 2 * (int(height) // vae_scale_factor)
|
||||
width = 2 * (int(width) // vae_scale_factor)
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
322
invokeai/backend/bria/transformer_bria.py
Normal file
322
invokeai/backend/bria/transformer_bria.py
Normal file
@@ -0,0 +1,322 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import AdaLayerNormContinuous
|
||||
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
|
||||
from invokeai.backend.bria.bria_utils import FluxPosEmbed as EmbedND
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
self.time_theta = time_theta
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
max_period=self.time_theta,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, time_theta):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
"""
|
||||
Based on FluxPipeline with several changes:
|
||||
- no pooled embeddings
|
||||
- We use zero padding for prompts
|
||||
- No guidance embedding since this is not a distilled version
|
||||
"""
|
||||
|
||||
|
||||
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Parameters:
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
||||
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = None,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Optional[List[int]] = None,
|
||||
rope_theta=10000,
|
||||
time_theta=10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
|
||||
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
# if pooled_projection_dim:
|
||||
# self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# temb = (
|
||||
# self.time_text_embed(timestep, pooled_projections)
|
||||
# if guidance is None
|
||||
# else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
# )
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
# if pooled_projections:
|
||||
# temb+=self.pooled_text_embed(pooled_projections)
|
||||
|
||||
if guidance:
|
||||
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if len(txt_ids.shape) == 2:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
else:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -112,7 +112,7 @@ def denoise(
|
||||
)
|
||||
|
||||
# Slice prediction to only include the main image tokens
|
||||
if img_cond_seq is not None:
|
||||
if img_input_ids is not None:
|
||||
pred = pred[:, :original_seq_len]
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
@@ -125,26 +125,9 @@ def denoise(
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
# For negative prediction with Kontext, we need to include the reference images
|
||||
# to maintain consistency between positive and negative passes. Without this,
|
||||
# CFG would create artifacts as the attention mechanism would see different
|
||||
# spatial structures in each pass
|
||||
neg_img_input = img
|
||||
neg_img_input_ids = img_ids
|
||||
|
||||
# Add channel-wise conditioning for negative pass if present
|
||||
if img_cond is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
|
||||
|
||||
# Add sequence-wise conditioning (Kontext) for negative pass
|
||||
# This ensures reference images are processed consistently
|
||||
if img_cond_seq is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
|
||||
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
|
||||
|
||||
neg_pred = model(
|
||||
img=neg_img_input,
|
||||
img_ids=neg_img_input_ids,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
@@ -157,10 +140,6 @@ def denoise(
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
|
||||
# Slice negative prediction to match main image tokens
|
||||
if img_cond_seq is not None:
|
||||
neg_pred = neg_pred[:, :original_seq_len]
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
|
||||
|
||||
|
||||
def generate_img_ids_with_offset(
|
||||
@@ -18,10 +19,8 @@ def generate_img_ids_with_offset(
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
idx_offset: int = 0,
|
||||
h_offset: int = 0,
|
||||
w_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with optional index and spatial offsets.
|
||||
"""Generate tensor of image position ids with an optional offset.
|
||||
|
||||
Args:
|
||||
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
|
||||
@@ -29,9 +28,7 @@ def generate_img_ids_with_offset(
|
||||
batch_size (int): Number of images in the batch.
|
||||
device (torch.device): Device to create tensors on.
|
||||
dtype (torch.dtype): Data type for the tensors.
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids (default: 0).
|
||||
h_offset (int): Spatial offset for height/y-coordinates in latent space (default: 0).
|
||||
w_offset (int): Spatial offset for width/x-coordinates in latent space (default: 0).
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
|
||||
@@ -45,10 +42,6 @@ def generate_img_ids_with_offset(
|
||||
packed_height = latent_height // 2
|
||||
packed_width = latent_width // 2
|
||||
|
||||
# Convert spatial offsets from latent space to packed space
|
||||
packed_h_offset = h_offset // 2
|
||||
packed_w_offset = w_offset // 2
|
||||
|
||||
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
|
||||
# The 3 channels represent: [batch_offset, y_position, x_position]
|
||||
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
|
||||
@@ -56,13 +49,13 @@ def generate_img_ids_with_offset(
|
||||
# Set the batch offset for all positions
|
||||
img_ids[..., 0] = idx_offset
|
||||
|
||||
# Create y-coordinate indices (vertical positions) with spatial offset
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype) + packed_h_offset
|
||||
# Create y-coordinate indices (vertical positions)
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
|
||||
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
|
||||
img_ids[..., 1] = y_indices[:, None]
|
||||
|
||||
# Create x-coordinate indices (horizontal positions) with spatial offset
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype) + packed_w_offset
|
||||
# Create x-coordinate indices (horizontal positions)
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
|
||||
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
|
||||
img_ids[..., 2] = x_indices[None, :]
|
||||
|
||||
@@ -80,14 +73,14 @@ class KontextExtension:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kontext_conditioning: list[FluxKontextConditioningField],
|
||||
kontext_conditioning: FluxKontextConditioningField,
|
||||
context: InvocationContext,
|
||||
vae_field: VAEField,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Initializes the KontextExtension, pre-processing the reference images
|
||||
Initializes the KontextExtension, pre-processing the reference image
|
||||
into latents and positional IDs.
|
||||
"""
|
||||
self._context = context
|
||||
@@ -100,116 +93,54 @@ class KontextExtension:
|
||||
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
|
||||
|
||||
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encodes the reference images and prepares their concatenated latents and IDs with spatial tiling."""
|
||||
all_latents = []
|
||||
all_ids = []
|
||||
"""Encodes the reference image and prepares its latents and IDs."""
|
||||
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
|
||||
|
||||
# Track cumulative dimensions for spatial tiling
|
||||
# These track the running extent of the virtual canvas in latent space
|
||||
canvas_h = 0 # Running canvas height
|
||||
canvas_w = 0 # Running canvas width
|
||||
# Calculate aspect ratio of input image
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Find the closest preferred resolution by aspect ratio
|
||||
_, target_width, target_height = min(
|
||||
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
|
||||
)
|
||||
|
||||
# Apply BFL's scaling formula
|
||||
# This ensures compatibility with the model's training
|
||||
scaled_width = 2 * int(target_width / 16)
|
||||
scaled_height = 2 * int(target_height / 16)
|
||||
|
||||
# Resize to the exact resolution used during training
|
||||
image = image.convert("RGB")
|
||||
final_width = 8 * scaled_width
|
||||
final_height = 8 * scaled_height
|
||||
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to tensor with same normalization as BFL
|
||||
image_np = np.array(image)
|
||||
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
|
||||
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
|
||||
image_tensor = image_tensor.to(self._device)
|
||||
|
||||
# Continue with VAE encoding
|
||||
vae_info = self._context.models.load(self._vae_field.vae)
|
||||
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
for idx, kontext_field in enumerate(self.kontext_conditioning):
|
||||
image = self._context.images.get_pil(kontext_field.image.image_name)
|
||||
# Extract tensor dimensions
|
||||
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
|
||||
|
||||
# Convert to RGB
|
||||
image = image.convert("RGB")
|
||||
# Pack the latents and generate IDs
|
||||
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
||||
kontext_ids = generate_img_ids_with_offset(
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
batch_size=batch_size,
|
||||
device=self._device,
|
||||
dtype=self._dtype,
|
||||
idx_offset=1,
|
||||
)
|
||||
|
||||
# Convert to tensor using torchvision transforms for consistency
|
||||
transformation = T.Compose(
|
||||
[
|
||||
T.ToTensor(), # Converts PIL image to tensor and scales to [0, 1]
|
||||
]
|
||||
)
|
||||
image_tensor = transformation(image)
|
||||
# Convert from [0, 1] to [-1, 1] range expected by VAE
|
||||
image_tensor = image_tensor * 2.0 - 1.0
|
||||
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
||||
image_tensor = image_tensor.to(self._device)
|
||||
|
||||
# Continue with VAE encoding
|
||||
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
|
||||
# Estimate working memory for encode operation (50% of decode memory requirements)
|
||||
img_h = image_tensor.shape[-2]
|
||||
img_w = image_tensor.shape[-1]
|
||||
element_size = next(vae_info.model.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
estimated_working_memory = int(img_h * img_w * element_size * scaling_constant)
|
||||
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
# Use sample=False to get the distribution mean without noise
|
||||
kontext_latents_unpacked = vae.encode(image_tensor, sample=False)
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
# Extract tensor dimensions
|
||||
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
|
||||
|
||||
# Pad latents to be compatible with patch_size=2
|
||||
# This ensures dimensions are even for the pack() function
|
||||
pad_h = (2 - latent_height % 2) % 2
|
||||
pad_w = (2 - latent_width % 2) % 2
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
kontext_latents_unpacked = F.pad(kontext_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
|
||||
# Update dimensions after padding
|
||||
_, _, latent_height, latent_width = kontext_latents_unpacked.shape
|
||||
|
||||
# Pack the latents
|
||||
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
||||
|
||||
# Determine spatial offsets for this reference image
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
|
||||
if idx > 0: # First image starts at (0, 0)
|
||||
# Calculate potential canvas dimensions for each tiling option
|
||||
# Option 1: Tile vertically (below existing content)
|
||||
potential_h_vertical = canvas_h + latent_height
|
||||
|
||||
# Option 2: Tile horizontally (to the right of existing content)
|
||||
potential_w_horizontal = canvas_w + latent_width
|
||||
|
||||
# Choose arrangement that minimizes the maximum dimension
|
||||
# This keeps the canvas closer to square, optimizing attention computation
|
||||
if potential_h_vertical > potential_w_horizontal:
|
||||
# Tile horizontally (to the right of existing images)
|
||||
w_offset = canvas_w
|
||||
canvas_w = canvas_w + latent_width
|
||||
canvas_h = max(canvas_h, latent_height)
|
||||
else:
|
||||
# Tile vertically (below existing images)
|
||||
h_offset = canvas_h
|
||||
canvas_h = canvas_h + latent_height
|
||||
canvas_w = max(canvas_w, latent_width)
|
||||
else:
|
||||
# First image - just set canvas dimensions
|
||||
canvas_h = latent_height
|
||||
canvas_w = latent_width
|
||||
|
||||
# Generate IDs with both index offset and spatial offsets
|
||||
kontext_ids = generate_img_ids_with_offset(
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
batch_size=batch_size,
|
||||
device=self._device,
|
||||
dtype=self._dtype,
|
||||
idx_offset=1, # All reference images use index=1 (matching ComfyUI implementation)
|
||||
h_offset=h_offset,
|
||||
w_offset=w_offset,
|
||||
)
|
||||
|
||||
all_latents.append(kontext_latents_packed)
|
||||
all_ids.append(kontext_ids)
|
||||
|
||||
# Concatenate all latents and IDs along the sequence dimension
|
||||
concatenated_latents = torch.cat(all_latents, dim=1) # Concatenate along sequence dimension
|
||||
concatenated_ids = torch.cat(all_ids, dim=1) # Concatenate along sequence dimension
|
||||
|
||||
return concatenated_latents, concatenated_ids
|
||||
return kontext_latents_packed, kontext_ids
|
||||
|
||||
def ensure_batch_size(self, target_batch_size: int) -> None:
|
||||
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
# This file is vendored from https://github.com/ShieldMnt/invisible-watermark
|
||||
#
|
||||
# `invisible-watermark` is MIT licensed as of August 23, 2025, when the code was copied into this repo.
|
||||
#
|
||||
# Why we vendored it in:
|
||||
# `invisible-watermark` has a dependency on `opencv-python`, which conflicts with Invoke's dependency on
|
||||
# `opencv-contrib-python`. It's easier to copy the code over than complicate the installation process by
|
||||
# requiring an extra post-install step of removing `opencv-python` and installing `opencv-contrib-python`.
|
||||
|
||||
import struct
|
||||
import uuid
|
||||
import base64
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
|
||||
class WatermarkEncoder(object):
|
||||
def __init__(self, content=b""):
|
||||
seq = np.array([n for n in content], dtype=np.uint8)
|
||||
self._watermarks = list(np.unpackbits(seq))
|
||||
self._wmLen = len(self._watermarks)
|
||||
self._wmType = "bytes"
|
||||
|
||||
def set_by_ipv4(self, addr):
|
||||
bits = []
|
||||
ips = addr.split(".")
|
||||
for ip in ips:
|
||||
bits += list(np.unpackbits(np.array([ip % 255], dtype=np.uint8)))
|
||||
self._watermarks = bits
|
||||
self._wmLen = len(self._watermarks)
|
||||
self._wmType = "ipv4"
|
||||
assert self._wmLen == 32
|
||||
|
||||
def set_by_uuid(self, uid):
|
||||
u = uuid.UUID(uid)
|
||||
self._wmType = "uuid"
|
||||
seq = np.array([n for n in u.bytes], dtype=np.uint8)
|
||||
self._watermarks = list(np.unpackbits(seq))
|
||||
self._wmLen = len(self._watermarks)
|
||||
|
||||
def set_by_bytes(self, content):
|
||||
self._wmType = "bytes"
|
||||
seq = np.array([n for n in content], dtype=np.uint8)
|
||||
self._watermarks = list(np.unpackbits(seq))
|
||||
self._wmLen = len(self._watermarks)
|
||||
|
||||
def set_by_b16(self, b16):
|
||||
content = base64.b16decode(b16)
|
||||
self.set_by_bytes(content)
|
||||
self._wmType = "b16"
|
||||
|
||||
def set_by_bits(self, bits=[]):
|
||||
self._watermarks = [int(bit) % 2 for bit in bits]
|
||||
self._wmLen = len(self._watermarks)
|
||||
self._wmType = "bits"
|
||||
|
||||
def set_watermark(self, wmType="bytes", content=""):
|
||||
if wmType == "ipv4":
|
||||
self.set_by_ipv4(content)
|
||||
elif wmType == "uuid":
|
||||
self.set_by_uuid(content)
|
||||
elif wmType == "bits":
|
||||
self.set_by_bits(content)
|
||||
elif wmType == "bytes":
|
||||
self.set_by_bytes(content)
|
||||
elif wmType == "b16":
|
||||
self.set_by_b16(content)
|
||||
else:
|
||||
raise NameError("%s is not supported" % wmType)
|
||||
|
||||
def get_length(self):
|
||||
return self._wmLen
|
||||
|
||||
# @classmethod
|
||||
# def loadModel(cls):
|
||||
# RivaWatermark.loadModel()
|
||||
|
||||
def encode(self, cv2Image, method="dwtDct", **configs):
|
||||
(r, c, channels) = cv2Image.shape
|
||||
if r * c < 256 * 256:
|
||||
raise RuntimeError("image too small, should be larger than 256x256")
|
||||
|
||||
if method == "dwtDct":
|
||||
embed = EmbedMaxDct(self._watermarks, wmLen=self._wmLen, **configs)
|
||||
return embed.encode(cv2Image)
|
||||
# elif method == 'dwtDctSvd':
|
||||
# embed = EmbedDwtDctSvd(self._watermarks, wmLen=self._wmLen, **configs)
|
||||
# return embed.encode(cv2Image)
|
||||
# elif method == 'rivaGan':
|
||||
# embed = RivaWatermark(self._watermarks, self._wmLen)
|
||||
# return embed.encode(cv2Image)
|
||||
else:
|
||||
raise NameError("%s is not supported" % method)
|
||||
|
||||
|
||||
class WatermarkDecoder(object):
|
||||
def __init__(self, wm_type="bytes", length=0):
|
||||
self._wmType = wm_type
|
||||
if wm_type == "ipv4":
|
||||
self._wmLen = 32
|
||||
elif wm_type == "uuid":
|
||||
self._wmLen = 128
|
||||
elif wm_type == "bytes":
|
||||
self._wmLen = length
|
||||
elif wm_type == "bits":
|
||||
self._wmLen = length
|
||||
elif wm_type == "b16":
|
||||
self._wmLen = length
|
||||
else:
|
||||
raise NameError("%s is unsupported" % wm_type)
|
||||
|
||||
def reconstruct_ipv4(self, bits):
|
||||
ips = [str(ip) for ip in list(np.packbits(bits))]
|
||||
return ".".join(ips)
|
||||
|
||||
def reconstruct_uuid(self, bits):
|
||||
nums = np.packbits(bits)
|
||||
bstr = b""
|
||||
for i in range(16):
|
||||
bstr += struct.pack(">B", nums[i])
|
||||
|
||||
return str(uuid.UUID(bytes=bstr))
|
||||
|
||||
def reconstruct_bits(self, bits):
|
||||
# return ''.join([str(b) for b in bits])
|
||||
return bits
|
||||
|
||||
def reconstruct_b16(self, bits):
|
||||
bstr = self.reconstruct_bytes(bits)
|
||||
return base64.b16encode(bstr)
|
||||
|
||||
def reconstruct_bytes(self, bits):
|
||||
nums = np.packbits(bits)
|
||||
bstr = b""
|
||||
for i in range(self._wmLen // 8):
|
||||
bstr += struct.pack(">B", nums[i])
|
||||
return bstr
|
||||
|
||||
def reconstruct(self, bits):
|
||||
if len(bits) != self._wmLen:
|
||||
raise RuntimeError("bits are not matched with watermark length")
|
||||
|
||||
if self._wmType == "ipv4":
|
||||
return self.reconstruct_ipv4(bits)
|
||||
elif self._wmType == "uuid":
|
||||
return self.reconstruct_uuid(bits)
|
||||
elif self._wmType == "bits":
|
||||
return self.reconstruct_bits(bits)
|
||||
elif self._wmType == "b16":
|
||||
return self.reconstruct_b16(bits)
|
||||
else:
|
||||
return self.reconstruct_bytes(bits)
|
||||
|
||||
def decode(self, cv2Image, method="dwtDct", **configs):
|
||||
(r, c, channels) = cv2Image.shape
|
||||
if r * c < 256 * 256:
|
||||
raise RuntimeError("image too small, should be larger than 256x256")
|
||||
|
||||
bits = []
|
||||
if method == "dwtDct":
|
||||
embed = EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
|
||||
bits = embed.decode(cv2Image)
|
||||
# elif method == 'dwtDctSvd':
|
||||
# embed = EmbedDwtDctSvd(watermarks=[], wmLen=self._wmLen, **configs)
|
||||
# bits = embed.decode(cv2Image)
|
||||
# elif method == 'rivaGan':
|
||||
# embed = RivaWatermark(watermarks=[], wmLen=self._wmLen, **configs)
|
||||
# bits = embed.decode(cv2Image)
|
||||
else:
|
||||
raise NameError("%s is not supported" % method)
|
||||
return self.reconstruct(bits)
|
||||
|
||||
# @classmethod
|
||||
# def loadModel(cls):
|
||||
# RivaWatermark.loadModel()
|
||||
|
||||
|
||||
class EmbedMaxDct(object):
|
||||
def __init__(self, watermarks=[], wmLen=8, scales=[0, 36, 36], block=4):
|
||||
self._watermarks = watermarks
|
||||
self._wmLen = wmLen
|
||||
self._scales = scales
|
||||
self._block = block
|
||||
|
||||
def encode(self, bgr):
|
||||
(row, col, channels) = bgr.shape
|
||||
|
||||
yuv = cv2.cvtColor(bgr, cv2.COLOR_BGR2YUV)
|
||||
|
||||
for channel in range(2):
|
||||
if self._scales[channel] <= 0:
|
||||
continue
|
||||
|
||||
ca1, (h1, v1, d1) = pywt.dwt2(yuv[: row // 4 * 4, : col // 4 * 4, channel], "haar")
|
||||
self.encode_frame(ca1, self._scales[channel])
|
||||
|
||||
yuv[: row // 4 * 4, : col // 4 * 4, channel] = pywt.idwt2((ca1, (v1, h1, d1)), "haar")
|
||||
|
||||
bgr_encoded = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR)
|
||||
return bgr_encoded
|
||||
|
||||
def decode(self, bgr):
|
||||
(row, col, channels) = bgr.shape
|
||||
|
||||
yuv = cv2.cvtColor(bgr, cv2.COLOR_BGR2YUV)
|
||||
|
||||
scores = [[] for i in range(self._wmLen)]
|
||||
for channel in range(2):
|
||||
if self._scales[channel] <= 0:
|
||||
continue
|
||||
|
||||
ca1, (h1, v1, d1) = pywt.dwt2(yuv[: row // 4 * 4, : col // 4 * 4, channel], "haar")
|
||||
|
||||
scores = self.decode_frame(ca1, self._scales[channel], scores)
|
||||
|
||||
avgScores = list(map(lambda l: np.array(l).mean(), scores))
|
||||
|
||||
bits = np.array(avgScores) * 255 > 127
|
||||
return bits
|
||||
|
||||
def decode_frame(self, frame, scale, scores):
|
||||
(row, col) = frame.shape
|
||||
num = 0
|
||||
|
||||
for i in range(row // self._block):
|
||||
for j in range(col // self._block):
|
||||
block = frame[
|
||||
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
|
||||
]
|
||||
|
||||
score = self.infer_dct_matrix(block, scale)
|
||||
# score = self.infer_dct_svd(block, scale)
|
||||
wmBit = num % self._wmLen
|
||||
scores[wmBit].append(score)
|
||||
num = num + 1
|
||||
|
||||
return scores
|
||||
|
||||
def diffuse_dct_svd(self, block, wmBit, scale):
|
||||
u, s, v = np.linalg.svd(cv2.dct(block))
|
||||
|
||||
s[0] = (s[0] // scale + 0.25 + 0.5 * wmBit) * scale
|
||||
return cv2.idct(np.dot(u, np.dot(np.diag(s), v)))
|
||||
|
||||
def infer_dct_svd(self, block, scale):
|
||||
u, s, v = np.linalg.svd(cv2.dct(block))
|
||||
|
||||
score = 0
|
||||
score = int((s[0] % scale) > scale * 0.5)
|
||||
return score
|
||||
if score >= 0.5:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def diffuse_dct_matrix(self, block, wmBit, scale):
|
||||
pos = np.argmax(abs(block.flatten()[1:])) + 1
|
||||
i, j = pos // self._block, pos % self._block
|
||||
val = block[i][j]
|
||||
if val >= 0.0:
|
||||
block[i][j] = (val // scale + 0.25 + 0.5 * wmBit) * scale
|
||||
else:
|
||||
val = abs(val)
|
||||
block[i][j] = -1.0 * (val // scale + 0.25 + 0.5 * wmBit) * scale
|
||||
return block
|
||||
|
||||
def infer_dct_matrix(self, block, scale):
|
||||
pos = np.argmax(abs(block.flatten()[1:])) + 1
|
||||
i, j = pos // self._block, pos % self._block
|
||||
|
||||
val = block[i][j]
|
||||
if val < 0:
|
||||
val = abs(val)
|
||||
|
||||
if (val % scale) > 0.5 * scale:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def encode_frame(self, frame, scale):
|
||||
"""
|
||||
frame is a matrix (M, N)
|
||||
|
||||
we get K (watermark bits size) blocks (self._block x self._block)
|
||||
|
||||
For i-th block, we encode watermark[i] bit into it
|
||||
"""
|
||||
(row, col) = frame.shape
|
||||
num = 0
|
||||
for i in range(row // self._block):
|
||||
for j in range(col // self._block):
|
||||
block = frame[
|
||||
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
|
||||
]
|
||||
wmBit = self._watermarks[(num % self._wmLen)]
|
||||
|
||||
diffusedBlock = self.diffuse_dct_matrix(block, wmBit, scale)
|
||||
# diffusedBlock = self.diffuse_dct_svd(block, wmBit, scale)
|
||||
frame[
|
||||
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
|
||||
] = diffusedBlock
|
||||
|
||||
num = num + 1
|
||||
@@ -6,10 +6,13 @@ configuration variable, that allows the watermarking to be supressed.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from imwatermark import WatermarkEncoder
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.image_util.imwatermark.vendor import WatermarkEncoder
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class InvisibleWatermark:
|
||||
|
||||
@@ -9,7 +9,6 @@ import spandrel
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
@@ -126,6 +125,8 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"BriaPipeline": ModelType.Main,
|
||||
"BriaTransformer2DModel": ModelType.ControlNet,
|
||||
"FluxPipeline": ModelType.Main,
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
@@ -494,21 +495,9 @@ class ModelProbe(object):
|
||||
# scan model
|
||||
scan_result = pscan.scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if get_config().unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
f"The model {model_name} is potentially infected by malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
if scan_result.scan_err:
|
||||
if get_config().unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
f"Error scanning the model at {model_name} for malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.")
|
||||
raise Exception(f"Error scanning model {model_name} for malware. Aborting import.")
|
||||
|
||||
|
||||
# Probing utilities
|
||||
@@ -874,6 +863,8 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.StableDiffusion3
|
||||
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
|
||||
return BaseModelType.CogView4
|
||||
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
|
||||
return BaseModelType.Bria
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
@@ -1023,6 +1014,9 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
if config.get("_class_name", None) == "FluxControlNetModel":
|
||||
return BaseModelType.Flux
|
||||
|
||||
if config.get("_class_name", None) == "BriaTransformer2DModel":
|
||||
return BaseModelType.Bria
|
||||
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
if dimension == 768:
|
||||
|
||||
96
invokeai/backend/model_manager/load/model_loaders/bria.py
Normal file
96
invokeai/backend/model_manager/load/model_loaders/bria.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
DiffusersConfigBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
class BriaControlNetDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load Bria control net models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path)
|
||||
repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path
|
||||
|
||||
dtype = self._torch_dtype
|
||||
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
use_safetensors=False,
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
class BriaDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load Bria main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, CheckpointConfigBase):
|
||||
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
|
||||
|
||||
if submodel_type is None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
|
||||
dtype = self._torch_dtype
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return result
|
||||
@@ -80,7 +80,13 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
"transformers",
|
||||
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||
"transformer_bria",
|
||||
]:
|
||||
if module == "transformer_bria":
|
||||
module = "invokeai.backend.bria.transformer_bria"
|
||||
elif class_name == "BriaTransformer2DModel":
|
||||
class_name = "BriaControlNetModel"
|
||||
module = "invokeai.backend.bria.controlnet_bria"
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
|
||||
@@ -12,6 +12,9 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
@@ -62,6 +65,8 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
else:
|
||||
# If neither is available, return 0
|
||||
return 0
|
||||
elif isinstance(model, (Body, Hand, Face)):
|
||||
return calc_module_size(model.model)
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
|
||||
@@ -6,17 +6,13 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors import safe_open
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class ModelOnDisk:
|
||||
"""A utility class representing a model stored on disk."""
|
||||
@@ -83,24 +79,8 @@ class ModelOnDisk:
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
if get_config().unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
f"The model {path.stem} is potentially infected by malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"The model {path.stem} is potentially infected by malware. Aborting import."
|
||||
)
|
||||
if scan_result.scan_err:
|
||||
if get_config().unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
f"Error scanning the model at {path.stem} for malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Error scanning the model at {path.stem} for malware. Aborting import.")
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
assert isinstance(checkpoint, dict)
|
||||
elif path.suffix.endswith(".gguf"):
|
||||
|
||||
@@ -149,29 +149,13 @@ flux_kontext = StarterModel(
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_kontext_quantized = StarterModel(
|
||||
name="FLUX.1 Kontext dev (quantized)",
|
||||
name="FLUX.1 Kontext dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_krea = StarterModel(
|
||||
name="FLUX.1 Krea dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors",
|
||||
description="FLUX.1 Krea dev. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_krea_quantized = StarterModel(
|
||||
name="FLUX.1 Krea dev (quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
@@ -596,14 +580,13 @@ t2i_sketch_sdxl = StarterModel(
|
||||
)
|
||||
# endregion
|
||||
# region SpandrelImageToImage
|
||||
animesharp_v4_rcan = StarterModel(
|
||||
name="2x-AnimeSharpV4_RCAN",
|
||||
realesrgan_anime = StarterModel(
|
||||
name="RealESRGAN_x4plus_anime_6B",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN.safetensors",
|
||||
description="A 2x upscaling model (optimized for anime images).",
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
)
|
||||
|
||||
realesrgan_x4 = StarterModel(
|
||||
name="RealESRGAN_x4plus",
|
||||
base=BaseModelType.Any,
|
||||
@@ -749,7 +732,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
t2i_lineart_sdxl,
|
||||
t2i_sketch_sdxl,
|
||||
realesrgan_x4,
|
||||
animesharp_v4_rcan,
|
||||
realesrgan_anime,
|
||||
realesrgan_x2,
|
||||
swinir,
|
||||
t5_base_encoder,
|
||||
@@ -760,8 +743,6 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
llava_onevision,
|
||||
flux_fill,
|
||||
cogview4,
|
||||
flux_krea,
|
||||
flux_krea_quantized,
|
||||
]
|
||||
|
||||
sd1_bundle: list[StarterModel] = [
|
||||
@@ -813,7 +794,6 @@ flux_bundle: list[StarterModel] = [
|
||||
flux_redux,
|
||||
flux_fill,
|
||||
flux_kontext_quantized,
|
||||
flux_krea_quantized,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
|
||||
|
||||
@@ -28,9 +28,9 @@ class BaseModelType(str, Enum):
|
||||
CogView4 = "cogview4"
|
||||
Imagen3 = "imagen3"
|
||||
Imagen4 = "imagen4"
|
||||
Gemini2_5 = "gemini-2.5"
|
||||
ChatGPT4o = "chatgpt-4o"
|
||||
FluxKontext = "flux-kontext"
|
||||
Bria = "bria"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
|
||||
@@ -8,12 +8,8 @@ import picklescan.scanner as pscan
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.model_manager.taxonomy import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
@@ -63,21 +59,9 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
|
||||
if scan:
|
||||
scan_result = pscan.scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
if get_config().unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
f"The model {path} is potentially infected by malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"The model {path} is potentially infected by malware. Aborting import.")
|
||||
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
|
||||
if scan_result.scan_err:
|
||||
if get_config().unsafe_disable_picklescan:
|
||||
logger.warning(
|
||||
f"Error scanning the model at {path} for malware, but picklescan is disabled. "
|
||||
"Proceeding with caution."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Error scanning the model at {path} for malware. Aborting import.")
|
||||
raise Exception(f"Error scanning model at {path} for malware. Aborting import.")
|
||||
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
@@ -18,25 +18,16 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
||||
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
|
||||
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
|
||||
|
||||
# Check if keys use transformer prefix
|
||||
transformer_prefix_keys = [
|
||||
# Next, check that this is likely a FLUX model by spot-checking a few keys.
|
||||
expected_keys = [
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||
]
|
||||
transformer_keys_present = all(k in state_dict for k in transformer_prefix_keys)
|
||||
all_expected_keys_present = all(k in state_dict for k in expected_keys)
|
||||
|
||||
# Check if keys use base_model.model prefix
|
||||
base_model_prefix_keys = [
|
||||
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||
]
|
||||
base_model_keys_present = all(k in state_dict for k in base_model_prefix_keys)
|
||||
|
||||
return all_keys_in_peft_format and (transformer_keys_present or base_model_keys_present)
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
def lora_model_from_flux_diffusers_state_dict(
|
||||
@@ -58,16 +49,8 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
|
||||
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
|
||||
"""
|
||||
|
||||
# Determine which prefix is used and remove it from all keys.
|
||||
# Check if any key starts with "base_model.model." prefix
|
||||
has_base_model_prefix = any(k.startswith("base_model.model.") for k in grouped_state_dict.keys())
|
||||
|
||||
if has_base_model_prefix:
|
||||
# Remove the "base_model.model." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("base_model.model.", ""): v for k, v in grouped_state_dict.items()}
|
||||
else:
|
||||
# Remove the "transformer." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||
# Remove the "transformer." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||
|
||||
# Constants for FLUX.1
|
||||
num_double_layers = 19
|
||||
|
||||
@@ -20,7 +20,7 @@ def main():
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
with log_time("Initialize FLUX transformer on meta device"):
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def main():
|
||||
)
|
||||
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Initialize FLUX transformer on meta device"):
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def main():
|
||||
"""
|
||||
model_path = Path("/data/misc/text_encoder_2")
|
||||
|
||||
with log_time("Initialize T5 on meta device"):
|
||||
with log_time("Intialize T5 on meta device"):
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForTextEncoding.from_config(model_config)
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
|
||||
def estimate_vae_working_memory_sd15_sdxl(
|
||||
operation: Literal["encode", "decode"],
|
||||
image_tensor: torch.Tensor,
|
||||
vae: AutoencoderKL | AutoencoderTiny,
|
||||
tile_size: int | None,
|
||||
fp32: bool,
|
||||
) -> int:
|
||||
"""Estimate the working memory required to encode or decode the given tensor."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
element_size = 4 if fp32 else 2
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
if tile_size is not None:
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
h = tile_size
|
||||
w = tile_size
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = working_memory * 1.25
|
||||
else:
|
||||
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
if fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
|
||||
def estimate_vae_working_memory_cogview4(
|
||||
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
print(f"estimate_vae_working_memory_cogview4: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
|
||||
def estimate_vae_working_memory_flux(
|
||||
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoEncoder
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
out_h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
out_w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
print(f"estimate_vae_working_memory_flux: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
|
||||
def estimate_vae_working_memory_sd3(
|
||||
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
print(f"estimate_vae_working_memory_sd3: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
3
invokeai/frontend/web/.gitignore
vendored
3
invokeai/frontend/web/.gitignore
vendored
@@ -44,5 +44,4 @@ yalc.lock
|
||||
|
||||
# vitest
|
||||
tsconfig.vitest-temp.json
|
||||
coverage/
|
||||
*.tgz
|
||||
coverage/
|
||||
@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
|
||||
returnNull: false,
|
||||
});
|
||||
|
||||
const store = createStore();
|
||||
const store = createStore(undefined, false);
|
||||
$store.set(store);
|
||||
$baseUrl.set('http://localhost:9090');
|
||||
|
||||
|
||||
@@ -197,10 +197,6 @@ export default [
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
{
|
||||
name: 'zod/v3',
|
||||
message: 'Import from zod instead.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
|
||||
@@ -17,7 +17,6 @@ const config: KnipConfig = {
|
||||
'src/app/store/use-debounced-app-selector.ts',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
ignoreDependencies: ['magic-string'],
|
||||
paths: {
|
||||
'public/*': ['public/*'],
|
||||
},
|
||||
|
||||
@@ -63,7 +63,7 @@
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^25.3.2",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "6.2.1",
|
||||
"idb-keyval": "6.2.2",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.22",
|
||||
"linkify-react": "^4.3.1",
|
||||
@@ -103,7 +103,7 @@
|
||||
"use-debounce": "^10.0.5",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"uuid": "^11.1.0",
|
||||
"zod": "^4.0.10",
|
||||
"zod": "^4.0.5",
|
||||
"zod-validation-error": "^3.5.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
@@ -139,7 +139,6 @@
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"globals": "^16.3.0",
|
||||
"knip": "^5.61.3",
|
||||
"magic-string": "^0.30.17",
|
||||
"openapi-types": "^12.1.3",
|
||||
"openapi-typescript": "^7.6.1",
|
||||
"prettier": "^3.5.3",
|
||||
|
||||
37
invokeai/frontend/web/pnpm-lock.yaml
generated
37
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -81,8 +81,8 @@ importers:
|
||||
specifier: ^3.0.2
|
||||
version: 3.0.2
|
||||
idb-keyval:
|
||||
specifier: 6.2.1
|
||||
version: 6.2.1
|
||||
specifier: 6.2.2
|
||||
version: 6.2.2
|
||||
jsondiffpatch:
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3
|
||||
@@ -201,11 +201,11 @@ importers:
|
||||
specifier: ^11.1.0
|
||||
version: 11.1.0
|
||||
zod:
|
||||
specifier: ^4.0.10
|
||||
version: 4.0.10
|
||||
specifier: ^4.0.5
|
||||
version: 4.0.5
|
||||
zod-validation-error:
|
||||
specifier: ^3.5.2
|
||||
version: 3.5.3(zod@4.0.10)
|
||||
version: 3.5.3(zod@4.0.5)
|
||||
devDependencies:
|
||||
'@eslint/js':
|
||||
specifier: ^9.31.0
|
||||
@@ -291,9 +291,6 @@ importers:
|
||||
knip:
|
||||
specifier: ^5.61.3
|
||||
version: 5.61.3(@types/node@22.16.0)(typescript@5.8.3)
|
||||
magic-string:
|
||||
specifier: ^0.30.17
|
||||
version: 0.30.17
|
||||
openapi-types:
|
||||
specifier: ^12.1.3
|
||||
version: 12.1.3
|
||||
@@ -414,10 +411,6 @@ packages:
|
||||
resolution: {integrity: sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/runtime@7.28.2':
|
||||
resolution: {integrity: sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/template@7.27.2':
|
||||
resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
@@ -2778,8 +2771,8 @@ packages:
|
||||
typescript:
|
||||
optional: true
|
||||
|
||||
idb-keyval@6.2.1:
|
||||
resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
|
||||
idb-keyval@6.2.2:
|
||||
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
|
||||
|
||||
ieee754@1.2.1:
|
||||
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
|
||||
@@ -4518,8 +4511,8 @@ packages:
|
||||
zod@3.25.76:
|
||||
resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==}
|
||||
|
||||
zod@4.0.10:
|
||||
resolution: {integrity: sha512-3vB+UU3/VmLL2lvwcY/4RV2i9z/YU0DTV/tDuYjrwmx5WeJ7hwy+rGEEx8glHp6Yxw7ibRbKSaIFBgReRPe5KA==}
|
||||
zod@4.0.5:
|
||||
resolution: {integrity: sha512-/5UuuRPStvHXu7RS+gmvRf4NXrNxpSllGwDnCBcJZtQsKrviYXm54yDGV2KYNLT5kq0lHGcl7lqWJLgSaG+tgA==}
|
||||
|
||||
zustand@4.5.7:
|
||||
resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==}
|
||||
@@ -4640,8 +4633,6 @@ snapshots:
|
||||
|
||||
'@babel/runtime@7.27.6': {}
|
||||
|
||||
'@babel/runtime@7.28.2': {}
|
||||
|
||||
'@babel/template@7.27.2':
|
||||
dependencies:
|
||||
'@babel/code-frame': 7.27.1
|
||||
@@ -5745,7 +5736,7 @@ snapshots:
|
||||
'@testing-library/dom@10.4.0':
|
||||
dependencies:
|
||||
'@babel/code-frame': 7.27.1
|
||||
'@babel/runtime': 7.28.2
|
||||
'@babel/runtime': 7.27.6
|
||||
'@types/aria-query': 5.0.4
|
||||
aria-query: 5.3.0
|
||||
chalk: 4.1.2
|
||||
@@ -7275,7 +7266,7 @@ snapshots:
|
||||
optionalDependencies:
|
||||
typescript: 5.8.3
|
||||
|
||||
idb-keyval@6.2.1: {}
|
||||
idb-keyval@6.2.2: {}
|
||||
|
||||
ieee754@1.2.1: {}
|
||||
|
||||
@@ -9071,13 +9062,13 @@ snapshots:
|
||||
dependencies:
|
||||
zod: 3.25.76
|
||||
|
||||
zod-validation-error@3.5.3(zod@4.0.10):
|
||||
zod-validation-error@3.5.3(zod@4.0.5):
|
||||
dependencies:
|
||||
zod: 4.0.10
|
||||
zod: 4.0.5
|
||||
|
||||
zod@3.25.76: {}
|
||||
|
||||
zod@4.0.10: {}
|
||||
zod@4.0.5: {}
|
||||
|
||||
zustand@4.5.7(@types/react@18.3.23)(immer@10.1.1)(react@18.3.1):
|
||||
dependencies:
|
||||
|
||||
@@ -1470,6 +1470,7 @@
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"queue": "Warteschlange",
|
||||
"generation": "Erzeugung",
|
||||
"gallery": "Galerie",
|
||||
"models": "Modelle",
|
||||
"upscaling": "Hochskalierung",
|
||||
|
||||
@@ -38,7 +38,6 @@
|
||||
"deletedImagesCannotBeRestored": "Deleted images cannot be restored.",
|
||||
"hideBoards": "Hide Boards",
|
||||
"loading": "Loading...",
|
||||
"locateInGalery": "Locate in Gallery",
|
||||
"menuItemAutoAdd": "Auto-add to this Board",
|
||||
"move": "Move",
|
||||
"movingImagesToBoard_one": "Moving {{count}} image to board:",
|
||||
@@ -115,9 +114,6 @@
|
||||
"t2iAdapter": "T2I Adapter",
|
||||
"positivePrompt": "Positive Prompt",
|
||||
"negativePrompt": "Negative Prompt",
|
||||
"removeNegativePrompt": "Remove Negative Prompt",
|
||||
"addNegativePrompt": "Add Negative Prompt",
|
||||
"selectYourModel": "Select Your Model",
|
||||
"discordLabel": "Discord",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"dontShowMeThese": "Don't show me these",
|
||||
@@ -614,18 +610,10 @@
|
||||
"title": "Toggle Non-Raster Layers",
|
||||
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
|
||||
},
|
||||
"fitBboxToLayers": {
|
||||
"title": "Fit Bbox To Layers",
|
||||
"desc": "Automatically adjust the generation bounding box to fit visible layers"
|
||||
},
|
||||
"fitBboxToMasks": {
|
||||
"title": "Fit Bbox To Masks",
|
||||
"desc": "Automatically adjust the generation bounding box to fit visible inpaint masks"
|
||||
},
|
||||
"toggleBbox": {
|
||||
"title": "Toggle Bbox Visibility",
|
||||
"desc": "Hide or show the generation bounding box"
|
||||
},
|
||||
"applySegmentAnything": {
|
||||
"title": "Apply Segment Anything",
|
||||
"desc": "Apply the current Segment Anything mask.",
|
||||
@@ -775,7 +763,6 @@
|
||||
"allPrompts": "All Prompts",
|
||||
"cfgScale": "CFG scale",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"clipSkip": "$t(parameters.clipSkip)",
|
||||
"createdBy": "Created By",
|
||||
"generationMode": "Generation Mode",
|
||||
"guidance": "Guidance",
|
||||
@@ -878,9 +865,6 @@
|
||||
"install": "Install",
|
||||
"installAll": "Install All",
|
||||
"installRepo": "Install Repo",
|
||||
"installBundle": "Install Bundle",
|
||||
"installBundleMsg1": "Are you sure you want to install the {{bundleName}} bundle?",
|
||||
"installBundleMsg2": "This bundle will install the following {{count}} models:",
|
||||
"ipAdapters": "IP Adapters",
|
||||
"learnMoreAboutSupportedModels": "Learn more about the models we support",
|
||||
"load": "Load",
|
||||
@@ -1250,8 +1234,10 @@
|
||||
"modelIncompatibleBboxHeight": "Bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
|
||||
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
|
||||
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
|
||||
"briaRequiresExactDimensions": "Bria requires exact {{size}}x{{size}} dimensions",
|
||||
"briaRequiresExactScaledDimensions": "Bria requires exact {{size}}x{{size}} scaled dimensions",
|
||||
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
|
||||
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with FLUX Kontext via BFL API",
|
||||
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
|
||||
"canvasIsFiltering": "Canvas is busy (filtering)",
|
||||
"canvasIsTransforming": "Canvas is busy (transforming)",
|
||||
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
|
||||
@@ -1299,7 +1285,6 @@
|
||||
"remixImage": "Remix Image",
|
||||
"usePrompt": "Use Prompt",
|
||||
"useSeed": "Use Seed",
|
||||
"useClipSkip": "Use CLIP Skip",
|
||||
"width": "Width",
|
||||
"gaussianBlur": "Gaussian Blur",
|
||||
"boxBlur": "Box Blur",
|
||||
@@ -1381,8 +1366,8 @@
|
||||
"addedToBoard": "Added to board {{name}}'s assets",
|
||||
"addedToUncategorized": "Added to board $t(boards.uncategorized)'s assets",
|
||||
"baseModelChanged": "Base Model Changed",
|
||||
"baseModelChangedCleared_one": "Updated, cleared or disabled {{count}} incompatible submodel",
|
||||
"baseModelChangedCleared_other": "Updated, cleared or disabled {{count}} incompatible submodels",
|
||||
"baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel",
|
||||
"baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels",
|
||||
"canceled": "Processing Canceled",
|
||||
"connected": "Connected to Server",
|
||||
"imageCopied": "Image Copied",
|
||||
@@ -1950,11 +1935,8 @@
|
||||
"zoomToNode": "Zoom to Node",
|
||||
"nodeFieldTooltip": "To add a node field, click the small plus sign button on the field in the Workflow Editor, or drag the field by its name into the form.",
|
||||
"addToForm": "Add to Form",
|
||||
"removeFromForm": "Remove from Form",
|
||||
"label": "Label",
|
||||
"showDescription": "Show Description",
|
||||
"showShuffle": "Show Shuffle",
|
||||
"shuffle": "Shuffle",
|
||||
"component": "Component",
|
||||
"numberInput": "Number Input",
|
||||
"singleLine": "Single Line",
|
||||
@@ -2086,8 +2068,6 @@
|
||||
"asControlLayer": "As $t(controlLayers.controlLayer)",
|
||||
"asControlLayerResize": "As $t(controlLayers.controlLayer) (Resize)",
|
||||
"referenceImage": "Reference Image",
|
||||
"maxRefImages": "Max Ref Images",
|
||||
"useAsReferenceImage": "Use as Reference Image",
|
||||
"regionalReferenceImage": "Regional Reference Image",
|
||||
"globalReferenceImage": "Global Reference Image",
|
||||
"sendingToCanvas": "Staging Generations on Canvas",
|
||||
@@ -2196,8 +2176,7 @@
|
||||
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
|
||||
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
|
||||
"rgNoRegion": "no region drawn",
|
||||
"fluxFillIncompatibleWithControlLoRA": "Control LoRA is not compatible with FLUX Fill",
|
||||
"bboxHidden": "Bounding box is hidden (shift+o to toggle)"
|
||||
"fluxFillIncompatibleWithControlLoRA": "Control LoRA is not compatible with FLUX Fill"
|
||||
},
|
||||
"errors": {
|
||||
"unableToFindImage": "Unable to find image",
|
||||
@@ -2208,7 +2187,13 @@
|
||||
"balanced": "Balanced (recommended)",
|
||||
"prompt": "Prompt",
|
||||
"control": "Control",
|
||||
"megaControl": "Mega Control"
|
||||
"megaControl": "Mega Control",
|
||||
"depth": "Depth",
|
||||
"canny": "Canny",
|
||||
"colorgrid": "Color Grid",
|
||||
"recolor": "Recolor",
|
||||
"tile": "Tile",
|
||||
"pose": "Pose"
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"ipAdapterMethod": "Mode",
|
||||
@@ -2556,7 +2541,7 @@
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generate": "Generate",
|
||||
"generation": "Generation",
|
||||
"canvas": "Canvas",
|
||||
"workflows": "Workflows",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
|
||||
@@ -2567,12 +2552,6 @@
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
|
||||
"gallery": "Gallery"
|
||||
},
|
||||
"panels": {
|
||||
"launchpad": "Launchpad",
|
||||
"workflowEditor": "Workflow Editor",
|
||||
"imageViewer": "Image Viewer",
|
||||
"canvas": "Canvas"
|
||||
},
|
||||
"launchpad": {
|
||||
"workflowsTitle": "Go deep with Workflows.",
|
||||
"upscalingTitle": "Upscale and add detail.",
|
||||
@@ -2580,28 +2559,6 @@
|
||||
"generateTitle": "Generate images from text prompts.",
|
||||
"modelGuideText": "Want to learn what prompts work best for each model?",
|
||||
"modelGuideLink": "Check out our Model Guide.",
|
||||
"createNewWorkflowFromScratch": "Create a new Workflow from scratch",
|
||||
"browseAndLoadWorkflows": "Browse and load existing workflows",
|
||||
"addStyleRef": {
|
||||
"title": "Add a Style Reference",
|
||||
"description": "Add an image to transfer its look."
|
||||
},
|
||||
"editImage": {
|
||||
"title": "Edit Image",
|
||||
"description": "Add an image to refine."
|
||||
},
|
||||
"generateFromText": {
|
||||
"title": "Generate from Text",
|
||||
"description": "Enter a prompt and Invoke."
|
||||
},
|
||||
"useALayoutImage": {
|
||||
"title": "Use a Layout Image",
|
||||
"description": "Add an image to control composition."
|
||||
},
|
||||
"generate": {
|
||||
"canvasCalloutTitle": "Looking to get more control, edit, and iterate on your images?",
|
||||
"canvasCalloutLink": "Navigate to Canvas for more capabilities."
|
||||
},
|
||||
"workflows": {
|
||||
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
|
||||
"learnMoreLink": "Learn more about creating workflows",
|
||||
@@ -2638,13 +2595,6 @@
|
||||
"upscaleModel": "Upscale Model",
|
||||
"model": "Model",
|
||||
"scale": "Scale",
|
||||
"creativityAndStructure": {
|
||||
"title": "Creativity & Structure Defaults",
|
||||
"conservative": "Conservative",
|
||||
"balanced": "Balanced",
|
||||
"creative": "Creative",
|
||||
"artistic": "Artistic"
|
||||
},
|
||||
"helpText": {
|
||||
"promptAdvice": "When upscaling, use a prompt that describes the medium and style. Avoid describing specific content details in the image.",
|
||||
"styleAdvice": "Upscaling works best with the general style of your image."
|
||||
@@ -2689,8 +2639,10 @@
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"items": [
|
||||
"Canvas: Color Picker does not sample alpha, bbox respects aspect ratio lock when resizing shuffle button for number fields in Workflow Builder, hide pixel dimension sliders when using a model that doesn't support them",
|
||||
"Workflows: Add a Shuffle button to number input fields"
|
||||
"New setting to send all Canvas generations directly to the Gallery.",
|
||||
"New Invert Mask (Shift+V) and Fit BBox to Mask (Shift+B) capabilities.",
|
||||
"Expanded support for Model Thumbnails and configurations.",
|
||||
"Various other quality of life updates and fixes"
|
||||
],
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
|
||||
@@ -399,6 +399,7 @@
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"canvas": "Lienzo",
|
||||
"generation": "Generación",
|
||||
"queue": "Cola",
|
||||
"workflows": "Flujos de trabajo",
|
||||
"models": "Modelos",
|
||||
|
||||
@@ -1820,6 +1820,7 @@
|
||||
"upscaling": "Agrandissement",
|
||||
"gallery": "Galerie",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
|
||||
"generation": "Génération",
|
||||
"workflows": "Workflows",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
|
||||
"models": "Modèles",
|
||||
|
||||
@@ -128,10 +128,7 @@
|
||||
"search": "Cerca",
|
||||
"clear": "Cancella",
|
||||
"compactView": "Vista compatta",
|
||||
"fullView": "Vista completa",
|
||||
"removeNegativePrompt": "Rimuovi prompt negativo",
|
||||
"addNegativePrompt": "Aggiungi prompt negativo",
|
||||
"selectYourModel": "Seleziona il modello"
|
||||
"fullView": "Vista completa"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -257,8 +254,8 @@
|
||||
"desc": "Attiva/disattiva il pannello destro."
|
||||
},
|
||||
"resetPanelLayout": {
|
||||
"title": "Ripristina lo schema del pannello",
|
||||
"desc": "Ripristina le dimensioni e lo schema predefiniti dei pannelli sinistro e destro."
|
||||
"title": "Ripristina il layout del pannello",
|
||||
"desc": "Ripristina le dimensioni e il layout predefiniti dei pannelli sinistro e destro."
|
||||
},
|
||||
"togglePanels": {
|
||||
"title": "Attiva/disattiva i pannelli",
|
||||
@@ -413,14 +410,6 @@
|
||||
"cancelSegmentAnything": {
|
||||
"title": "Annulla Segment Anything",
|
||||
"desc": "Annulla l'operazione Segment Anything corrente."
|
||||
},
|
||||
"fitBboxToLayers": {
|
||||
"title": "Adatta il riquadro di delimitazione ai livelli",
|
||||
"desc": "Regola automaticamente il riquadro di delimitazione della generazione per adattarlo ai livelli visibili"
|
||||
},
|
||||
"toggleBbox": {
|
||||
"title": "Attiva/disattiva la visibilità del riquadro di delimitazione",
|
||||
"desc": "Nascondi o mostra il riquadro di delimitazione della generazione"
|
||||
}
|
||||
},
|
||||
"workflows": {
|
||||
@@ -550,10 +539,6 @@
|
||||
"galleryNavUpAlt": {
|
||||
"desc": "Uguale a Naviga verso l'alto, ma seleziona l'immagine da confrontare, aprendo la modalità di confronto se non è già aperta.",
|
||||
"title": "Naviga verso l'alto (Confronta immagine)"
|
||||
},
|
||||
"starImage": {
|
||||
"desc": "Aggiungi/Rimuovi contrassegno all'immagine selezionata.",
|
||||
"title": "Aggiungi / Rimuovi contrassegno immagine"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -722,10 +707,7 @@
|
||||
"bundleDescription": "Ogni pacchetto include modelli essenziali per ogni famiglia di modelli e modelli base selezionati per iniziare.",
|
||||
"browseAll": "Oppure scopri tutti i modelli disponibili:"
|
||||
},
|
||||
"launchpadTab": "Rampa di lancio",
|
||||
"installBundle": "Installa pacchetto",
|
||||
"installBundleMsg1": "Vuoi davvero installare il pacchetto {{bundleName}}?",
|
||||
"installBundleMsg2": "Questo pacchetto installerà i seguenti {{count}} modelli:"
|
||||
"launchpadTab": "Rampa di lancio"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -812,7 +794,7 @@
|
||||
"modelIncompatibleScaledBboxWidth": "La larghezza scalata del riquadro è {{width}} ma {{model}} richiede multipli di {{multiple}}",
|
||||
"modelIncompatibleScaledBboxHeight": "L'altezza scalata del riquadro è {{height}} ma {{model}} richiede multipli di {{multiple}}",
|
||||
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Accedi alle impostazioni del tuo account per effettuare l'upgrade.",
|
||||
"fluxKontextMultipleReferenceImages": "È possibile utilizzare solo 1 immagine di riferimento alla volta con FLUX Kontext tramite BFL API",
|
||||
"fluxKontextMultipleReferenceImages": "È possibile utilizzare solo 1 immagine di riferimento alla volta con Flux Kontext",
|
||||
"promptExpansionResultPending": "Accetta o ignora il risultato dell'espansione del prompt",
|
||||
"promptExpansionPending": "Espansione del prompt in corso"
|
||||
},
|
||||
@@ -842,8 +824,7 @@
|
||||
"coherenceMinDenoise": "Min rid. rumore",
|
||||
"recallMetadata": "Richiama i metadati",
|
||||
"disabledNoRasterContent": "Disabilitato (nessun contenuto Raster)",
|
||||
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Visita le <LinkComponent>impostazioni account</LinkComponent> per effettuare l'upgrade.",
|
||||
"useClipSkip": "Usa CLIP Skip"
|
||||
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Visita le <LinkComponent>impostazioni account</LinkComponent> per effettuare l'upgrade."
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@@ -1181,19 +1162,7 @@
|
||||
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato",
|
||||
"missingSourceOrTargetHandle": "Identificatore del nodo sorgente o di destinazione mancante",
|
||||
"layout": {
|
||||
"alignmentDR": "In basso a destra",
|
||||
"autoLayout": "Schema automatico",
|
||||
"nodeSpacing": "Spaziatura nodi",
|
||||
"layerSpacing": "Spaziatura livelli",
|
||||
"layeringStrategy": "Strategia livelli",
|
||||
"longestPath": "Percorso più lungo",
|
||||
"layoutDirection": "Direzione schema",
|
||||
"layoutDirectionRight": "A destra",
|
||||
"layoutDirectionDown": "In basso",
|
||||
"alignment": "Allineamento nodi",
|
||||
"alignmentUL": "In alto a sinistra",
|
||||
"alignmentDL": "In basso a sinistra",
|
||||
"alignmentUR": "In alto a destra"
|
||||
"alignmentDR": "In basso a destra"
|
||||
}
|
||||
},
|
||||
"boards": {
|
||||
@@ -1242,8 +1211,7 @@
|
||||
"updateBoardError": "Errore durante l'aggiornamento della bacheca",
|
||||
"uncategorizedImages": "Immagini non categorizzate",
|
||||
"deleteAllUncategorizedImages": "Elimina tutte le immagini non categorizzate",
|
||||
"deletedImagesCannotBeRestored": "Le immagini eliminate non possono essere ripristinate.",
|
||||
"locateInGalery": "Trova nella Galleria"
|
||||
"deletedImagesCannotBeRestored": "Le immagini eliminate non possono essere ripristinate."
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
@@ -1272,7 +1240,7 @@
|
||||
"batchQueuedDesc_other": "Aggiunte {{count}} sessioni a {{direction}} della coda",
|
||||
"graphQueued": "Grafico in coda",
|
||||
"batch": "Lotto",
|
||||
"clearQueueAlertDialog": "La cancellazione della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati e l'area di lavoro della Tela verrà reimpostata.",
|
||||
"clearQueueAlertDialog": "Lo svuotamento della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati.",
|
||||
"pending": "In attesa",
|
||||
"completedIn": "Completato in",
|
||||
"resumeFailed": "Problema nel riavvio dell'elaborazione",
|
||||
@@ -1328,8 +1296,7 @@
|
||||
"retrySucceeded": "Elemento rieseguito",
|
||||
"retryItem": "Riesegui elemento",
|
||||
"retryFailed": "Problema riesecuzione elemento",
|
||||
"credits": "Crediti",
|
||||
"cancelAllExceptCurrent": "Annulla tutto tranne quello corrente"
|
||||
"credits": "Crediti"
|
||||
},
|
||||
"models": {
|
||||
"noMatchingModels": "Nessun modello corrispondente",
|
||||
@@ -1744,7 +1711,7 @@
|
||||
"structure": {
|
||||
"heading": "Struttura",
|
||||
"paragraphs": [
|
||||
"La struttura determina quanto l'immagine finale rispecchierà lo schema dell'originale. Un valore struttura basso permette cambiamenti significativi, mentre un valore struttura alto conserva la composizione e lo schema originali."
|
||||
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Una struttura bassa permette cambiamenti significativi, mentre una struttura alta conserva la composizione e il layout originali."
|
||||
]
|
||||
},
|
||||
"fluxDevLicense": {
|
||||
@@ -1910,7 +1877,7 @@
|
||||
"opened": "Aperto",
|
||||
"convertGraph": "Converti grafico",
|
||||
"loadWorkflow": "$t(common.load) Flusso di lavoro",
|
||||
"autoLayout": "Schema automatico",
|
||||
"autoLayout": "Disposizione automatica",
|
||||
"loadFromGraph": "Carica il flusso di lavoro dal grafico",
|
||||
"userWorkflows": "Flussi di lavoro utente",
|
||||
"projectWorkflows": "Flussi di lavoro del progetto",
|
||||
@@ -1990,10 +1957,7 @@
|
||||
"publishInProgress": "Pubblicazione in corso",
|
||||
"selectingOutputNode": "Selezione del nodo di uscita",
|
||||
"selectingOutputNodeDesc": "Fare clic su un nodo per selezionarlo come nodo di uscita del flusso di lavoro.",
|
||||
"errorWorkflowHasUnpublishableNodes": "Il flusso di lavoro ha nodi di estrazione lotto, generatore o metadati",
|
||||
"showShuffle": "Mostra Mescola",
|
||||
"shuffle": "Mescola",
|
||||
"removeFromForm": "Rimuovi dal modulo"
|
||||
"errorWorkflowHasUnpublishableNodes": "Il flusso di lavoro ha nodi di estrazione lotto, generatore o metadati"
|
||||
},
|
||||
"loadMore": "Carica altro",
|
||||
"searchPlaceholder": "Cerca per nome, descrizione o etichetta",
|
||||
@@ -2474,8 +2438,7 @@
|
||||
"ipAdapterIncompatibleBaseModel": "modello base dell'immagine di riferimento incompatibile",
|
||||
"ipAdapterNoImageSelected": "nessuna immagine di riferimento selezionata",
|
||||
"rgAutoNegativeNotSupported": "Auto-Negativo non supportato per il modello base selezionato",
|
||||
"fluxFillIncompatibleWithControlLoRA": "Il controllo LoRA non è compatibile con FLUX Fill",
|
||||
"bboxHidden": "Il riquadro di delimitazione è nascosto (Shift+o per attivarlo)"
|
||||
"fluxFillIncompatibleWithControlLoRA": "Il controllo LoRA non è compatibile con FLUX Fill"
|
||||
},
|
||||
"pasteTo": "Incolla su",
|
||||
"pasteToBboxDesc": "Nuovo livello (nel riquadro di delimitazione)",
|
||||
@@ -2515,12 +2478,11 @@
|
||||
"off": "Spento"
|
||||
},
|
||||
"invertMask": "Inverti maschera",
|
||||
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere",
|
||||
"maxRefImages": "Max Immagini di rif.to",
|
||||
"useAsReferenceImage": "Usa come immagine di riferimento"
|
||||
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generation": "Generazione",
|
||||
"canvas": "Tela",
|
||||
"workflows": "Flussi di lavoro",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
|
||||
@@ -2529,8 +2491,7 @@
|
||||
"queue": "Coda",
|
||||
"upscaling": "Amplia",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
|
||||
"gallery": "Galleria",
|
||||
"generate": "Genera"
|
||||
"gallery": "Galleria"
|
||||
},
|
||||
"launchpad": {
|
||||
"workflowsTitle": "Approfondisci i flussi di lavoro.",
|
||||
@@ -2578,43 +2539,8 @@
|
||||
"helpText": {
|
||||
"promptAdvice": "Durante l'ampliamento, utilizza un prompt che descriva il mezzo e lo stile. Evita di descrivere dettagli specifici del contenuto dell'immagine.",
|
||||
"styleAdvice": "L'ampliamento funziona meglio con lo stile generale dell'immagine."
|
||||
},
|
||||
"creativityAndStructure": {
|
||||
"title": "Creatività e struttura predefinite",
|
||||
"conservative": "Conservativo",
|
||||
"balanced": "Bilanciato",
|
||||
"creative": "Creativo",
|
||||
"artistic": "Artistico"
|
||||
}
|
||||
},
|
||||
"createNewWorkflowFromScratch": "Crea un nuovo flusso di lavoro da zero",
|
||||
"browseAndLoadWorkflows": "Sfoglia e carica i flussi di lavoro esistenti",
|
||||
"addStyleRef": {
|
||||
"title": "Aggiungi un riferimento di stile",
|
||||
"description": "Aggiungi un'immagine per trasferirne l'aspetto."
|
||||
},
|
||||
"editImage": {
|
||||
"title": "Modifica immagine",
|
||||
"description": "Aggiungi un'immagine da perfezionare."
|
||||
},
|
||||
"generateFromText": {
|
||||
"title": "Genera da testo",
|
||||
"description": "Inserisci un prompt e genera."
|
||||
},
|
||||
"useALayoutImage": {
|
||||
"description": "Aggiungi un'immagine per controllare la composizione.",
|
||||
"title": "Usa una immagine guida"
|
||||
},
|
||||
"generate": {
|
||||
"canvasCalloutTitle": "Vuoi avere più controllo, modificare e affinare le tue immagini?",
|
||||
"canvasCalloutLink": "Per ulteriori funzionalità, vai su Tela."
|
||||
}
|
||||
},
|
||||
"panels": {
|
||||
"launchpad": "Rampa di lancio",
|
||||
"workflowEditor": "Editor del flusso di lavoro",
|
||||
"imageViewer": "Visualizzatore immagini",
|
||||
"canvas": "Tela"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
@@ -2705,8 +2631,9 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"Vari QoL: attiva/disattiva la visibilità del Riquadro di delimitazione, evidenzia i nodi con errori, evita di aggiungere più volte i campi dei nodi al modulo Generatore, i metadati CLIP Skip ora richiamabili",
|
||||
"Utilizzo ridotto di VRAM per immagini di riferimento Kontext multiple e codifica VAE"
|
||||
"Genera immagini più velocemente con le nuove Rampe di lancio e una scheda Genera semplificata.",
|
||||
"Modifica con prompt utilizzando Flux Kontext Dev.",
|
||||
"Esporta in PSD, nascondi sovrapposizioni in blocco, organizza modelli e immagini: il tutto in un'interfaccia riprogettata e pensata per il controllo."
|
||||
]
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -1783,6 +1783,7 @@
|
||||
"workflows": "ワークフロー",
|
||||
"models": "モデル",
|
||||
"gallery": "ギャラリー",
|
||||
"generation": "生成",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"upscaling": "アップスケーリング",
|
||||
|
||||
@@ -1931,6 +1931,7 @@
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generation": "Генерация",
|
||||
"canvas": "Холст",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
|
||||
"models": "Модели",
|
||||
|
||||
@@ -55,8 +55,7 @@
|
||||
"assetsWithCount_other": "{{count}} tài nguyên",
|
||||
"uncategorizedImages": "Ảnh Chưa Sắp Xếp",
|
||||
"deleteAllUncategorizedImages": "Xoá Tất Cả Ảnh Chưa Sắp Xếp",
|
||||
"deletedImagesCannotBeRestored": "Ảnh đã xoá không thể phục hồi lại.",
|
||||
"locateInGalery": "Vị Trí Ở Thư Viện Ảnh"
|
||||
"deletedImagesCannotBeRestored": "Ảnh đã xoá không thể phục hồi lại."
|
||||
},
|
||||
"gallery": {
|
||||
"swapImages": "Đổi Hình Ảnh",
|
||||
@@ -253,10 +252,7 @@
|
||||
"clear": "Dọn Dẹp",
|
||||
"compactView": "Chế Độ Xem Gọn",
|
||||
"fullView": "Chế Độ Xem Đầy Đủ",
|
||||
"options_withCount_other": "{{count}} thiết lập",
|
||||
"removeNegativePrompt": "Xóa Lệnh Tiêu Cực",
|
||||
"addNegativePrompt": "Thêm Lệnh Tiêu Cực",
|
||||
"selectYourModel": "Chọn Model"
|
||||
"options_withCount_other": "{{count}} thiết lập"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Trigger Cho Lệnh",
|
||||
@@ -303,7 +299,7 @@
|
||||
"pruneTooltip": "Cắt bớt {{item_count}} mục đã hoàn tất",
|
||||
"pruneSucceeded": "Đã cắt bớt {{item_count}} mục đã hoàn tất khỏi hàng",
|
||||
"clearTooltip": "Huỷ Và Dọn Dẹp Tất Cả Mục",
|
||||
"clearQueueAlertDialog": "Dọn dẹp hàng đợi sẽ ngay lập tức huỷ tất cả mục đang xử lý và làm sạch hàng hoàn toàn. Bộ lọc đang chờ xử lý sẽ bị huỷ bỏ và Vùng Dựng Canva sẽ được khởi động lại.",
|
||||
"clearQueueAlertDialog": "Dọn dẹp hàng đợi sẽ ngay lập tức huỷ tất cả mục đang xử lý và làm sạch hàng hoàn toàn. Bộ lọc đang chờ xử lý sẽ bị huỷ bỏ.",
|
||||
"session": "Phiên",
|
||||
"item": "Mục",
|
||||
"resumeFailed": "Có Vấn Đề Khi Tiếp Tục Bộ Xử Lý",
|
||||
@@ -347,14 +343,13 @@
|
||||
"retrySucceeded": "Mục Đã Thử Lại",
|
||||
"retryFailed": "Có Vấn Đề Khi Thử Lại Mục",
|
||||
"retryItem": "Thử Lại Mục",
|
||||
"credits": "Nguồn",
|
||||
"cancelAllExceptCurrent": "Huỷ Bỏ Tất Cả Ngoại Trừ Mục Hiện Tại"
|
||||
"credits": "Nguồn"
|
||||
},
|
||||
"hotkeys": {
|
||||
"canvas": {
|
||||
"fitLayersToCanvas": {
|
||||
"title": "Xếp Vừa Layers Vào Canvas",
|
||||
"desc": "Căn chỉnh để góc nhìn vừa vặn với tất cả layer nhìn thấy dược."
|
||||
"desc": "Căn chỉnh để góc nhìn vừa vặn với tất cả layer."
|
||||
},
|
||||
"setZoomTo800Percent": {
|
||||
"desc": "Phóng to canvas lên 800%.",
|
||||
@@ -478,32 +473,6 @@
|
||||
"toggleNonRasterLayers": {
|
||||
"title": "Bật/Tắt Layer Không Thuộc Dạng Raster",
|
||||
"desc": "Hiện hoặc ẩn tất cả layer không thuộc dạng raster (Layer Điều Khiển Được, Lớp Phủ Inpaint, Chỉ Dẫn Khu Vực)."
|
||||
},
|
||||
"invertMask": {
|
||||
"title": "Đảo Ngược Lớp Phủ",
|
||||
"desc": "Đảo ngược lớp phủ inpaint được chọn, tạo một lớp phủ mới với độ trong suốt đối nghịch."
|
||||
},
|
||||
"fitBboxToMasks": {
|
||||
"title": "Xếp Vừa Hộp Giới Hạn Vào Lớp Phủ",
|
||||
"desc": "Tự động điểu chỉnh hộp giới hạn tạo sinh vừa vặn vào lớp phủ inpaint nhìn thấy được"
|
||||
},
|
||||
"applySegmentAnything": {
|
||||
"title": "Áp Dụng Segment Anything",
|
||||
"desc": "Áp dụng lớp phủ Segment Anything hiện tại.",
|
||||
"key": "enter"
|
||||
},
|
||||
"cancelSegmentAnything": {
|
||||
"title": "Huỷ Segment Anything",
|
||||
"desc": "Huỷ hoạt động Segment Anything hiện tại.",
|
||||
"key": "esc"
|
||||
},
|
||||
"fitBboxToLayers": {
|
||||
"title": "Xếp Vừa Hộp Giới Hạn Vào Layer",
|
||||
"desc": "Tự động điểu chỉnh hộp giới hạn tạo sinh vừa vặn vào layer nhìn thấy được"
|
||||
},
|
||||
"toggleBbox": {
|
||||
"title": "Bật/Tắt Hiển Thị Hộp Giới Hạn",
|
||||
"desc": "Ẩn hoặc hiện hộp giới hạn tạo sinh"
|
||||
}
|
||||
},
|
||||
"workflows": {
|
||||
@@ -633,10 +602,6 @@
|
||||
"clearSelection": {
|
||||
"desc": "Xoá phần lựa chọn hiện tại nếu có.",
|
||||
"title": "Xoá Phần Lựa Chọn"
|
||||
},
|
||||
"starImage": {
|
||||
"title": "Dấu/Huỷ Sao Hình Ảnh",
|
||||
"desc": "Đánh dấu sao hoặc huỷ đánh dấu sao ảnh được chọn."
|
||||
}
|
||||
},
|
||||
"app": {
|
||||
@@ -696,11 +661,6 @@
|
||||
"selectModelsTab": {
|
||||
"desc": "Chọn tab Model (Mô Hình).",
|
||||
"title": "Chọn Tab Model"
|
||||
},
|
||||
"selectGenerateTab": {
|
||||
"title": "Chọn Tab Tạo Sinh",
|
||||
"desc": "Chọn tab Tạo Sinh.",
|
||||
"key": "1"
|
||||
}
|
||||
},
|
||||
"searchHotkeys": "Tìm Phím tắt",
|
||||
@@ -877,10 +837,7 @@
|
||||
"stableDiffusion15": "Stable Diffusion 1.5",
|
||||
"sdxl": "SDXL",
|
||||
"fluxDev": "FLUX.1 dev"
|
||||
},
|
||||
"installBundle": "Tải Xuống Gói",
|
||||
"installBundleMsg1": "Bạn có chắc chắn muốn tải xuống gói {{bundleName}}?",
|
||||
"installBundleMsg2": "Gói này sẽ tải xuống {{count}} model sau đây:"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"guidance": "Hướng Dẫn",
|
||||
@@ -913,8 +870,7 @@
|
||||
"recallParameters": "Gợi Nhớ Tham Số",
|
||||
"scheduler": "Scheduler",
|
||||
"noMetaData": "Không tìm thấy metadata",
|
||||
"imageDimensions": "Kích Thước Ảnh",
|
||||
"clipSkip": "$t(parameters.clipSkip)"
|
||||
"imageDimensions": "Kích Thước Ảnh"
|
||||
},
|
||||
"accordions": {
|
||||
"generation": {
|
||||
@@ -1134,23 +1090,7 @@
|
||||
"unknownField_withName": "Vùng Dữ Liệu Không Rõ \"{{name}}\"",
|
||||
"unexpectedField_withName": "Sai Vùng Dữ Liệu \"{{name}}\"",
|
||||
"unknownFieldEditWorkflowToFix_withName": "Workflow chứa vùng dữ liệu không rõ \"{{name}}\".\nHãy biên tập workflow để sửa lỗi.",
|
||||
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\"",
|
||||
"layout": {
|
||||
"autoLayout": "Bố Cục Tự Động",
|
||||
"layeringStrategy": "Chiến Lược Phân Layer",
|
||||
"networkSimplex": "Network Simplex",
|
||||
"longestPath": "Đường Đi Dài Nhất",
|
||||
"nodeSpacing": "Khoảng Cách Node",
|
||||
"layerSpacing": "Khoảng Cách Layer",
|
||||
"layoutDirection": "Hướng Bố Cục",
|
||||
"layoutDirectionRight": "Phải",
|
||||
"layoutDirectionDown": "Xuống",
|
||||
"alignment": "Căn Chỉnh Node",
|
||||
"alignmentUL": "Trên Cùng Bên Trái",
|
||||
"alignmentDL": "Dưới Cùng Bên Trái",
|
||||
"alignmentUR": "Trên Cùng Bên Phải",
|
||||
"alignmentDR": "Dưới Cùng Bên Phải"
|
||||
}
|
||||
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\""
|
||||
},
|
||||
"popovers": {
|
||||
"paramCFGRescaleMultiplier": {
|
||||
@@ -1657,7 +1597,7 @@
|
||||
"modelIncompatibleScaledBboxHeight": "Chiều dài hộp giới hạn theo tỉ lệ là {{height}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
|
||||
"modelIncompatibleScaledBboxWidth": "Chiều rộng hộp giới hạn theo tỉ lệ là {{width}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
|
||||
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần thiết lập tài khoản để nâng cấp.",
|
||||
"fluxKontextMultipleReferenceImages": "Chỉ có thể dùng 1 Ảnh Mẫu cùng lúc với LUX Kontext thông qua BFL API",
|
||||
"fluxKontextMultipleReferenceImages": "Chỉ có thể dùng 1 Ảnh Mẫu cùng lúc với Flux Kontext",
|
||||
"promptExpansionPending": "Trong quá trình mở rộng lệnh",
|
||||
"promptExpansionResultPending": "Hãy chấp thuận hoặc huỷ bỏ kết quả mở rộng lệnh của bạn"
|
||||
},
|
||||
@@ -1723,8 +1663,7 @@
|
||||
"upscaling": "Upscale",
|
||||
"tileSize": "Kích Thước Khối",
|
||||
"disabledNoRasterContent": "Đã Tắt (Không Có Nội Dung Dạng Raster)",
|
||||
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần <LinkComponent>thiết lập tài khoản</LinkComponent> để nâng cấp.",
|
||||
"useClipSkip": "Dùng CLIP Skip"
|
||||
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần <LinkComponent>thiết lập tài khoản</LinkComponent> để nâng cấp."
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"seedBehaviour": {
|
||||
@@ -2215,8 +2154,7 @@
|
||||
"rgReferenceImagesNotSupported": "Ảnh Mẫu Khu Vực không được hỗ trợ cho model cơ sở được chọn",
|
||||
"rgAutoNegativeNotSupported": "Tự Động Đảo Chiều không được hỗ trợ cho model cơ sở được chọn",
|
||||
"rgNoRegion": "không có khu vực được vẽ",
|
||||
"fluxFillIncompatibleWithControlLoRA": "LoRA Điều Khiển Được không tương tích với FLUX Fill",
|
||||
"bboxHidden": "Hộp giới hạn đang ẩn (shift+o để bật/tắt)"
|
||||
"fluxFillIncompatibleWithControlLoRA": "LoRA Điều Khiển Được không tương tích với FLUX Fill"
|
||||
},
|
||||
"pasteTo": "Dán Vào",
|
||||
"pasteToAssets": "Tài Nguyên",
|
||||
@@ -2254,11 +2192,7 @@
|
||||
"off": "Tắt",
|
||||
"switchOnStart": "Khi Bắt Đầu",
|
||||
"switchOnFinish": "Khi Kết Thúc"
|
||||
},
|
||||
"fitBboxToMasks": "Xếp Vừa Hộp Giới Hạn Vào Lớp Phủ",
|
||||
"invertMask": "Đảo Ngược Lớp Phủ",
|
||||
"maxRefImages": "Ảnh Mẫu Tối Đa",
|
||||
"useAsReferenceImage": "Dùng Làm Ảnh Mẫu"
|
||||
}
|
||||
},
|
||||
"stylePresets": {
|
||||
"negativePrompt": "Lệnh Tiêu Cực",
|
||||
@@ -2420,28 +2354,20 @@
|
||||
"noValidLayerAdapters": "Không có Layer Adaper Phù Hợp",
|
||||
"promptGenerationStarted": "Trình tạo sinh lệnh khởi động",
|
||||
"uploadAndPromptGenerationFailed": "Thất bại khi tải lên ảnh để tạo sinh lệnh",
|
||||
"promptExpansionFailed": "Có vấn đề xảy ra. Hãy thử mở rộng lệnh lại.",
|
||||
"maskInverted": "Đã Đảo Ngược Lớp Phủ",
|
||||
"maskInvertFailed": "Thất Bại Khi Đảo Ngược Lớp Phủ",
|
||||
"noVisibleMasks": "Không Có Lớp Phủ Đang Hiển Thị",
|
||||
"noVisibleMasksDesc": "Tạo hoặc bật ít nhất một lớp phủ inpaint để đảo ngược",
|
||||
"noInpaintMaskSelected": "Không Có Lớp Phủ Inpant Được Chọn",
|
||||
"noInpaintMaskSelectedDesc": "Chọn một lớp phủ inpaint để đảo ngược",
|
||||
"invalidBbox": "Hộp Giới Hạn Không Hợp Lệ",
|
||||
"invalidBboxDesc": "Hợp giới hạn có kích thước không hợp lệ"
|
||||
"promptExpansionFailed": "Có vấn đề xảy ra. Hãy thử mở rộng lệnh lại."
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"gallery": "Thư Viện Ảnh",
|
||||
"models": "Models",
|
||||
"generation": "Generation (Máy Tạo Sinh)",
|
||||
"upscaling": "Upscale (Nâng Cấp Chất Lượng Hình Ảnh)",
|
||||
"canvas": "Canvas (Vùng Ảnh)",
|
||||
"upscalingTab": "$t(common.tab) $t(ui.tabs.upscaling)",
|
||||
"modelsTab": "$t(common.tab) $t(ui.tabs.models)",
|
||||
"queue": "Queue (Hàng Đợi)",
|
||||
"workflows": "Workflow (Luồng Làm Việc)",
|
||||
"workflowsTab": "$t(common.tab) $t(ui.tabs.workflows)",
|
||||
"generate": "Tạo Sinh"
|
||||
"workflowsTab": "$t(common.tab) $t(ui.tabs.workflows)"
|
||||
},
|
||||
"launchpad": {
|
||||
"workflowsTitle": "Đi sâu hơn với Workflow.",
|
||||
@@ -2489,43 +2415,8 @@
|
||||
"promptAdvice": "Khi upscale, dùng lệnh để mô tả phương thức và phong cách. Tránh mô tả các chi tiết cụ thể trong ảnh.",
|
||||
"styleAdvice": "Upscale thích hợp nhất cho phong cách chung của ảnh."
|
||||
},
|
||||
"scale": "Kích Thước",
|
||||
"creativityAndStructure": {
|
||||
"title": "Độ Sáng Tạo & Cấu Trúc Mặc Định",
|
||||
"conservative": "Bảo toàn",
|
||||
"balanced": "Cân bằng",
|
||||
"creative": "Sáng tạo",
|
||||
"artistic": "Thẩm mỹ"
|
||||
}
|
||||
},
|
||||
"createNewWorkflowFromScratch": "Tạo workflow mới từ đầu",
|
||||
"browseAndLoadWorkflows": "Duyệt và tải workflow có sẵn",
|
||||
"addStyleRef": {
|
||||
"title": "Thêm Phong Cách Mẫu",
|
||||
"description": "Thêm ảnh để chuyển đổi diện mạo của nó."
|
||||
},
|
||||
"editImage": {
|
||||
"title": "Biên Tập Ảnh",
|
||||
"description": "Thêm ảnh để chỉnh sửa."
|
||||
},
|
||||
"generateFromText": {
|
||||
"title": "Tạo Sinh Từ Chữ",
|
||||
"description": "Nhập lệnh vào và Kích Hoạt."
|
||||
},
|
||||
"useALayoutImage": {
|
||||
"title": "Dùng Bố Cục Ảnh",
|
||||
"description": "Thêm ảnh để điều khiển bố cục."
|
||||
},
|
||||
"generate": {
|
||||
"canvasCalloutTitle": "Đang tìm cách để điều khiển, chỉnh sửa, và làm lại ảnh?",
|
||||
"canvasCalloutLink": "Vào Canvas cho nhiều tính năng hơn."
|
||||
"scale": "Kích Thước"
|
||||
}
|
||||
},
|
||||
"panels": {
|
||||
"launchpad": "Launchpad",
|
||||
"workflowEditor": "Trình Biên Tập Workflow",
|
||||
"imageViewer": "Trình Xem Ảnh",
|
||||
"canvas": "Canvas"
|
||||
}
|
||||
},
|
||||
"workflows": {
|
||||
@@ -2640,10 +2531,7 @@
|
||||
"publishingValidationRunInProgress": "Quá trình kiểm tra tính hợp lệ đang diễn ra.",
|
||||
"selectingOutputNodeDesc": "Bấm vào node để biến nó thành node đầu ra của workflow.",
|
||||
"selectingOutputNode": "Chọn node đầu ra",
|
||||
"errorWorkflowHasUnpublishableNodes": "Workflow có lô node, node sản sinh, hoặc node tách metadata",
|
||||
"removeFromForm": "Xóa Khỏi Vùng Nhập",
|
||||
"showShuffle": "Hiện Xáo Trộn",
|
||||
"shuffle": "Xáo Trộn"
|
||||
"errorWorkflowHasUnpublishableNodes": "Workflow có lô node, node sản sinh, hoặc node tách metadata"
|
||||
},
|
||||
"yourWorkflows": "Workflow Của Bạn",
|
||||
"browseWorkflows": "Khám Phá Workflow",
|
||||
@@ -2700,8 +2588,9 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"Misc QoL: Bật/Tắt hiển thị hộp giới hạn, đánh dấu node bị lỗi, chặn lỗi thêm node vào vùng nhập nhiều lần, khả năng đọc lại metadata của CLIP Skip",
|
||||
"Giảm lượng tiêu thụ VRAM cho các ảnh mẫu Kontext và mã hóa VAE"
|
||||
"Tạo sinh ảnh nhanh hơn với Launchpad và thẻ Tạo Sinh đã cơ bản hoá.",
|
||||
"Biên tập với lệnh bằng Flux Kontext Dev.",
|
||||
"Xuất ra file PSD, ẩn số lượng lớn lớp phủ, sắp xếp model & ảnh — tất cả cho một giao diện đã thiết kế lại để chuyên điều khiển."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
|
||||
@@ -1772,6 +1772,7 @@
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generation": "生成",
|
||||
"queue": "队列",
|
||||
"canvas": "画布",
|
||||
"upscaling": "放大中",
|
||||
|
||||
@@ -3,9 +3,9 @@ import { useStore } from '@nanostores/react';
|
||||
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
|
||||
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
|
||||
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { clearStorage } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||
import { AppContent } from 'features/ui/components/AppContent';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
@@ -21,12 +21,13 @@ interface Props {
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
const didStudioInit = useStore($didStudioInit);
|
||||
const clearStorage = useClearStorage();
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
clearStorage();
|
||||
location.reload();
|
||||
return false;
|
||||
}, []);
|
||||
}, [clearStorage]);
|
||||
|
||||
return (
|
||||
<ThemeLocaleProvider>
|
||||
|
||||
@@ -5,7 +5,6 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
|
||||
import type { LoggingOverrides } from 'app/logging/logger';
|
||||
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
|
||||
import { addStorageListeners } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
@@ -36,7 +35,7 @@ import {
|
||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||
import type { ToastConfig } from 'features/toast/toast';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import React, { lazy, memo, useEffect, useLayoutEffect, useState } from 'react';
|
||||
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
import { $socketOptions } from 'services/events/stores';
|
||||
@@ -71,7 +70,6 @@ interface Props extends PropsWithChildren {
|
||||
* If provided, overrides in-app navigation to the model manager
|
||||
*/
|
||||
onClickGoToModelManager?: () => void;
|
||||
storagePersistDebounce?: number;
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({
|
||||
@@ -98,11 +96,7 @@ const InvokeAIUI = ({
|
||||
loggingOverrides,
|
||||
onClickGoToModelManager,
|
||||
whatsNew,
|
||||
storagePersistDebounce = 300,
|
||||
}: Props) => {
|
||||
const [store, setStore] = useState<ReturnType<typeof createStore> | undefined>(undefined);
|
||||
const [didRehydrate, setDidRehydrate] = useState(false);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
/*
|
||||
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
|
||||
@@ -314,30 +308,22 @@ const InvokeAIUI = ({
|
||||
};
|
||||
}, [isDebugging]);
|
||||
|
||||
const store = useMemo(() => {
|
||||
return createStore(projectId);
|
||||
}, [projectId]);
|
||||
|
||||
useEffect(() => {
|
||||
const onRehydrated = () => {
|
||||
setDidRehydrate(true);
|
||||
};
|
||||
const store = createStore({ persist: true, persistDebounce: storagePersistDebounce, onRehydrated });
|
||||
setStore(store);
|
||||
$store.set(store);
|
||||
if (import.meta.env.MODE === 'development') {
|
||||
window.$store = $store;
|
||||
}
|
||||
const removeStorageListeners = addStorageListeners();
|
||||
return () => {
|
||||
removeStorageListeners();
|
||||
setStore(undefined);
|
||||
$store.set(undefined);
|
||||
if (import.meta.env.MODE === 'development') {
|
||||
window.$store = undefined;
|
||||
}
|
||||
};
|
||||
}, [storagePersistDebounce]);
|
||||
|
||||
if (!store || !didRehydrate) {
|
||||
return <Loading />;
|
||||
}
|
||||
}, [store]);
|
||||
|
||||
return (
|
||||
<React.StrictMode>
|
||||
|
||||
@@ -93,7 +93,5 @@ export const configureLogging = (
|
||||
localStorage.setItem('ROARR_FILTER', filter);
|
||||
}
|
||||
|
||||
const styleOutput = localStorage.getItem('ROARR_STYLE_OUTPUT') === 'false' ? false : true;
|
||||
|
||||
ROARR.write = createLogWriter({ styleOutput });
|
||||
ROARR.write = createLogWriter();
|
||||
};
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
export const STORAGE_PREFIX = '@@invokeai-';
|
||||
export const EMPTY_ARRAY = [];
|
||||
export const EMPTY_OBJECT = {};
|
||||
|
||||
@@ -1,209 +1,40 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $projectId } from 'app/store/nanostores/projectId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { UseStore } from 'idb-keyval';
|
||||
import { createStore as idbCreateStore, del as idbDel, get as idbGet } from 'idb-keyval';
|
||||
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Driver } from 'redux-remember';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { buildV1Url, getBaseUrl } from 'services/api';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('system');
|
||||
// Create a custom idb-keyval store (just needed to customize the name)
|
||||
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
|
||||
|
||||
const getUrl = (endpoint: 'get_by_key' | 'set_by_key' | 'delete', key?: string) => {
|
||||
const baseUrl = getBaseUrl();
|
||||
const query: Record<string, string> = {};
|
||||
if (key) {
|
||||
query['key'] = key;
|
||||
}
|
||||
|
||||
const path = buildV1Url(`client_state/${$queueId.get()}/${endpoint}`, query);
|
||||
const url = `${baseUrl}/${path}`;
|
||||
return url;
|
||||
export const clearIdbKeyValStore = () => {
|
||||
clear($idbKeyValStore.get());
|
||||
};
|
||||
|
||||
const getHeaders = () => {
|
||||
const headers = new Headers();
|
||||
const authToken = $authToken.get();
|
||||
const projectId = $projectId.get();
|
||||
if (authToken) {
|
||||
headers.set('Authorization', `Bearer ${authToken}`);
|
||||
}
|
||||
if (projectId) {
|
||||
headers.set('project-id', projectId);
|
||||
}
|
||||
return headers;
|
||||
};
|
||||
|
||||
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
|
||||
// it when a slice is being persisted and decrementing it when the persistence is done.
|
||||
let persistRefCount = 0;
|
||||
|
||||
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
|
||||
//
|
||||
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
|
||||
// persist config.
|
||||
//
|
||||
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
|
||||
// way to do this directly.
|
||||
//
|
||||
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
|
||||
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
|
||||
// the implementation in `store.ts` for this logic.
|
||||
//
|
||||
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
|
||||
// whole slice, even if the final, _serialized_ slice value is unchanged.
|
||||
//
|
||||
// To avoid unnecessary network requests, we keep track of the last persisted state for each key in this map.
|
||||
// If the value to be persisted is the same as the last persisted value, we will skip the network request.
|
||||
const lastPersistedState = new Map<string, string | undefined>();
|
||||
|
||||
// As of v6.3.0, we use server-backed storage for client state. This replaces the previous IndexedDB-based storage,
|
||||
// which was implemented using `idb-keyval`.
|
||||
//
|
||||
// To facilitate a smooth transition, we implement a migration strategy that attempts to retrieve values from IndexedDB
|
||||
// and persist them to the new server-backed storage. This is done on a best-effort basis.
|
||||
|
||||
// These constants were used in the previous IndexedDB-based storage implementation.
|
||||
const IDB_DB_NAME = 'invoke';
|
||||
const IDB_STORE_NAME = 'invoke-store';
|
||||
const IDB_STORAGE_PREFIX = '@@invokeai-';
|
||||
|
||||
// Lazy store creation
|
||||
let _idbKeyValStore: UseStore | null = null;
|
||||
const getIdbKeyValStore = () => {
|
||||
if (_idbKeyValStore === null) {
|
||||
_idbKeyValStore = idbCreateStore(IDB_DB_NAME, IDB_STORE_NAME);
|
||||
}
|
||||
return _idbKeyValStore;
|
||||
};
|
||||
|
||||
const getIdbKey = (key: string) => {
|
||||
return `${IDB_STORAGE_PREFIX}${key}`;
|
||||
};
|
||||
|
||||
const getItem = async (key: string) => {
|
||||
try {
|
||||
const url = getUrl('get_by_key', key);
|
||||
const headers = getHeaders();
|
||||
const res = await fetch(url, { method: 'GET', headers });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
const value = await res.json();
|
||||
|
||||
// Best-effort migration from IndexedDB to the new storage system
|
||||
log.trace({ key, value }, 'Server-backed storage value retrieved');
|
||||
|
||||
if (!value) {
|
||||
const idbKey = getIdbKey(key);
|
||||
try {
|
||||
// It's a bit tricky to query IndexedDB directly to check if value exists, so we use `idb-keyval` to do it.
|
||||
// Thing is, `idb-keyval` requires you to create a store to query it. End result - we are creating a store
|
||||
// even if we don't use it for anything besides checking if the key is present.
|
||||
const idbKeyValStore = getIdbKeyValStore();
|
||||
const idbValue = await idbGet(idbKey, idbKeyValStore);
|
||||
if (idbValue) {
|
||||
log.debug(
|
||||
{ key, idbKey, idbValue },
|
||||
'No value in server-backed storage, but found value in IndexedDB - attempting migration'
|
||||
);
|
||||
await idbDel(idbKey, idbKeyValStore);
|
||||
await setItem(key, idbValue);
|
||||
log.debug({ key, idbKey, idbValue }, 'Migration successful');
|
||||
return idbValue;
|
||||
}
|
||||
} catch (error) {
|
||||
// Just log if IndexedDB retrieval fails - this is a best-effort migration.
|
||||
log.debug(
|
||||
{ key, idbKey, error: serializeError(error) } as JsonObject,
|
||||
'Error checking for or migrating from IndexedDB'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
lastPersistedState.set(key, value);
|
||||
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`);
|
||||
return value;
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const setItem = async (key: string, value: string) => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
if (lastPersistedState.get(key) === value) {
|
||||
log.trace(
|
||||
{ key, last: lastPersistedState.get(key), next: value },
|
||||
`Skipping persist for ${key} as value is unchanged`
|
||||
);
|
||||
return value;
|
||||
}
|
||||
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
|
||||
const url = getUrl('set_by_key', key);
|
||||
const headers = getHeaders();
|
||||
const res = await fetch(url, { method: 'POST', headers, body: value });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
const resultValue = await res.json();
|
||||
lastPersistedState.set(key, resultValue);
|
||||
return resultValue;
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
value,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const reduxRememberDriver: Driver = { getItem, setItem };
|
||||
|
||||
export const clearStorage = async () => {
|
||||
try {
|
||||
persistRefCount++;
|
||||
const url = getUrl('delete');
|
||||
const headers = getHeaders();
|
||||
const res = await fetch(url, { method: 'POST', headers });
|
||||
if (!res.ok) {
|
||||
throw new Error(`Response status: ${res.status}`);
|
||||
}
|
||||
} catch {
|
||||
log.error('Failed to reset client state');
|
||||
} finally {
|
||||
persistRefCount--;
|
||||
lastPersistedState.clear();
|
||||
if (persistRefCount < 0) {
|
||||
log.trace('Persist ref count is negative, resetting to 0');
|
||||
persistRefCount = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const addStorageListeners = () => {
|
||||
const onBeforeUnload = (e: BeforeUnloadEvent) => {
|
||||
if (persistRefCount > 0) {
|
||||
e.preventDefault();
|
||||
}
|
||||
};
|
||||
window.addEventListener('beforeunload', onBeforeUnload);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('beforeunload', onBeforeUnload);
|
||||
};
|
||||
// Create redux-remember driver, wrapping idb-keyval
|
||||
export const idbKeyValDriver: Driver = {
|
||||
getItem: (key) => {
|
||||
try {
|
||||
return get(key, $idbKeyValStore.get());
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
}
|
||||
},
|
||||
setItem: (key, value) => {
|
||||
try {
|
||||
return set(key, value, $idbKeyValStore.get());
|
||||
} catch (originalError) {
|
||||
throw new StorageError({
|
||||
key,
|
||||
value,
|
||||
projectId: $projectId.get(),
|
||||
originalError,
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -33,9 +33,8 @@ export class StorageError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const errorHandler = (err: PersistError | RehydrateError) => {
|
||||
const log = logger('system');
|
||||
if (err instanceof PersistError) {
|
||||
log.error({ error: serializeError(err) }, 'Problem persisting state');
|
||||
} else if (err instanceof RehydrateError) {
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
|
||||
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
|
||||
const startAppListening = listenerMiddleware.startListening as AppStartListening;
|
||||
|
||||
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
|
||||
|
||||
/**
|
||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||
*
|
||||
* Most side effect logic should live in a listener.
|
||||
*/
|
||||
|
||||
// Image uploaded
|
||||
addImageUploadedFulfilledListener(startAppListening);
|
||||
|
||||
// Image deleted
|
||||
addDeleteBoardAndImagesFulfilledListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
|
||||
// Gallery bulk download
|
||||
addBulkDownloadListeners(startAppListening);
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||
addImageRemovedFromBoardFulfilledListener(startAppListening);
|
||||
addBoardIdSelectedListener(startAppListening);
|
||||
addArchivedOrDeletedBoardListener(startAppListening);
|
||||
|
||||
// Node schemas
|
||||
addGetOpenAPISchemaListener(startAppListening);
|
||||
|
||||
// Models
|
||||
addModelSelectedListener(startAppListening);
|
||||
|
||||
// app startup
|
||||
addAppStartedListener(startAppListening);
|
||||
addModelsLoadedListener(startAppListening);
|
||||
addAppConfigReceivedListener(startAppListening);
|
||||
|
||||
// Ad-hoc upscale workflwo
|
||||
addAdHocPostProcessingRequestedListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/store';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
autoAddBoardIdChanged,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user