Compare commits

..

14 Commits

Author SHA1 Message Date
Kent Keirsey
3a14791da3 bria-ui-updates-wip 2025-07-25 12:58:27 -04:00
Ilan Tchenak
711a579945 fixed schema 2025-07-24 17:22:15 +00:00
Ilan Tchenak
1ac5a24a8a ruff fix 2025-07-24 19:10:29 +03:00
Ubuntu
282df322d5 fixed node issue 2025-07-24 10:56:37 -04:00
Ilan Tchenak
8523ea88f2 moved bria's nodes to invocations folder 2025-07-24 10:56:37 -04:00
Ubuntu
cad97d3da3 Small cosmetic fixes 2025-07-24 10:56:37 -04:00
Ubuntu
efc5a762fc removed unused file 2025-07-24 10:56:37 -04:00
Ubuntu
9131c45645 Added scikit-image required for Bria's OpenposeDetector model 2025-07-24 10:56:37 -04:00
Ilan Tchenak
75ca44d5f9 Add Bria text to image model and controlnet support 2025-07-24 10:56:37 -04:00
Ilan Tchenak
8b08af3949 Setup Probe and UI to accept bria controlnet models 2025-07-24 10:56:37 -04:00
Ubuntu
df9ea8dcc1 addded bria nodes for bria3.1 and bria3.2 2025-07-24 10:56:37 -04:00
Ubuntu
25a57326b3 front end support for bria 2025-07-24 10:56:37 -04:00
Ubuntu
7f3e8087ba added support for loading bria transformer 2025-07-24 10:56:37 -04:00
Brandon Rising
dfc7835359 Setup Probe and UI to accept bria main models 2025-07-24 10:56:37 -04:00
580 changed files with 13642 additions and 22187 deletions

View File

@@ -18,6 +18,5 @@
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _❗Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_

View File

@@ -45,9 +45,6 @@ jobs:
steps:
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
# the /mnt dir has 70GBs of free space
# /dev/sda1 74G 28K 70G 1% /mnt
# According to some online posts the /mnt is not always there, so checking before setting docker to use it
run: |
echo "----- Free space before cleanup"
df -h
@@ -55,11 +52,6 @@ jobs:
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
if [ -d /mnt ]; then
sudo chmod -R 777 /mnt
echo '{"data-root": "/mnt/docker-root"}' | sudo tee /etc/docker/daemon.json
sudo systemctl restart docker
fi
echo "----- Free space after cleanup"
df -h

View File

@@ -1,30 +0,0 @@
# Checks that large files and LFS-tracked files are properly checked in with pointer format.
# Uses https://github.com/ppremk/lfs-warning to detect LFS issues.
name: 'lfs checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
jobs:
lfs-check:
runs-on: ubuntu-latest
timeout-minutes: 5
permissions:
# Required to label and comment on the PRs
pull-requests: write
steps:
- name: checkout
uses: actions/checkout@v4
- name: check lfs files
uses: ppremk/lfs-warning@v3.3

View File

@@ -39,18 +39,6 @@ jobs:
- name: checkout
uses: actions/checkout@v4
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
run: |
echo "----- Free space before cleanup"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
echo "----- Free space after cleanup"
df -h
- name: check for changed files
if: ${{ inputs.always_run != true }}
id: changed-files

View File

@@ -22,10 +22,6 @@
## GPU_DRIVER can be set to either `cuda` or `rocm` to enable GPU support in the container accordingly.
# GPU_DRIVER=cuda #| rocm
## If you are using ROCM, you will need to ensure that the render group within the container and the host system use the same group ID.
## To obtain the group ID of the render group on the host system, run `getent group render` and grab the number.
# RENDER_GROUP_ID=
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
## It is usually not necessary to change this. Use `id -u` on the host system to find the UID.
# CONTAINER_UID=1000

View File

@@ -43,6 +43,7 @@ ENV \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
UV_INDEX="https://download.pytorch.org/whl/cu124" \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
@@ -73,18 +74,20 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
ulimit -n 30000 && \
uv sync --extra $GPU_DRIVER --frozen
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids" && groupadd render
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv sync --frozen
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
@@ -102,6 +105,8 @@ COPY invokeai ${INVOKEAI_SRC}/invokeai
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
ulimit -n 30000 && \
uv pip install -e .[$GPU_DRIVER]
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv pip install -e .

View File

@@ -1,136 +0,0 @@
# syntax=docker/dockerfile:1.4
#### Web UI ------------------------------------
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./
RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN npx vite build
## Backend ---------------------------------------
FROM library/ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
apt update && apt install -y --no-install-recommends \
ca-certificates \
git \
gosu \
libglib2.0-0 \
libgl1 \
libglx-mesa0 \
build-essential \
libopencv-dev \
libstdc++-10-dev \
wget
ENV \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
VIRTUAL_ENV=/opt/venv \
INVOKEAI_SRC=/opt/invokeai \
PYTHON_VERSION=3.12 \
UV_PYTHON=3.12 \
UV_COMPILE_BYTECODE=1 \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
PATH="/opt/venv/bin:$PATH" \
CONTAINER_UID=${CONTAINER_UID:-1000} \
CONTAINER_GID=${CONTAINER_GID:-1000}
ARG GPU_DRIVER=cuda
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
RUN --mount=type=cache,target=/root/.cache/uv \
uv python install ${PYTHON_VERSION} && \
# chmod --recursive a+rX /root/.local/share/uv/python
chmod 711 /root
WORKDIR ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
ulimit -n 30000 && \
uv sync --extra $GPU_DRIVER --frozen
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
if [ "$GPU_DRIVER" = "rocm" ]; then \
wget -O /tmp/amdgpu-install.deb \
https://repo.radeon.com/amdgpu-install/6.3.4/ubuntu/noble/amdgpu-install_6.3.60304-1_all.deb && \
apt install -y /tmp/amdgpu-install.deb && \
apt update && \
amdgpu-install --usecase=rocm -y && \
apt-get autoclean && \
apt clean && \
rm -rf /tmp/* /var/tmp/* && \
usermod -a -G render ubuntu && \
usermod -a -G video ubuntu && \
echo "\\n/opt/rocm/lib\\n/opt/rocm/lib64" >> /etc/ld.so.conf.d/rocm.conf && \
ldconfig && \
update-alternatives --auto rocm; \
fi
## Heathen711: Leaving this for review input, will remove before merge
# RUN --mount=type=cache,target=/var/cache/apt \
# --mount=type=cache,target=/var/lib/apt \
# if [ "$GPU_DRIVER" = "rocm" ]; then \
# groupadd render && \
# usermod -a -G render ubuntu && \
# usermod -a -G video ubuntu; \
# fi
## Link amdgpu.ids for ROCm builds
## contributed by https://github.com/Rubonnek
# RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
# ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
CMD ["invokeai-web"]
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# add sources last to minimize image changes on code changes
COPY invokeai ${INVOKEAI_SRC}/invokeai
# this should not increase image size because we've already installed dependencies
# in a previous layer
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
ulimit -n 30000 && \
uv pip install -e .[$GPU_DRIVER]

View File

@@ -47,9 +47,8 @@ services:
invokeai-rocm:
<<: *invokeai
environment:
- AMD_VISIBLE_DEVICES=all
- RENDER_GROUP_ID=${RENDER_GROUP_ID}
runtime: amd
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
profiles:
- rocm

View File

@@ -21,17 +21,6 @@ _=$(id ${USER} 2>&1) || useradd -u ${USER_ID} ${USER}
# ensure the UID is correct
usermod -u ${USER_ID} ${USER} 1>/dev/null
## ROCM specific configuration
# render group within the container must match the host render group
# otherwise the container will not be able to access the host GPU.
if [[ -v "RENDER_GROUP_ID" ]] && [[ ! -z "${RENDER_GROUP_ID}" ]]; then
# ensure the render group exists
groupmod -g ${RENDER_GROUP_ID} render
usermod -a -G render ${USER}
usermod -a -G video ${USER}
fi
### Set the $PUBLIC_KEY env var to enable SSH access.
# We do not install openssh-server in the image by default to avoid bloat.
# but it is useful to have the full SSH server e.g. on Runpod.

View File

@@ -13,7 +13,7 @@ run() {
# parse .env file for build args
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER=/ {print $2}' .env)"
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
# default to 'cuda' profile
[[ -z "$profile" ]] && profile="cuda"
@@ -30,7 +30,7 @@ run() {
printf "%s\n" "starting service $service_name"
docker compose --profile "$profile" up -d "$service_name"
docker compose --profile "$profile" logs -f
docker compose logs -f
}
run

View File

@@ -265,7 +265,7 @@ If the key is unrecognized, this call raises an
#### exists(key) -> AnyModelConfig
Returns True if a model with the given key exists in the database.
Returns True if a model with the given key exists in the databsae.
#### search_by_path(path) -> AnyModelConfig
@@ -718,7 +718,7 @@ When downloading remote models is implemented, additional
configuration information, such as list of trigger terms, will be
retrieved from the HuggingFace and Civitai model repositories.
The probed values can be overridden by providing a dictionary in the
The probed values can be overriden by providing a dictionary in the
optional `config` argument passed to `import_model()`. You may provide
overriding values for any of the model's configuration
attributes. Here is an example of setting the
@@ -841,7 +841,7 @@ variable.
#### installer.start(invoker)
The `start` method is called by the API initialization routines when
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.

View File

@@ -16,7 +16,7 @@ We thank [all contributors](https://github.com/invoke-ai/InvokeAI/graphs/contrib
- @psychedelicious (Spencer Mabrito) - Web Team Leader
- @joshistoast (Josh Corbett) - Web Development
- @cheerio (Mary Rogers) - Lead Engineer & Web App Development
- @ebr (Eugene Brodsky) - Cloud/DevOps/Software engineer; your friendly neighbourhood cluster-autoscaler
- @ebr (Eugene Brodsky) - Cloud/DevOps/Sofware engineer; your friendly neighbourhood cluster-autoscaler
- @sunija - Standalone version
- @brandon (Brandon Rising) - Platform, Infrastructure, Backend Systems
- @ryanjdick (Ryan Dick) - Machine Learning & Training

View File

@@ -69,34 +69,34 @@ The following commands vary depending on the version of Invoke being installed a
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
7. Determine the torch backend to use for installation, if any. This is necessary to get the right version of torch installed. This is acheived by using [UV's built in torch support.](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection)
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
=== "Invoke v5.12 and later"
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu128`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.3`.
- **In all other cases, do not use a torch backend.**
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu128`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.10.0 to v5.11.0"
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu126`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.2.4`.
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu126`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.0.0 to v5.9.1"
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.1`.
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
- **In all other cases, do not use an index.**
=== "Invoke v4"
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm5.2`.
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
- **In all other cases, do not use an index.**
8. Install the `invokeai` package. Substitute the package specifier and version.
@@ -105,10 +105,10 @@ The following commands vary depending on the version of Invoke being installed a
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
```
If you determined you needed to use a torch backend in the previous step, you'll need to set the backend like this:
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --torch-backend=<VERSION> --force-reinstall
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
```
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:

View File

@@ -33,45 +33,30 @@ Hardware requirements vary significantly depending on model and image output siz
More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download and Set Up the Launcher
## Step 2: Download
The Launcher manages your Invoke install. Follow these instructions to download and set up the Launcher.
Download the most recent launcher for your operating system:
!!! info "Instructions for each OS"
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
- [Download for Linux](https://download.invoke.ai/Invoke%20Community%20Edition.AppImage)
=== "Windows"
## Step 3: Install or Update
- [Download for Windows](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition.Setup.latest.exe)
- Run the `EXE` to install the Launcher and start it.
- A desktop shortcut will be created; use this to run the Launcher in the future.
- You can delete the `EXE` file you downloaded.
=== "macOS"
- [Download for macOS](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest-arm64.dmg)
- Open the `DMG` and drag the app into `Applications`.
- Run the Launcher using its entry in `Applications`.
- You can delete the `DMG` file you downloaded.
=== "Linux"
- [Download for Linux](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest.AppImage)
- You may need to edit the `AppImage` file properties and make it executable.
- Optionally move the file to a location that does not require admin privileges and add a desktop shortcut for it.
- Run the Launcher by double-clicking the `AppImage` or the shortcut you made.
## Step 3: Install Invoke
Run the Launcher you just set up if you haven't already. Click **Install** and follow the instructions to install (or update) Invoke.
Run the launcher you just downloaded, click **Install** and follow the instructions to get set up.
If you have an existing Invoke installation, you can select it and let the launcher manage the install. You'll be able to update or launch the installation.
!!! tip "Updating"
!!! warning "Problem running the launcher on macOS"
The Launcher will check for updates for itself _and_ Invoke.
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can manually flag the launcher as safe:
- When the Launcher detects an update is available for itself, you'll get a small popup window. Click through this and the Launcher will update itself.
- When the Launcher detects an update for Invoke, you'll see a small green alert in the Launcher. Click that and follow the instructions to update Invoke.
- Open the **Invoke Community Edition.dmg** file.
- Drag the launcher to **Applications**.
- Open a terminal.
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
You should now be able to run the launcher.
## Step 4: Launch

View File

@@ -41,7 +41,7 @@ Nodes have a "Use Cache" option in their footer. This allows for performance imp
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
### Create Latent Noise
### Noise
An initial noise tensor is necessary for the latent diffusion process. As a result, the Denoising node requires a noise node input.

View File

@@ -10,7 +10,6 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -152,7 +151,6 @@ class ApiDependencies:
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
services = InvocationServices(
board_image_records=board_image_records,
@@ -183,7 +181,6 @@ class ApiDependencies:
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -1,39 +0,0 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from invokeai.app.services.videos_common import AddVideosToBoardResult, RemoveVideosFromBoardResult
board_videos_router = APIRouter(prefix="/v1/board_videos", tags=["boards"])
@board_videos_router.post(
"/batch",
operation_id="add_videos_to_board",
responses={
201: {"description": "Videos were added to board successfully"},
},
status_code=201,
response_model=AddVideosToBoardResult,
)
async def add_videos_to_board(
board_id: str = Body(description="The id of the board to add to"),
video_ids: list[str] = Body(description="The ids of the videos to add", embed=True),
) -> AddVideosToBoardResult:
"""Adds a list of videos to a board"""
raise HTTPException(status_code=501, detail="Not implemented")
@board_videos_router.post(
"/batch/delete",
operation_id="remove_videos_from_board",
responses={
201: {"description": "Videos were removed from board successfully"},
},
status_code=201,
response_model=RemoveVideosFromBoardResult,
)
async def remove_videos_from_board(
video_ids: list[str] = Body(description="The ids of the videos to remove", embed=True),
) -> RemoveVideosFromBoardResult:
"""Removes a list of videos from their board, if they had one"""
raise HTTPException(status_code=501, detail="Not implemented")

View File

@@ -1,58 +0,0 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.util.logging import logging
client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"])
@client_state_router.get(
"/{queue_id}/get_by_key",
operation_id="get_client_state_by_key",
response_model=str | None,
)
async def get_client_state_by_key(
queue_id: str = Path(description="The queue id to perform this operation on"),
key: str = Query(..., description="Key to get"),
) -> str | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@client_state_router.post(
"/{queue_id}/set_by_key",
operation_id="set_client_state",
response_model=str,
)
async def set_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
key: str = Query(..., description="Key to set"),
value: str = Body(..., description="Stringified value to set"),
) -> str:
"""Sets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@client_state_router.post(
"/{queue_id}/delete",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
@@ -17,7 +18,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
DeleteByDestinationResult,
EnqueueBatchResult,
FieldIdentifier,
ItemIdsResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
@@ -25,7 +25,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemNotFoundError,
SessionQueueStatus,
)
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.pagination import CursorPaginatedResults
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
@@ -68,6 +68,36 @@ async def enqueue_batch(
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
@session_queue_router.get(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItem]},
},
)
async def list_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
limit: int = Query(default=50, description="The number of items to fetch"),
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
@session_queue_router.get(
"/{queue_id}/list_all",
operation_id="list_all_queue_items",
@@ -89,56 +119,6 @@ async def list_all_queue_items(
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
@session_queue_router.get(
"/{queue_id}/item_ids",
operation_id="get_queue_item_ids",
responses={
200: {"model": ItemIdsResult},
},
)
async def get_queue_item_ids(
queue_id: str = Path(description="The queue id to perform this operation on"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
) -> ItemIdsResult:
"""Gets all queue item ids that match the given parameters"""
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
@session_queue_router.post(
"/{queue_id}/items_by_ids",
operation_id="get_queue_items_by_item_ids",
responses={200: {"model": list[SessionQueueItem]}},
)
async def get_queue_items_by_item_ids(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_ids: list[int] = Body(
embed=True, description="Object containing list of queue item ids to fetch queue items for"
),
) -> list[SessionQueueItem]:
"""Gets queue items for the specified queue item ids. Maintains order of item ids."""
try:
session_queue_service = ApiDependencies.invoker.services.session_queue
# Fetch queue items preserving the order of requested item ids
queue_items: list[SessionQueueItem] = []
for item_id in item_ids:
try:
queue_item = session_queue_service.get_queue_item(item_id=item_id)
if queue_item.queue_id != queue_id: # Auth protection for items from other queues
continue
queue_items.append(queue_item)
except Exception:
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
continue
return queue_items
except Exception:
raise HTTPException(status_code=500, detail="Failed to get queue items")
@session_queue_router.put(
"/{queue_id}/processor/resume",
operation_id="resume",
@@ -374,10 +354,7 @@ async def get_queue_item(
) -> SessionQueueItem:
"""Gets a queue item"""
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id=item_id)
if queue_item.queue_id != queue_id:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
return queue_item
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:

View File

@@ -1,119 +0,0 @@
from typing import Optional
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.videos_common import (
DeleteVideosResult,
StarredVideosResult,
UnstarredVideosResult,
VideoDTO,
VideoIdsResult,
VideoRecordChanges,
)
videos_router = APIRouter(prefix="/v1/videos", tags=["videos"])
@videos_router.patch(
"/i/{video_id}",
operation_id="update_video",
response_model=VideoDTO,
)
async def update_video(
video_id: str = Path(description="The id of the video to update"),
video_changes: VideoRecordChanges = Body(description="The changes to apply to the video"),
) -> VideoDTO:
"""Updates a video"""
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.get(
"/i/{video_id}",
operation_id="get_video_dto",
response_model=VideoDTO,
)
async def get_video_dto(
video_id: str = Path(description="The id of the video to get"),
) -> VideoDTO:
"""Gets a video's DTO"""
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.post("/delete", operation_id="delete_videos_from_list", response_model=DeleteVideosResult)
async def delete_videos_from_list(
video_ids: list[str] = Body(description="The list of ids of videos to delete", embed=True),
) -> DeleteVideosResult:
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.post("/star", operation_id="star_videos_in_list", response_model=StarredVideosResult)
async def star_videos_in_list(
video_ids: list[str] = Body(description="The list of ids of videos to star", embed=True),
) -> StarredVideosResult:
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.post("/unstar", operation_id="unstar_videos_in_list", response_model=UnstarredVideosResult)
async def unstar_videos_in_list(
video_ids: list[str] = Body(description="The list of ids of videos to unstar", embed=True),
) -> UnstarredVideosResult:
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.delete("/uncategorized", operation_id="delete_uncategorized_videos", response_model=DeleteVideosResult)
async def delete_uncategorized_videos() -> DeleteVideosResult:
"""Deletes all videos that are uncategorized"""
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.get("/", operation_id="list_video_dtos", response_model=OffsetPaginatedResults[VideoDTO])
async def list_video_dtos(
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate videos."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find videos without a board.",
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of videos per page"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred videos first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[VideoDTO]:
"""Lists video DTOs"""
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.get("/ids", operation_id="get_video_ids")
async def get_video_ids(
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate videos."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find videos without a board.",
),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred videos first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> VideoIdsResult:
"""Gets ordered list of video ids with metadata for optimistic updates"""
raise HTTPException(status_code=501, detail="Not implemented")
@videos_router.post(
"/videos_by_ids",
operation_id="get_videos_by_ids",
responses={200: {"model": list[VideoDTO]}},
)
async def get_videos_by_ids(
video_ids: list[str] = Body(embed=True, description="Object containing list of video ids to fetch DTOs for"),
) -> list[VideoDTO]:
"""Gets video DTOs for the specified video ids. Maintains order of input ids."""
raise HTTPException(status_code=501, detail="Not implemented")

View File

@@ -18,9 +18,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.api.routers import (
app_info,
board_images,
board_videos,
boards,
client_state,
download_queue,
images,
model_manager,
@@ -28,7 +26,6 @@ from invokeai.app.api.routers import (
session_queue,
style_presets,
utilities,
videos,
workflows,
)
from invokeai.app.api.sockets import SocketIO
@@ -127,16 +124,13 @@ app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(videos.videos_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(board_videos.board_videos_router, prefix="/api")
app.include_router(model_relationships.model_relationships_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.include_router(client_state.client_state_router, prefix="/api")
app.openapi = get_openapi_func(app)
@@ -161,12 +155,6 @@ def overridden_redoc() -> HTMLResponse:
web_root_path = Path(list(web_dir.__path__)[0])
if app_config.unsafe_disable_picklescan:
logger.warning(
"The unsafe_disable_picklescan option is enabled. This disables malware scanning while installing and"
"loading models, which may allow malicious code to be executed. Use at your own risk."
)
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:

View File

@@ -0,0 +1,158 @@
import cv2
import numpy as np
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.invocation_api import Classification, ImageOutput
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/"
class BriaControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
model: ModelIdentifierField = Field(description="The ControlNet model to use")
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
conditioning_scale: float = Field(description="The weight given to the ControlNet")
@invocation_output("bria_controlnet_output")
class BriaControlNetOutput(BaseInvocationOutput):
"""Bria ControlNet info"""
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
@invocation(
"bria_controlnet",
title="ControlNet - Bria",
tags=["controlnet", "bria"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Collect Bria ControlNet info to pass to denoiser node."""
control_image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
)
control_mode: BRIA_CONTROL_MODES = InputField(default="depth", description="The mode of the ControlNet")
control_weight: float = InputField(default=1.0, ge=-1, le=2, description="The weight given to the ControlNet")
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
if self.control_mode == "canny":
control_image = extract_canny(image_in)
elif self.control_mode == "depth":
control_image = extract_depth(image_in, context)
elif self.control_mode == "pose":
control_image = extract_openpose(image_in, context)
elif self.control_mode == "colorgrid":
control_image = tile(64, image_in)
elif self.control_mode == "recolor":
control_image = convert_to_grayscale(image_in)
elif self.control_mode == "tile":
control_image = tile(16, image_in)
control_image = resize_img(control_image)
image_dto = context.images.save(image=control_image)
image_output = ImageOutput.build(image_dto)
return BriaControlNetOutput(
preprocessed_images=image_output.image,
control=BriaControlNetField(
image=ImageField(image_name=image_dto.image_name),
model=self.control_model,
mode=self.control_mode,
conditioning_scale=self.control_weight,
),
)
RATIO_CONFIGS_1024 = {
0.6666666666666666: {"width": 832, "height": 1248},
0.7432432432432432: {"width": 880, "height": 1184},
0.8028169014084507: {"width": 912, "height": 1136},
1.0: {"width": 1024, "height": 1024},
1.2456140350877194: {"width": 1136, "height": 912},
1.3454545454545455: {"width": 1184, "height": 880},
1.4339622641509433: {"width": 1216, "height": 848},
1.5: {"width": 1248, "height": 832},
1.5490196078431373: {"width": 1264, "height": 816},
1.62: {"width": 1296, "height": 800},
1.7708333333333333: {"width": 1360, "height": 768},
}
def extract_depth(image: Image.Image, context: InvocationContext):
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
return depth_map
def extract_openpose(image: Image.Image, context: InvocationContext):
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face)
with body_model as body_model, hand_model as hand_model, face_model as face_model:
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
processed_image_open_pose = processed_image_open_pose.resize(image.size)
return processed_image_open_pose
def extract_canny(input_image):
image = np.array(input_image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def convert_to_grayscale(image):
gray_image = image.convert("L").convert("RGB")
return gray_image
def tile(downscale_factor, input_image):
control_image = input_image.resize(
(input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)
).resize(input_image.size, Image.Resampling.NEAREST)
return control_image
def resize_img(control_image):
image_ratio = control_image.width / control_image.height
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
to_height = RATIO_CONFIGS_1024[ratio]["height"]
to_width = RATIO_CONFIGS_1024[ratio]["width"]
resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS)
return resized_image

View File

@@ -0,0 +1,46 @@
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from PIL import Image
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation
@invocation(
"bria_decoder",
title="Decoder - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDecoderInvocation(BaseInvocation):
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
with context.models.load(self.vae.vae) as vae:
assert isinstance(vae, AutoencoderKL)
latents = latents / vae.config.scaling_factor
latents = latents.to(device=vae.device, dtype=vae.dtype)
decoded_output = vae.decode(latents)
image = decoded_output.sample
# Convert to numpy with proper gradient handling
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
img = Image.fromarray(image)
image_dto = context.images.save(image=img)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,180 @@
from typing import List, Tuple
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
from invokeai.backend.bria.controlnet_utils import prepare_control_images
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
@invocation_output("bria_denoise_output")
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
@invocation(
"bria_denoise",
title="Denoise - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDenoiseInvocation(BaseInvocation):
num_steps: int = InputField(
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
)
guidance_scale: float = InputField(
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser"
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
title="VAE",
)
latents: LatentsField = InputField(
description="Latents to denoise",
input=Input.Connection,
title="Latents",
)
latent_image_ids: LatentsField = InputField(
description="Latent Image IDs to denoise",
input=Input.Connection,
title="Latent Image IDs",
)
pos_embeds: LatentsField = InputField(
description="Positive Prompt Embeds",
input=Input.Connection,
title="Positive Prompt Embeds",
)
neg_embeds: LatentsField = InputField(
description="Negative Prompt Embeds",
input=Input.Connection,
title="Negative Prompt Embeds",
)
text_ids: LatentsField = InputField(
description="Text IDs",
input=Input.Connection,
title="Text IDs",
)
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
description="ControlNet",
input=Input.Connection,
title="ControlNet",
default=None,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
latents = context.tensors.load(self.latents.latents_name)
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
text_ids = context.tensors.load(self.text_ids.latents_name)
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
device = None
dtype = None
with (
context.models.load(self.transformer.transformer) as transformer,
context.models.load(scheduler_identifier) as scheduler,
context.models.load(self.vae.vae) as vae,
context.models.load(self.t5_encoder.text_encoder) as t5_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(transformer, BriaTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
assert isinstance(vae, AutoencoderKL)
dtype = transformer.dtype
device = transformer.device
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds))
control_model, control_images, control_modes, control_scales = None, None, None, None
if self.control is not None:
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
context=context,
vae=vae,
width=1024,
height=1024,
device=vae.device,
)
pipeline = BriaControlNetPipeline(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=t5_encoder,
tokenizer=t5_tokenizer,
controlnet=control_model,
)
pipeline.to(device=transformer.device, dtype=transformer.dtype)
latents = pipeline(
control_image=control_images,
control_mode=control_modes,
width=1024,
height=1024,
controlnet_conditioning_scale=control_scales,
num_inference_steps=self.num_steps,
max_sequence_length=128,
guidance_scale=self.guidance_scale,
latents=latents,
latent_image_ids=latent_image_ids,
text_ids=text_ids,
prompt_embeds=pos_embeds,
negative_prompt_embeds=neg_embeds,
output_type="latent",
)[0]
assert isinstance(latents, torch.Tensor)
saved_input_latents_tensor = context.tensors.save(latents)
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
return BriaDenoiseInvocationOutput(latents=latents_output)
def _prepare_multi_control(
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
control = self.control if isinstance(self.control, list) else [self.control]
control_images, control_models, control_modes, control_scales = [], [], [], []
for controlnet in control:
if controlnet is not None:
control_models.append(context.models.load(controlnet.model).model)
control_modes.append(BriaControlModes[controlnet.mode].value)
control_scales.append(controlnet.conditioning_scale)
try:
control_images.append(context.images.get_pil(controlnet.image.image_name))
except Exception:
raise FileNotFoundError(
f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline."
)
control_model = BriaMultiControlNetModel(control_models).to(device)
tensored_control_images, tensored_control_modes = prepare_control_images(
vae=vae,
control_images=control_images,
control_modes=control_modes,
width=width,
height=height,
device=device,
)
return control_model, tensored_control_images, tensored_control_modes, control_scales

View File

@@ -0,0 +1,76 @@
import torch
from invokeai.app.invocations.fields import Input, InputField, OutputField
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import (
BaseInvocationOutput,
FieldDescriptions,
LatentsField,
)
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_latent_sampler_output")
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_latent_sampler",
title="Latent Sampler - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaLatentSamplerInvocation(BaseInvocation):
seed: int = InputField(
default=42,
title="Seed",
description="The seed to use for the latent sampler",
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
with context.models.load(self.transformer.transformer) as transformer:
device = transformer.device
dtype = transformer.dtype
height, width = 1024, 1024
generator = torch.Generator(device=device).manual_seed(self.seed)
num_channels_latents = 4
latents, latent_image_ids = prepare_latents(
batch_size=1,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=dtype,
device=device,
generator=generator,
)
saved_latents_tensor = context.tensors.save(latents)
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)
latents_output = LatentsField(latents_name=saved_latents_tensor)
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
return BriaLatentSamplerInvocationOutput(
latents=latents_output,
latent_image_ids=latent_image_ids_output,
)

View File

@@ -0,0 +1,58 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import (
ModelIdentifierField,
SubModelType,
T5EncoderField,
TransformerField,
VAEField,
)
from invokeai.invocation_api import (
BaseInvocation,
BaseInvocationOutput,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_model_loader_output")
class BriaModelLoaderOutput(BaseInvocationOutput):
"""Bria base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"bria_model_loader",
title="Main Model - Bria",
tags=["model", "bria"],
version="1.0.0",
classification=Classification.Prototype,
)
class BriaModelLoaderInvocation(BaseInvocation):
"""Loads a bria base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description="Bria model (Transformer) to load",
ui_type=UIType.BriaMainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput:
for key in [self.model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return BriaModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,93 @@
from typing import Optional
import torch
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.app.invocations.model import T5EncoderField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.pipeline_bria_controlnet import encode_prompt
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InputField,
LatentsField,
invocation,
invocation_output,
)
@invocation_output("bria_text_encoder_output")
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
pos_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
neg_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
text_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_text_encoder",
title="Prompt - Bria",
tags=["prompt", "conditioning", "bria"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaTextEncoderInvocation(BaseInvocation):
prompt: str = InputField(
title="Prompt",
description="The prompt to encode",
)
negative_prompt: Optional[str] = InputField(
title="Negative Prompt",
description="The negative prompt to encode",
default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate",
)
max_length: int = InputField(
default=128,
title="Max Length",
description="The maximum length of the prompt",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
with (
t5_encoder_info as text_encoder,
t5_tokenizer_info as tokenizer,
):
assert isinstance(tokenizer, T5TokenizerFast)
assert isinstance(text_encoder, T5EncoderModel)
(prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt(
prompt=self.prompt,
tokenizer=tokenizer,
text_encoder=text_encoder,
negative_prompt=self.negative_prompt,
device=text_encoder.device,
num_images_per_prompt=1,
max_sequence_length=self.max_length,
lora_scale=1.0,
)
saved_pos_tensor = context.tensors.save(prompt_embeds)
saved_neg_tensor = context.tensors.save(negative_prompt_embeds)
saved_text_ids_tensor = context.tensors.save(text_ids)
pos_embeds_output = LatentsField(latents_name=saved_pos_tensor)
neg_embeds_output = LatentsField(latents_name=saved_neg_tensor)
text_ids_output = LatentsField(latents_name=saved_text_ids_tensor)
return BriaTextEncoderInvocationOutput(
pos_embeds=pos_embeds_output,
neg_embeds=neg_embeds_output,
text_ids=text_ids_output,
)

View File

@@ -17,7 +17,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
# should refactor to avoid this duplication.
@@ -39,11 +38,7 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = estimate_vae_working_memory_cogview4(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
@@ -67,8 +62,6 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKL)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")

View File

@@ -6,6 +6,7 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -19,7 +20,6 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
# should refactor to avoid this duplication.
@@ -39,15 +39,22 @@ class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = estimate_vae_working_memory_cogview4(
operation="decode", image_tensor=latents, vae=vae_info.model
)
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

View File

@@ -1,12 +1,11 @@
from enum import Enum
from typing import Any, Callable, Optional, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.util.metaenum import MetaEnum
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@@ -43,6 +42,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
BriaMainModel = "BriaMainModelField"
BriaControlNetModel = "BriaControlNetModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
@@ -65,16 +66,12 @@ class UIType(str, Enum, metaclass=MetaEnum):
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
Gemini2_5Model = "Gemini2_5ModelField"
FluxKontextModel = "FluxKontextModelField"
Veo3Model = "Veo3ModelField"
RunwayModel = "RunwayModelField"
# endregion
# region Misc Field Types
Scheduler = "SchedulerField"
Any = "AnyField"
Video = "VideoField"
# endregion
# region Internal Field Types
@@ -229,12 +226,6 @@ class ImageField(BaseModel):
image_name: str = Field(description="The name of the image")
class VideoField(BaseModel):
"""A video primitive field"""
video_id: str = Field(description="The id of the video")
class BoardField(BaseModel):
"""A board primitive field"""
@@ -332,9 +323,14 @@ class ConditioningField(BaseModel):
)
class BoundingBoxField(BoundingBox):
class BoundingBoxField(BaseModel):
"""A bounding box primitive value."""
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
score: Optional[float] = Field(
default=None,
ge=0.0,
@@ -343,6 +339,21 @@ class BoundingBoxField(BoundingBox):
"when the bounding box was produced by a detector and has an associated confidence score.",
)
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
def tuple(self) -> Tuple[int, int, int, int]:
"""
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
"""
return (self.x_min, self.y_min, self.x_max, self.y_max)
class MetadataField(RootModel[dict[str, Any]]):
"""

View File

@@ -63,7 +63,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.1.0",
version="4.0.0",
)
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
@@ -153,7 +153,7 @@ class FluxDenoiseInvocation(BaseInvocation):
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
@@ -328,21 +328,6 @@ class FluxDenoiseInvocation(BaseInvocation):
cfg_scale_end_step=self.cfg_scale_end_step,
)
kontext_extension = None
if self.kontext_conditioning:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning
if isinstance(self.kontext_conditioning, list)
else [self.kontext_conditioning],
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
@@ -400,6 +385,19 @@ class FluxDenoiseInvocation(BaseInvocation):
dtype=inference_dtype,
)
kontext_extension = None
if self.kontext_conditioning is not None:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None

View File

@@ -3,6 +3,7 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -17,7 +18,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
@invocation(
@@ -39,11 +39,17 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
assert isinstance(vae_info.model, AutoEncoder)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="decode", image_tensor=latents, vae=vae_info.model
)
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype

View File

@@ -15,7 +15,6 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
@invocation(
@@ -42,12 +41,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
assert isinstance(vae_info.model, AutoEncoder)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

View File

@@ -1347,96 +1347,3 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
image_dto = context.images.save(image=target_image)
return ImageOutput.build(image_dto)
@invocation(
"flux_kontext_image_prep",
title="FLUX Kontext Image Prep",
tags=["image", "concatenate", "flux", "kontext"],
category="image",
version="1.0.0",
)
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Prepares an image or images for use with FLUX Kontext. The first/single image is resized to the nearest
preferred Kontext resolution. All other images are concatenated horizontally, maintaining their aspect ratio."""
images: list[ImageField] = InputField(
description="The images to concatenate",
min_length=1,
max_length=10,
)
use_preferred_resolution: bool = InputField(
default=True, description="Use FLUX preferred resolutions for the first image"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
# Step 1: Load all images
pil_images = []
for image_field in self.images:
image = context.images.get_pil(image_field.image_name, mode="RGBA")
pil_images.append(image)
# Step 2: Determine target resolution for the first image
first_image = pil_images[0]
width, height = first_image.size
if self.use_preferred_resolution:
aspect_ratio = width / height
# Find the closest preferred resolution for the first image
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
scaled_height = 2 * int(target_height / 16)
final_height = 8 * scaled_height # This will be consistent for all images
scaled_width = 2 * int(target_width / 16)
first_width = 8 * scaled_width
else:
# Use original dimensions of first image, ensuring divisibility by 16
final_height = 16 * (height // 16)
first_width = 16 * (width // 16)
# Ensure minimum dimensions
if final_height < 16:
final_height = 16
if first_width < 16:
first_width = 16
# Step 3: Process and resize all images with consistent height
processed_images = []
total_width = 0
for i, image in enumerate(pil_images):
if i == 0:
# First image uses the calculated dimensions
final_width = first_width
else:
# Subsequent images maintain aspect ratio with the same height
img_aspect_ratio = image.width / image.height
# Calculate width that maintains aspect ratio at the target height
calculated_width = int(final_height * img_aspect_ratio)
# Ensure width is divisible by 16 for proper VAE encoding
final_width = 16 * (calculated_width // 16)
# Ensure minimum width
if final_width < 16:
final_width = 16
# Resize image to calculated dimensions
resized_image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
processed_images.append(resized_image)
total_width += final_width
# Step 4: Concatenate images horizontally
concatenated_image = Image.new("RGB", (total_width, final_height))
x_offset = 0
for img in processed_images:
concatenated_image.paste(img, (x_offset, 0))
x_offset += img.width
# Save the concatenated image
image_dto = context.images.save(image=concatenated_image)
return ImageOutput.build(image_dto)

View File

@@ -27,7 +27,6 @@ from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
@invocation(
@@ -53,24 +52,11 @@ class ImageToLatentsInvocation(BaseInvocation):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
@classmethod
@staticmethod
def vae_encode(
cls,
vae_info: LoadedModel,
upcast: bool,
tiled: bool,
image_tensor: torch.Tensor,
tile_size: int = 0,
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
operation="encode",
image_tensor=image_tensor,
vae=vae_info.model,
tile_size=tile_size if tiled else None,
fp32=upcast,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
@@ -127,7 +113,6 @@ class ImageToLatentsInvocation(BaseInvocation):
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
@@ -135,11 +120,7 @@ class ImageToLatentsInvocation(BaseInvocation):
context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info,
upcast=self.fp32,
tiled=self.tiled or context.config.get().force_tiled_decode,
image_tensor=image_tensor,
tile_size=self.tile_size,
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu")

View File

@@ -27,7 +27,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
@invocation(
@@ -54,6 +53,39 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
def _estimate_working_memory(
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if self.fp32 else 2
scaling_constant = 2200 # Determined experimentally.
if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
out_h = tile_size
out_w = tile_size
working_memory = out_h * out_w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
working_memory = out_h * out_w * element_size * scaling_constant
if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
@@ -62,13 +94,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
operation="decode",
image_tensor=latents,
vae=vae_info.model,
tile_size=self.tile_size if use_tiling else None,
fp32=self.fp32,
)
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

View File

@@ -27,7 +27,6 @@ from invokeai.app.invocations.fields import (
SD3ConditioningField,
TensorField,
UIComponent,
VideoField,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -288,30 +287,6 @@ class ImageCollectionInvocation(BaseInvocation):
return ImageCollectionOutput(collection=self.collection)
# endregion
# region Video
@invocation_output("video_output")
class VideoOutput(BaseInvocationOutput):
"""Base class for nodes that output a video"""
video: VideoField = OutputField(description="The output video")
width: int = OutputField(description="The width of the video in pixels")
height: int = OutputField(description="The height of the video in pixels")
duration_seconds: float = OutputField(description="The duration of the video in seconds")
@classmethod
def build(cls, video_id: str, width: int, height: int, duration_seconds: float) -> "VideoOutput":
return cls(
video=VideoField(video_id=video_id),
width=width,
height=height,
duration_seconds=duration_seconds,
)
# endregion
# region DenoiseMask

View File

@@ -17,7 +17,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
@invocation(
@@ -35,11 +34,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = estimate_vae_working_memory_sd3(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
@@ -63,8 +58,6 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKL)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")

View File

@@ -6,6 +6,7 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -19,7 +20,6 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
@invocation(
@@ -41,15 +41,22 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = estimate_vae_working_memory_sd3(
operation="decode", image_tensor=latents, vae=vae_info.model
)
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

View File

@@ -1,75 +1,72 @@
from itertools import zip_longest
from enum import Enum
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from transformers import AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from transformers.models.sam2 import Sam2Model
from transformers.models.sam2.processing_sam2 import Sam2Processor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_2_pipeline import SegmentAnything2Pipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.image_util.segment_anything.shared import SAMInput, SAMPoint
SegmentAnythingModelKey = Literal[
"segment-anything-base",
"segment-anything-large",
"segment-anything-huge",
"segment-anything-2-tiny",
"segment-anything-2-small",
"segment-anything-2-base",
"segment-anything-2-large",
]
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
"segment-anything-base": "facebook/sam-vit-base",
"segment-anything-large": "facebook/sam-vit-large",
"segment-anything-huge": "facebook/sam-vit-huge",
"segment-anything-2-tiny": "facebook/sam2.1-hiera-tiny",
"segment-anything-2-small": "facebook/sam2.1-hiera-small",
"segment-anything-2-base": "facebook/sam2.1-hiera-base-plus",
"segment-anything-2-large": "facebook/sam2.1-hiera-large",
}
class SAMPointsField(BaseModel):
points: list[SAMPoint] = Field(..., description="The points of the object", min_length=1)
class SAMPointLabel(Enum):
negative = -1
neutral = 0
positive = 1
def to_list(self) -> list[list[float]]:
class SAMPoint(BaseModel):
x: int = Field(..., description="The x-coordinate of the point")
y: int = Field(..., description="The y-coordinate of the point")
label: SAMPointLabel = Field(..., description="The label of the point")
class SAMPointsField(BaseModel):
points: list[SAMPoint] = Field(..., description="The points of the object")
def to_list(self) -> list[list[int]]:
return [[point.x, point.y, point.label.value] for point in self.points]
@invocation(
"segment_anything",
title="Segment Anything",
tags=["prompt", "segmentation", "sam", "sam2"],
tags=["prompt", "segmentation"],
category="segmentation",
version="1.3.0",
version="1.2.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model (SAM or SAM2)."""
"""Runs a Segment Anything Model."""
# Reference:
# - https://arxiv.org/pdf/2304.02643
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use (SAM or SAM2).")
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
image: ImageField = InputField(description="The image to segment.")
bounding_boxes: list[BoundingBoxField] | None = InputField(
default=None, description="The bounding boxes to prompt the model with."
default=None, description="The bounding boxes to prompt the SAM model with."
)
point_lists: list[SAMPointsField] | None = InputField(
default=None,
description="The list of point lists to prompt the model with. Each list of points represents a single object.",
description="The list of point lists to prompt the SAM model with. Each list of points represents a single object.",
)
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
@@ -80,18 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
default="all",
)
@model_validator(mode="after")
def validate_points_and_boxes_len(self):
if self.point_lists is not None and self.bounding_boxes is not None:
if len(self.point_lists) != len(self.bounding_boxes):
raise ValueError("If both point_lists and bounding_boxes are provided, they must have the same length.")
return self
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
not self.point_lists or len(self.point_lists) == 0
):
@@ -118,38 +111,26 @@ class SegmentAnythingInvocation(BaseInvocation):
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
sam_processor = SamProcessor.from_pretrained(model_path, local_files_only=True)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
@staticmethod
def _load_sam_2_model(model_path: Path):
sam2_model = Sam2Model.from_pretrained(model_path, local_files_only=True)
sam2_processor = Sam2Processor.from_pretrained(model_path, local_files_only=True)
return SegmentAnything2Pipeline(sam2_model=sam2_model, sam2_processor=sam2_processor)
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
"""Use Segment Anything (SAM or SAM2) to generate masks given an image + a set of bounding boxes."""
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = (
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
)
sam_points = [p.to_list() for p in self.point_lists] if self.point_lists else None
source = SEGMENT_ANYTHING_MODEL_IDS[self.model]
inputs: list[SAMInput] = []
for bbox_field, point_field in zip_longest(self.bounding_boxes or [], self.point_lists or [], fillvalue=None):
inputs.append(
SAMInput(
bounding_box=bbox_field,
points=point_field.points if point_field else None,
)
)
if "sam2" in source:
loader = SegmentAnythingInvocation._load_sam_2_model
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
assert isinstance(pipeline, SegmentAnything2Pipeline)
masks = pipeline.segment(image=image, inputs=inputs)
else:
loader = SegmentAnythingInvocation._load_sam_model
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
assert isinstance(pipeline, SegmentAnythingPipeline)
masks = pipeline.segment(image=image, inputs=inputs)
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes, point_lists=sam_points)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:

View File

@@ -49,11 +49,3 @@ class BoardImageRecordStorageBase(ABC):
) -> int:
"""Gets the number of images for a board."""
pass
@abstractmethod
def get_asset_count_for_board(
self,
board_id: str,
) -> int:
"""Gets the number of assets for a board."""
pass

View File

@@ -3,8 +3,6 @@ from typing import Optional, cast
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
ImageCategory,
ImageRecord,
deserialize_image_record,
@@ -153,38 +151,15 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def get_image_count_for_board(self, board_id: str) -> int:
with self._db.transaction() as cursor:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(IMAGE_CATEGORIES)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
cursor.execute(
f"""--sql
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE AND images.image_category IN ( {placeholders} )
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(*category_strings, board_id),
)
count = cast(int, cursor.fetchone()[0])
return count
def get_asset_count_for_board(self, board_id: str) -> int:
with self._db.transaction() as cursor:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(ASSETS_CATEGORIES)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE AND images.image_category IN ( {placeholders} )
AND board_images.board_id = ?;
""",
(*category_strings, board_id),
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -12,20 +12,12 @@ class BoardDTO(BoardRecord):
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
asset_count: int = Field(description="The number of assets in the board.")
"""The number of assets in the board."""
video_count: int = Field(description="The number of videos in the board.")
"""The number of videos in the board."""
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int, asset_count: int, video_count: int
) -> BoardDTO:
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.model_dump(exclude={"cover_image_name"}),
cover_image_name=cover_image_name,
image_count=image_count,
asset_count=asset_count,
video_count=video_count,
)

View File

@@ -17,7 +17,7 @@ class BoardService(BoardServiceABC):
board_name: str,
) -> BoardDTO:
board_record = self.__invoker.services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0, 0, 0)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self.__invoker.services.board_records.get(board_id)
@@ -27,9 +27,7 @@ class BoardService(BoardServiceABC):
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(board_id)
video_count = 0 # noop for OSS
return board_record_to_dto(board_record, cover_image_name, image_count, asset_count, video_count)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update(
self,
@@ -44,9 +42,7 @@ class BoardService(BoardServiceABC):
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(board_id)
video_count = 0 # noop for OSS
return board_record_to_dto(board_record, cover_image_name, image_count, asset_count, video_count)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id)
@@ -71,9 +67,7 @@ class BoardService(BoardServiceABC):
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(r.board_id)
video_count = 0 # noop for OSS
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count, video_count))
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
@@ -90,8 +84,6 @@ class BoardService(BoardServiceABC):
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
asset_count = self.__invoker.services.board_image_records.get_asset_count_for_board(r.board_id)
video_count = 0 # noop for OSS
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count, asset_count, video_count))
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos

View File

@@ -150,15 +150,4 @@ class BulkDownloadService(BulkDownloadBase):
def _is_valid_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for a bulk download."""
path = path if isinstance(path, Path) else Path(path)
# Resolve the path to handle any path traversal attempts (e.g., ../)
resolved_path = path.resolve()
# The path may not traverse out of the bulk downloads folder or its subfolders
does_not_traverse = resolved_path.parent == self._bulk_downloads_folder.resolve()
# The path must exist and be a .zip file
does_exist = resolved_path.exists()
is_zip_file = resolved_path.suffix == ".zip"
return does_exist and is_zip_file and does_not_traverse
return path.exists()

View File

@@ -1,42 +0,0 @@
from abc import ABC, abstractmethod
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
@abstractmethod
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
"""
Set a key-value pair for the client.
Args:
key (str): The key to set.
value (str): The value to set for the key.
Returns:
str: The value that was set.
"""
pass
@abstractmethod
def get_by_key(self, queue_id: str, key: str) -> str | None:
"""
Get the value for a specific key of the client.
Args:
key (str): The key to retrieve the value for.
Returns:
str | None: The value associated with the key, or None if the key does not exist.
"""
pass
@abstractmethod
def delete(self, queue_id: str) -> None:
"""
Delete all client state.
"""
pass

View File

@@ -1,65 +0,0 @@
import json
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def _get(self) -> dict[str, str] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
state = self._get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
""",
(json.dumps(state),),
)
return value
def get_by_key(self, queue_id: str, key: str) -> str | None:
state = self._get()
if state is None:
return None
return state.get(key, None)
def delete(self, queue_id: str) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
)

View File

@@ -107,7 +107,6 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.<br>Valid values: `blake3_multi`, `blake3_single`, `random`, `md5`, `sha1`, `sha224`, `sha256`, `sha384`, `sha512`, `blake2b`, `blake2s`, `sha3_224`, `sha3_256`, `sha3_384`, `sha3_512`, `shake_128`, `shake_256`
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
"""
_root: Optional[Path] = PrivateAttr(default=None)
@@ -197,7 +196,6 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3_single", description="Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
scan_models_on_startup: bool = Field(default=False, description="Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.")
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
# fmt: on

View File

@@ -234,8 +234,8 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: str = Field(description="The timestamp when the queue item was created")
updated_at: str = Field(description="The timestamp when the queue item was last updated")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
@@ -258,8 +258,8 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at),
updated_at=str(queue_item.updated_at),
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,

View File

@@ -58,15 +58,6 @@ class ImageCategory(str, Enum, metaclass=MetaEnum):
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
IMAGE_CATEGORIES: list[ImageCategory] = [ImageCategory.GENERAL]
ASSETS_CATEGORIES: list[ImageCategory] = [
ImageCategory.CONTROL,
ImageCategory.MASK,
ImageCategory.USER,
ImageCategory.OTHER,
]
class InvalidImageCategoryException(ValueError):
"""Raised when a provided value is not a valid ImageCategory.

View File

@@ -17,7 +17,6 @@ if TYPE_CHECKING:
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
@@ -74,7 +73,6 @@ class InvocationServices:
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -104,4 +102,3 @@ class InvocationServices:
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence

View File

@@ -7,7 +7,7 @@ import threading
import time
from pathlib import Path
from queue import Empty, Queue
from shutil import move, rmtree
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@@ -186,15 +186,13 @@ class ModelInstallService(ModelInstallServiceBase):
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
if preferred_name := config.name:
if Path(model_path).is_file():
# Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot.
preferred_name = f"{preferred_name}{model_path.suffix}"
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
)
try:
new_path = self._move_model(model_path, dest_path)
new_path = self._copy_model(model_path, dest_path)
except FileExistsError as excp:
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
@@ -619,17 +617,30 @@ class ModelInstallService(ModelInstallServiceBase):
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
if old_path.is_dir():
copytree(old_path, new_path)
else:
copyfile(old_path, new_path)
return new_path
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
if new_path.exists():
raise FileExistsError(f"Cannot move {old_path} to {new_path}: destination already exists")
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
move(old_path, new_path)
return new_path
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):

View File

@@ -87,21 +87,9 @@ class ModelLoadService(ModelLoadServiceBase):
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if self._app_config.unsafe_disable_picklescan:
self._logger.warning(
f"Model at {checkpoint} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
if scan_result.scan_err:
if self._app_config.unsafe_disable_picklescan:
self._logger.warning(
f"Error scanning model at {checkpoint} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result

View File

@@ -15,7 +15,6 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ControlAdapterDefaultSettings,
LoraModelDefaultSettings,
MainModelDefaultSettings,
)
from invokeai.backend.model_manager.taxonomy import (
@@ -84,8 +83,8 @@ class ModelRecordChanges(BaseModelExcludeNull):
file_size: Optional[int] = Field(description="Size of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | LoraModelDefaultSettings | ControlAdapterDefaultSettings] = (
Field(description="Default settings for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
# Checkpoint-specific changes

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import Any, Coroutine, Optional
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
@@ -14,7 +15,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
ItemIdsResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
@@ -22,7 +22,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.pagination import CursorPaginatedResults
class SessionQueueBase(ABC):
@@ -135,6 +135,19 @@ class SessionQueueBase(ABC):
"""Deletes all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def list_all_queue_items(
self,
@@ -144,18 +157,9 @@ class SessionQueueBase(ABC):
"""Gets all queue items that match the given parameters"""
pass
@abstractmethod
def get_queue_item_ids(
self,
queue_id: str,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
) -> ItemIdsResult:
"""Gets all queue item ids that match the given parameters"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID for a given queue"""
"""Gets a session queue item by ID"""
pass
@abstractmethod

View File

@@ -176,14 +176,6 @@ DEFAULT_QUEUE_ID = "default"
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
class ItemIdsResult(BaseModel):
"""Response containing ordered item ids with metadata for optimistic updates."""
item_ids: list[int] = Field(description="Ordered list of item ids")
total_count: int = Field(description="Total number of queue items matching the query")
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])

View File

@@ -22,7 +22,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
ItemIdsResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
@@ -34,7 +33,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -588,6 +587,59 @@ class SqliteSessionQueue(SessionQueueBase):
)
return self.get_queue_item(item_id)
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,
@@ -619,26 +671,6 @@ class SqliteSessionQueue(SessionQueueBase):
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_item_ids(
self,
queue_id: str,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
) -> ItemIdsResult:
with self._db.transaction() as cursor_:
query = f"""--sql
SELECT item_id
FROM session_queue
WHERE queue_id = ?
ORDER BY created_at {order_dir.value}
"""
query_params = [queue_id]
cursor_.execute(query, query_params)
result = cast(list[sqlite3.Row], cursor_.fetchall())
item_ids = [row[0] for row in result]
return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
with self._db.transaction() as cursor:
cursor.execute(

View File

@@ -23,7 +23,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -64,7 +63,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.register_migration(build_migration_21())
migrator.run_migrations()
return db

View File

@@ -1,40 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration21Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
CREATE TABLE client_state (
id INTEGER PRIMARY KEY CHECK(id = 1),
data TEXT NOT NULL, -- Frontend will handle the shape of this data
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
"""
)
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE id = OLD.id;
END;
"""
)
def build_migration_21() -> Migration:
"""Builds the migration object for migrating from version 20 to version 21. This includes:
- Creating the `client_state` table.
- Adding a trigger to update the `updated_at` field on updates.
"""
return Migration(
from_version=20,
to_version=21,
callback=Migration21Callback(),
)

View File

@@ -1,179 +0,0 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Field, StrictBool, StrictStr
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
VIDEO_DTO_COLS = ", ".join(
[
"videos." + c
for c in [
"video_id",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
"starred",
]
]
)
class VideoRecord(BaseModelExcludeNull):
"""Deserialized video record without metadata."""
video_id: str = Field(description="The unique id of the video.")
"""The unique id of the video."""
width: int = Field(description="The width of the video in px.")
"""The actual width of the video in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the video in px.")
"""The actual height of the video in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the video.")
"""The created timestamp of the video."""
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the video.")
"""The updated timestamp of the video."""
deleted_at: Optional[Union[datetime.datetime, str]] = Field(
default=None, description="The deleted timestamp of the video."
)
"""The deleted timestamp of the video."""
is_intermediate: bool = Field(description="Whether this is an intermediate video.")
"""Whether this is an intermediate video."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this video, if it is a generated video.",
)
"""The session ID that generated this video, if it is a generated video."""
node_id: Optional[str] = Field(
default=None,
description="The node ID that generated this video, if it is a generated video.",
)
"""The node ID that generated this video, if it is a generated video."""
starred: bool = Field(description="Whether this video is starred.")
"""Whether this video is starred."""
class VideoRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a video record.
Only limited changes are valid:
- `session_id`: change the session associated with a video
- `is_intermediate`: change the video's `is_intermediate` flag
- `starred`: change whether the video is starred
"""
session_id: Optional[StrictStr] = Field(
default=None,
description="The video's new session ID.",
)
"""The video's new session ID."""
is_intermediate: Optional[StrictBool] = Field(default=None, description="The video's new `is_intermediate` flag.")
"""The video's new `is_intermediate` flag."""
starred: Optional[StrictBool] = Field(default=None, description="The video's new `starred` state")
"""The video's new `starred` state."""
def deserialize_video_record(video_dict: dict) -> VideoRecord:
"""Deserializes a video record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
video_id = video_dict.get("video_id", "unknown")
width = video_dict.get("width", 0)
height = video_dict.get("height", 0)
session_id = video_dict.get("session_id", None)
node_id = video_dict.get("node_id", None)
created_at = video_dict.get("created_at", get_iso_timestamp())
updated_at = video_dict.get("updated_at", get_iso_timestamp())
deleted_at = video_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = video_dict.get("is_intermediate", False)
starred = video_dict.get("starred", False)
return VideoRecord(
video_id=video_id,
width=width,
height=height,
session_id=session_id,
node_id=node_id,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
starred=starred,
)
class VideoCollectionCounts(BaseModel):
starred_count: int = Field(description="The number of starred videos in the collection.")
unstarred_count: int = Field(description="The number of unstarred videos in the collection.")
class VideoIdsResult(BaseModel):
"""Response containing ordered video ids with metadata for optimistic updates."""
video_ids: list[str] = Field(description="Ordered list of video ids")
starred_count: int = Field(description="Number of starred videos (when starred_first=True)")
total_count: int = Field(description="Total number of videos matching the query")
class VideoUrlsDTO(BaseModelExcludeNull):
"""The URLs for an image and its thumbnail."""
video_id: str = Field(description="The unique id of the video.")
"""The unique id of the video."""
video_url: str = Field(description="The URL of the video.")
"""The URL of the video."""
thumbnail_url: str = Field(description="The URL of the video's thumbnail.")
"""The URL of the video's thumbnail."""
class VideoDTO(VideoRecord, VideoUrlsDTO):
"""Deserialized video record, enriched for the frontend."""
board_id: Optional[str] = Field(
default=None, description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
def video_record_to_dto(
video_record: VideoRecord,
video_url: str,
thumbnail_url: str,
board_id: Optional[str],
) -> VideoDTO:
"""Converts a video record to a video DTO."""
return VideoDTO(
**video_record.model_dump(),
video_url=video_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
)
class ResultWithAffectedBoards(BaseModel):
affected_boards: list[str] = Field(description="The ids of boards affected by the delete operation")
class DeleteVideosResult(ResultWithAffectedBoards):
deleted_videos: list[str] = Field(description="The ids of the videos that were deleted")
class StarredVideosResult(ResultWithAffectedBoards):
starred_videos: list[str] = Field(description="The ids of the videos that were starred")
class UnstarredVideosResult(ResultWithAffectedBoards):
unstarred_videos: list[str] = Field(description="The ids of the videos that were unstarred")
class AddVideosToBoardResult(ResultWithAffectedBoards):
added_videos: list[str] = Field(description="The video ids that were added to the board")
class RemoveVideosFromBoardResult(ResultWithAffectedBoards):
removed_videos: list[str] = Field(description="The video ids that were removed from their board")

View File

View File

@@ -0,0 +1,314 @@
import math
import os
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from diffusers.utils import logging
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_t5_prompt_embeds(
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str], None] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
# padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
# Concat zeros to max_sequence
b, seq_len, dim = prompt_embeds.shape
if seq_len < max_sequence_length:
padding = torch.zeros(
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
prompt_embeds = prompt_embeds.to(device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# in order the get the same sigmas as in training and sample from them
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
sigmas = timesteps / num_train_timesteps
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
new_sigmas = sigmas[inds]
return new_sigmas
def is_ng_none(negative_prompt):
return (
negative_prompt is None
or negative_prompt == ""
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
or (isinstance(negative_prompt, list) and negative_prompt[0] == "")
)
class CudaTimerContext:
def __init__(self, times_arr):
self.times_arr = times_arr
def __enter__(self):
self.before_event = torch.cuda.Event(enable_timing=True)
self.after_event = torch.cuda.Event(enable_timing=True)
self.before_event.record()
def __exit__(self, type, value, traceback):
self.after_event.record()
torch.cuda.synchronize()
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
self.times_arr.append(elapsed_time)
def get_env_prefix():
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
if env == "AWS":
return "SM_CHANNEL"
elif env == "AZURE":
return "AZUREML_DATAREFERENCE"
raise Exception(f"Env {env} not supported")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def initialize_distributed():
# Initialize the process group for distributed training
dist.init_process_group("nccl")
# Get the current process's rank (ID) and the total number of processes (world size)
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"Initialized distributed training: Rank {rank}/{world_size}")
def get_clip_prompt_embeds(
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
assert max_sequence_length == tokenizer.model_max_length
prompt = [prompt] if isinstance(prompt, str) else prompt
# Define tokenizers and text encoders
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, pooled_prompt_embeds
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class FluxPosEmbed(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin

View File

@@ -0,0 +1,6 @@
__version__ = "0.0.9"
from invokeai.backend.bria.controlnet_aux.canny import CannyDetector as CannyDetector
from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector as OpenposeDetector
__all__ = ["CannyDetector", "OpenposeDetector"]

View File

@@ -0,0 +1,48 @@
import warnings
import cv2
import numpy as np
from PIL import Image
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
class CannyDetector:
def __call__(
self,
input_image=None,
low_threshold=100,
high_threshold=200,
detect_resolution=512,
image_resolution=512,
output_type=None,
**kwargs,
):
if "img" in kwargs:
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning, stacklevel=2)
input_image = kwargs.pop("img")
if input_image is None:
raise ValueError("input_image must be defined.")
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"
input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,108 @@
OPENPOSE: MULTIPERSON KEYPOINT DETECTION
SOFTWARE LICENSE AGREEMENT
ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
COPYRIGHT: The Software is owned by Licensor and is protected by United
States copyright laws and applicable international treaties and/or conventions.
PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
EXPORT REGULATION: Licensee agrees to comply with any and all applicable
U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
************************************************************************
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
COPYRIGHT
All contributions by the University of California:
Copyright (c) 2014-2017 The Regents of the University of California (Regents)
All rights reserved.
All other contributions:
Copyright (c) 2014-2017, the respective contributors
All rights reserved.
Caffe uses a shared copyright model: each contributor holds copyright over
their contributions to Caffe. The project versioning records all such
contribution and copyright details. If a contributor wants to further mark
their specific copyright on a particular contribution, they should indicate
their copyright solely in the commit message of the change when it is
committed.
LICENSE
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
CONTRIBUTION AGREEMENT
By contributing to the BVLC/caffe repository through pull-request, comment,
or otherwise, the contributor releases their content to the
license and copyright terms herein.
************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********

View File

@@ -0,0 +1,267 @@
# Openpose
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
# 3rd Edited by ControlNet
# 4th Edited by ControlNet (added face and correct hands)
# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
# This preprocessor is licensed by CMU for non-commercial use only.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import warnings
from typing import List, NamedTuple, Tuple, Union
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body, BodyResult, Keypoint
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
HandResult = List[Keypoint]
FaceResult = List[Keypoint]
class PoseResult(NamedTuple):
body: BodyResult
left_hand: Union[HandResult, None]
right_hand: Union[HandResult, None]
face: Union[FaceResult, None]
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
"""
Draw the detected poses on an empty canvas.
Args:
poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
H (int): The height of the canvas.
W (int): The width of the canvas.
draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
Returns:
numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
"""
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
for pose in poses:
if draw_body:
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
if draw_hand:
canvas = util.draw_handpose(canvas, pose.left_hand)
canvas = util.draw_handpose(canvas, pose.right_hand)
if draw_face:
canvas = util.draw_facepose(canvas, pose.face)
return canvas
class OpenposeDetector:
"""
A class for detecting human poses in images using the Openpose model.
Attributes:
model_dir (str): Path to the directory where the pose models are stored.
"""
def __init__(self, body_estimation, hand_estimation=None, face_estimation=None):
self.body_estimation = body_estimation
self.hand_estimation = hand_estimation
self.face_estimation = face_estimation
@classmethod
def from_pretrained(
cls,
pretrained_model_or_path,
filename=None,
hand_filename=None,
face_filename=None,
cache_dir=None,
local_files_only=False,
):
if pretrained_model_or_path == "lllyasviel/ControlNet":
filename = filename or "annotator/ckpts/body_pose_model.pth"
hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth"
face_filename = face_filename or "facenet.pth"
face_pretrained_model_or_path = "lllyasviel/Annotators"
else:
filename = filename or "body_pose_model.pth"
hand_filename = hand_filename or "hand_pose_model.pth"
face_filename = face_filename or "facenet.pth"
face_pretrained_model_or_path = pretrained_model_or_path
if os.path.isdir(pretrained_model_or_path):
body_model_path = os.path.join(pretrained_model_or_path, filename)
hand_model_path = os.path.join(pretrained_model_or_path, hand_filename)
face_model_path = os.path.join(face_pretrained_model_or_path, face_filename)
else:
body_model_path = hf_hub_download(
pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only
)
hand_model_path = hf_hub_download(
pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only
)
face_model_path = hf_hub_download(
face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only
)
body_estimation = Body(body_model_path)
hand_estimation = Hand(hand_model_path)
face_estimation = Face(face_model_path)
return cls(body_estimation, hand_estimation, face_estimation)
def to(self, device):
self.body_estimation.to(device)
self.hand_estimation.to(device)
self.face_estimation.to(device)
return self
def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
left_hand = None
right_hand = None
H, W, _ = oriImg.shape
for x, y, w, is_left in util.handDetect(body, oriImg):
peaks = self.hand_estimation(oriImg[y : y + w, x : x + w, :]).astype(np.float32)
if peaks.ndim == 2 and peaks.shape[1] == 2:
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
hand_result = [Keypoint(x=peak[0], y=peak[1]) for peak in peaks]
if is_left:
left_hand = hand_result
else:
right_hand = hand_result
return left_hand, right_hand
def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
face = util.faceDetect(body, oriImg)
if face is None:
return None
x, y, w = face
H, W, _ = oriImg.shape
heatmaps = self.face_estimation(oriImg[y : y + w, x : x + w, :])
peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
if peaks.ndim == 2 and peaks.shape[1] == 2:
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
return [Keypoint(x=peak[0], y=peak[1]) for peak in peaks]
return None
def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
"""
Detect poses in the given image.
Args:
oriImg (numpy.ndarray): The input image for pose detection.
include_hand (bool, optional): Whether to include hand detection. Defaults to False.
include_face (bool, optional): Whether to include face detection. Defaults to False.
Returns:
List[PoseResult]: A list of PoseResult objects containing the detected poses.
"""
oriImg = oriImg[:, :, ::-1].copy()
H, W, C = oriImg.shape
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
bodies = self.body_estimation.format_body_result(candidate, subset)
results = []
for body in bodies:
left_hand, right_hand, face = (None,) * 3
if include_hand:
left_hand, right_hand = self.detect_hands(body, oriImg)
if include_face:
face = self.detect_face(body, oriImg)
results.append(
PoseResult(
BodyResult(
keypoints=[
Keypoint(x=keypoint.x / float(W), y=keypoint.y / float(H))
if keypoint is not None
else None
for keypoint in body.keypoints
],
total_score=body.total_score,
total_parts=body.total_parts,
),
left_hand,
right_hand,
face,
)
)
return results
def __call__(
self,
input_image,
detect_resolution=512,
image_resolution=512,
include_body=True,
include_hand=False,
include_face=False,
hand_and_face=None,
output_type="pil",
**kwargs,
):
if hand_and_face is not None:
warnings.warn(
"hand_and_face is deprecated. Use include_hand and include_face instead.",
DeprecationWarning,
stacklevel=2,
)
include_hand = hand_and_face
include_face = hand_and_face
if "return_pil" in kwargs:
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning, stacklevel=2)
output_type = "pil" if kwargs["return_pil"] else "np"
if type(output_type) is bool:
warnings.warn(
"Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions",
stacklevel=2,
)
if output_type:
output_type = "pil"
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)
H, W, C = input_image.shape
poses = self.detect_poses(input_image, include_hand, include_face)
canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
detected_map = canvas
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,319 @@
import math
from typing import List, NamedTuple, Union
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.model import bodypose_model
class Keypoint(NamedTuple):
x: float
y: float
score: float = 1.0
id: int = -1
class BodyResult(NamedTuple):
# Note: Using `Union` instead of `|` operator as the ladder is a Python
# 3.10 feature.
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
# Python 3.8 environment.
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
keypoints: List[Union[Keypoint, None]]
total_score: float
total_parts: int
class Body(object):
def __init__(self, model_path):
self.model = bodypose_model()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, oriImg):
device = next(iter(self.model.parameters())).device
# scale_search = [0.5, 1.0, 1.5, 2.0]
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
data = data.to(device)
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
paf = util.smart_resize_k(paf, fx=stride, fy=stride)
paf = paf[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += +paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(
one_heatmap >= map_left,
one_heatmap >= map_right,
one_heatmap >= map_up,
one_heatmap >= map_down,
one_heatmap > thre1,
)
)
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0], strict=False)) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
# the middle joints heatmap correpondence
mapIdx = [
[31, 32],
[39, 40],
[33, 34],
[35, 36],
[41, 42],
[43, 44],
[19, 20],
[21, 22],
[23, 24],
[25, 26],
[27, 28],
[29, 30],
[47, 48],
[49, 50],
[53, 54],
[51, 52],
[55, 56],
[37, 38],
[45, 46],
]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if nA != 0 and nB != 0:
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(
zip(
np.linspace(candA[i][0], candB[j][0], num=mid_num),
np.linspace(candA[i][1], candB[j][1], num=mid_num),
strict=False,
)
)
vec_x = np.array(
[
score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 0]
for i in range(len(startend))
]
)
vec_y = np.array(
[
score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 1]
for i in range(len(startend))
]
)
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
0.5 * oriImg.shape[0] / norm - 1, 0
)
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]
)
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
if len(connection) >= min(nA, nB):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array([item for sublist in all_peaks for item in sublist])
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += subset[j2][:-2] + 1
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
# candidate: x, y, score, id
return candidate, subset
@staticmethod
def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
"""
Format the body results from the candidate and subset arrays into a list of BodyResult objects.
Args:
candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
for each body part.
subset (np.ndarray): An array of subsets containing indices to the candidate array for each
person detected. The last two columns of each row hold the total score and total parts
of the person.
Returns:
List[BodyResult]: A list of BodyResult objects, where each object represents a person with
detected keypoints, total score, and total parts.
"""
return [
BodyResult(
keypoints=[
Keypoint(
x=candidate[candidate_index][0],
y=candidate[candidate_index][1],
score=candidate[candidate_index][2],
id=candidate[candidate_index][3],
)
if candidate_index != -1
else None
for candidate_index in person[:18].astype(int)
],
total_score=person[18],
total_parts=person[19],
)
for person in subset
]

View File

@@ -0,0 +1,307 @@
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init
from torchvision.transforms import ToPILImage, ToTensor
from invokeai.backend.bria.controlnet_aux.open_pose import util
class FaceNet(Module):
"""Model the cascading heatmaps."""
def __init__(self):
super(FaceNet, self).__init__()
# cnn to make feature map
self.relu = ReLU()
self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
self.conv1_1 = Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv3_4 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv4_4 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv5_3_CPM = Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=1)
# stage1
self.conv6_1_CPM = Conv2d(in_channels=128, out_channels=512, kernel_size=1, stride=1, padding=0)
self.conv6_2_CPM = Conv2d(in_channels=512, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage2
self.Mconv1_stage2 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage2 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage3
self.Mconv1_stage3 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage3 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage4
self.Mconv1_stage4 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage4 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage5
self.Mconv1_stage5 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage5 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage6
self.Mconv1_stage6 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage6 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
for m in self.modules():
if isinstance(m, Conv2d):
init.constant_(m.bias, 0)
def forward(self, x):
"""Return a list of heatmaps."""
heatmaps = []
h = self.relu(self.conv1_1(x))
h = self.relu(self.conv1_2(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv2_1(h))
h = self.relu(self.conv2_2(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv3_1(h))
h = self.relu(self.conv3_2(h))
h = self.relu(self.conv3_3(h))
h = self.relu(self.conv3_4(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv4_1(h))
h = self.relu(self.conv4_2(h))
h = self.relu(self.conv4_3(h))
h = self.relu(self.conv4_4(h))
h = self.relu(self.conv5_1(h))
h = self.relu(self.conv5_2(h))
h = self.relu(self.conv5_3_CPM(h))
feature_map = h
# stage1
h = self.relu(self.conv6_1_CPM(h))
h = self.conv6_2_CPM(h)
heatmaps.append(h)
# stage2
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage2(h))
h = self.relu(self.Mconv2_stage2(h))
h = self.relu(self.Mconv3_stage2(h))
h = self.relu(self.Mconv4_stage2(h))
h = self.relu(self.Mconv5_stage2(h))
h = self.relu(self.Mconv6_stage2(h))
h = self.Mconv7_stage2(h)
heatmaps.append(h)
# stage3
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage3(h))
h = self.relu(self.Mconv2_stage3(h))
h = self.relu(self.Mconv3_stage3(h))
h = self.relu(self.Mconv4_stage3(h))
h = self.relu(self.Mconv5_stage3(h))
h = self.relu(self.Mconv6_stage3(h))
h = self.Mconv7_stage3(h)
heatmaps.append(h)
# stage4
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage4(h))
h = self.relu(self.Mconv2_stage4(h))
h = self.relu(self.Mconv3_stage4(h))
h = self.relu(self.Mconv4_stage4(h))
h = self.relu(self.Mconv5_stage4(h))
h = self.relu(self.Mconv6_stage4(h))
h = self.Mconv7_stage4(h)
heatmaps.append(h)
# stage5
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage5(h))
h = self.relu(self.Mconv2_stage5(h))
h = self.relu(self.Mconv3_stage5(h))
h = self.relu(self.Mconv4_stage5(h))
h = self.relu(self.Mconv5_stage5(h))
h = self.relu(self.Mconv6_stage5(h))
h = self.Mconv7_stage5(h)
heatmaps.append(h)
# stage6
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage6(h))
h = self.relu(self.Mconv2_stage6(h))
h = self.relu(self.Mconv3_stage6(h))
h = self.relu(self.Mconv4_stage6(h))
h = self.relu(self.Mconv5_stage6(h))
h = self.relu(self.Mconv6_stage6(h))
h = self.Mconv7_stage6(h)
heatmaps.append(h)
return heatmaps
LOG = logging.getLogger(__name__)
TOTEN = ToTensor()
TOPIL = ToPILImage()
params = {
"gaussian_sigma": 2.5,
"inference_img_size": 736, # 368, 736, 1312
"heatmap_peak_thresh": 0.1,
"crop_scale": 1.5,
"line_indices": [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6],
[6, 7],
[7, 8],
[8, 9],
[9, 10],
[10, 11],
[11, 12],
[12, 13],
[13, 14],
[14, 15],
[15, 16],
[17, 18],
[18, 19],
[19, 20],
[20, 21],
[22, 23],
[23, 24],
[24, 25],
[25, 26],
[27, 28],
[28, 29],
[29, 30],
[31, 32],
[32, 33],
[33, 34],
[34, 35],
[36, 37],
[37, 38],
[38, 39],
[39, 40],
[40, 41],
[41, 36],
[42, 43],
[43, 44],
[44, 45],
[45, 46],
[46, 47],
[47, 42],
[48, 49],
[49, 50],
[50, 51],
[51, 52],
[52, 53],
[53, 54],
[54, 55],
[55, 56],
[56, 57],
[57, 58],
[58, 59],
[59, 48],
[60, 61],
[61, 62],
[62, 63],
[63, 64],
[64, 65],
[65, 66],
[66, 67],
[67, 60],
],
}
class Face(object):
"""
The OpenPose face landmark detector model.
Args:
inference_size: set the size of the inference image size, suggested:
368, 736, 1312, default 736
gaussian_sigma: blur the heatmaps, default 2.5
heatmap_peak_thresh: return landmark if over threshold, default 0.1
"""
def __init__(self, face_model_path, inference_size=None, gaussian_sigma=None, heatmap_peak_thresh=None):
self.inference_size = inference_size or params["inference_img_size"]
self.sigma = gaussian_sigma or params["gaussian_sigma"]
self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
self.model = FaceNet()
self.model.load_state_dict(torch.load(face_model_path))
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, face_img):
device = next(iter(self.model.parameters())).device
H, W, C = face_img.shape
w_size = 384
x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
x_data = x_data.to(device)
with torch.no_grad():
hs = self.model(x_data[None, ...])
heatmaps = F.interpolate(hs[-1], (H, W), mode="bilinear", align_corners=True).cpu().numpy()[0]
return heatmaps
def compute_peaks_from_heatmaps(self, heatmaps):
all_peaks = []
for part in range(heatmaps.shape[0]):
map_ori = heatmaps[part].copy()
binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
if np.sum(binary) == 0:
continue
positions = np.where(binary > 0.5)
intensities = map_ori[positions]
mi = np.argmax(intensities)
y, x = positions[0][mi], positions[1][mi]
all_peaks.append([x, y])
return np.array(all_peaks)

View File

@@ -0,0 +1,91 @@
import cv2
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
from skimage.measure import label
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.model import handpose_model
class Hand(object):
def __init__(self, model_path):
self.model = handpose_model()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, oriImgRaw):
device = next(iter(self.model.parameters())).device
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize for x in scale_search]
wsize = 128
heatmap_avg = np.zeros((wsize, wsize, 22))
Hr, Wr, Cr = oriImgRaw.shape
oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = util.smart_resize(oriImg, (scale, scale))
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
data = data.to(device)
with torch.no_grad():
output = self.model(data).cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
heatmap = util.smart_resize(heatmap, (wsize, wsize))
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = util.npmax(map_ori)
y = int(float(y) * float(Hr) / float(wsize))
x = int(float(x) * float(Wr) / float(wsize))
all_peaks.append([x, y])
return np.array(all_peaks)
if __name__ == "__main__":
hand_estimation = Hand("../model/hand_pose_model.pth")
# test_image = '../images/hand.jpg'
test_image = "../images/hand.jpg"
oriImg = cv2.imread(test_image) # B,G,R order
peaks = hand_estimation(oriImg)
canvas = util.draw_handpose(oriImg, peaks, True)
cv2.imshow("", canvas)
cv2.waitKey(0)

View File

@@ -0,0 +1,240 @@
from collections import OrderedDict
import torch
import torch.nn as nn
def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if "pool" in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4])
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
return nn.Sequential(OrderedDict(layers))
class bodypose_model(nn.Module):
def __init__(self):
super(bodypose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv5_5_CPM_L1",
"conv5_5_CPM_L2",
"Mconv7_stage2_L1",
"Mconv7_stage2_L2",
"Mconv7_stage3_L1",
"Mconv7_stage3_L2",
"Mconv7_stage4_L1",
"Mconv7_stage4_L2",
"Mconv7_stage5_L1",
"Mconv7_stage5_L2",
"Mconv7_stage6_L1",
"Mconv7_stage6_L1",
]
blocks = {}
block0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3_CPM", [512, 256, 3, 1, 1]),
("conv4_4_CPM", [256, 128, 3, 1, 1]),
]
)
# Stage 1
block1_1 = OrderedDict(
[
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
]
)
block1_2 = OrderedDict(
[
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
]
)
blocks["block1_1"] = block1_1
blocks["block1_2"] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks["block%d_1" % i] = OrderedDict(
[
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
]
)
blocks["block%d_2" % i] = OrderedDict(
[
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"]
self.model2_1 = blocks["block2_1"]
self.model3_1 = blocks["block3_1"]
self.model4_1 = blocks["block4_1"]
self.model5_1 = blocks["block5_1"]
self.model6_1 = blocks["block6_1"]
self.model1_2 = blocks["block1_2"]
self.model2_2 = blocks["block2_2"]
self.model3_2 = blocks["block3_2"]
self.model4_2 = blocks["block4_2"]
self.model5_2 = blocks["block5_2"]
self.model6_2 = blocks["block6_2"]
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class handpose_model(nn.Module):
def __init__(self):
super(handpose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv6_2_CPM",
"Mconv7_stage2",
"Mconv7_stage3",
"Mconv7_stage4",
"Mconv7_stage5",
"Mconv7_stage6",
]
# stage 1
block1_0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3", [512, 512, 3, 1, 1]),
("conv4_4", [512, 512, 3, 1, 1]),
("conv5_1", [512, 512, 3, 1, 1]),
("conv5_2", [512, 512, 3, 1, 1]),
("conv5_3_CPM", [512, 128, 3, 1, 1]),
]
)
block1_1 = OrderedDict([("conv6_1_CPM", [128, 512, 1, 1, 0]), ("conv6_2_CPM", [512, 22, 1, 1, 0])])
blocks = {}
blocks["block1_0"] = block1_0
blocks["block1_1"] = block1_1
# stage 2-6
for i in range(2, 7):
blocks["block%d" % i] = OrderedDict(
[
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"]
self.model1_1 = blocks["block1_1"]
self.model2 = blocks["block2"]
self.model3 = blocks["block3"]
self.model4 = blocks["block4"]
self.model5 = blocks["block5"]
self.model6 = blocks["block6"]
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6

View File

@@ -0,0 +1,436 @@
import math
from typing import List, Tuple, Union
import cv2
import numpy as np
from invokeai.backend.bria.controlnet_aux.open_pose.body import BodyResult, Keypoint
eps = 0.01
def smart_resize(x, s):
Ht, Wt = s
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
def smart_resize_k(x, fx, fy):
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
Ht, Wt = Ho * fy, Wo * fx
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights[".".join(weights_name.split(".")[1:])]
return transfered_model_weights
def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
"""
Draw keypoints and limbs representing body pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
H, W, C = canvas.shape
stickwidth = 4
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
]
for (k1_index, k2_index), color in zip(limbSeq, colors, strict=False):
keypoint1 = keypoints[k1_index - 1]
keypoint2 = keypoints[k2_index - 1]
if keypoint1 is None or keypoint2 is None:
continue
Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
X = np.array([keypoint1.y, keypoint2.y]) * float(H)
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
for keypoint, color in zip(keypoints, colors, strict=False):
if keypoint is None:
continue
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
return canvas
def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
import matplotlib
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
if not keypoints:
return canvas
H, W, C = canvas.shape
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for ie, (e1, e2) in enumerate(edges):
k1 = keypoints[e1]
k2 = keypoints[e2]
if k1 is None or k2 is None:
continue
x1 = int(k1.x * W)
y1 = int(k1.y * H)
x2 = int(k2.x * W)
y2 = int(k2.y * H)
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
thickness=2,
)
for keypoint in keypoints:
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
return canvas
def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
"""
Draw keypoints representing face pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
if not keypoints:
return canvas
H, W, C = canvas.shape
for keypoint in keypoints:
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
"""
Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
Args:
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
Returns:
List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
corner of the bounding box, the width (height) of the bounding box, and
a boolean flag indicating whether the hand is a left hand (True) or a
right hand (False).
Notes:
- The width and height of the bounding boxes are equal since the network requires squared input.
- The minimum bounding box size is 20 pixels.
"""
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
keypoints = body.keypoints
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
left_shoulder = keypoints[5]
left_elbow = keypoints[6]
left_wrist = keypoints[7]
right_shoulder = keypoints[2]
right_elbow = keypoints[3]
right_wrist = keypoints[4]
# if any of three not detected
has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
if not (has_left or has_right):
return []
hands = []
# left hand
if has_left:
hands.append([left_shoulder.x, left_shoulder.y, left_elbow.x, left_elbow.y, left_wrist.x, left_wrist.y, True])
# right hand
if has_right:
hands.append(
[right_shoulder.x, right_shoulder.y, right_elbow.x, right_elbow.y, right_wrist.x, right_wrist.y, False]
)
for x1, y1, x2, y2, x3, y3, is_left in hands:
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
# handRectangle.x -= handRectangle.width / 2.f;
# handRectangle.y -= handRectangle.height / 2.f;
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width
width2 = width
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append((int(x), int(y), int(width), is_left))
"""
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left.
"""
return detect_result
# Written by Lvmin
def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
"""
Detect the face in the input body pose keypoints and calculate the bounding box for the face.
Args:
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
Returns:
Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
bounding box and the width (height) of the bounding box, or None if the
face is not detected or the bounding box width is less than 20 pixels.
Notes:
- The width and height of the bounding box are equal.
- The minimum bounding box size is 20 pixels.
"""
# left right eye ear 14 15 16 17
image_height, image_width = oriImg.shape[0:2]
keypoints = body.keypoints
head = keypoints[0]
left_eye = keypoints[14]
right_eye = keypoints[15]
left_ear = keypoints[16]
right_ear = keypoints[17]
if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
return None
width = 0.0
x0, y0 = head.x, head.y
if left_eye is not None:
x1, y1 = left_eye.x, left_eye.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if right_eye is not None:
x1, y1 = right_eye.x, right_eye.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if left_ear is not None:
x1, y1 = left_ear.x, left_ear.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
if right_ear is not None:
x1, y1 = right_ear.x, right_ear.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
x, y = x0, y0
x -= width
y -= width
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width * 2
width2 = width * 2
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
if width >= 20:
return int(x), int(y), int(width)
else:
return None
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return i, j

View File

@@ -0,0 +1,260 @@
import os
import random
import cv2
import numpy as np
import torch
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts")
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def make_noise_disk(H, W, C, F):
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = noise[F : F + H, F : F + W]
noise -= np.min(noise)
noise /= np.max(noise)
if C == 1:
noise = noise[:, :, None]
return noise
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def min_max_norm(x):
x -= np.min(x)
x /= np.maximum(np.max(x), 1e-5)
return x
def safe_step(x, step=2):
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y
def img2mask(img, H, W, low=10, high=90):
assert img.ndim == 3 or img.ndim == 2
assert img.dtype == np.uint8
if img.ndim == 3:
y = img[:, :, random.randrange(0, img.shape[2])]
else:
y = img
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
if random.uniform(0, 1) < 0.5:
y = 255 - y
return y < np.percentile(y, random.randrange(low, high))
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def ade_palette():
"""ADE20K palette that maps each class to RGB values."""
return [
[120, 120, 120],
[180, 120, 120],
[6, 230, 230],
[80, 50, 50],
[4, 200, 3],
[120, 120, 80],
[140, 140, 140],
[204, 5, 255],
[230, 230, 230],
[4, 250, 7],
[224, 5, 255],
[235, 255, 7],
[150, 5, 61],
[120, 120, 70],
[8, 255, 51],
[255, 6, 82],
[143, 255, 140],
[204, 255, 4],
[255, 51, 7],
[204, 70, 3],
[0, 102, 200],
[61, 230, 250],
[255, 6, 51],
[11, 102, 255],
[255, 7, 71],
[255, 9, 224],
[9, 7, 230],
[220, 220, 220],
[255, 9, 92],
[112, 9, 255],
[8, 255, 214],
[7, 255, 224],
[255, 184, 6],
[10, 255, 71],
[255, 41, 10],
[7, 255, 255],
[224, 255, 8],
[102, 8, 255],
[255, 61, 6],
[255, 194, 7],
[255, 122, 8],
[0, 255, 20],
[255, 8, 41],
[255, 5, 153],
[6, 51, 255],
[235, 12, 255],
[160, 150, 20],
[0, 163, 255],
[140, 140, 140],
[250, 10, 15],
[20, 255, 0],
[31, 255, 0],
[255, 31, 0],
[255, 224, 0],
[153, 255, 0],
[0, 0, 255],
[255, 71, 0],
[0, 235, 255],
[0, 173, 255],
[31, 0, 255],
[11, 200, 200],
[255, 82, 0],
[0, 255, 245],
[0, 61, 255],
[0, 255, 112],
[0, 255, 133],
[255, 0, 0],
[255, 163, 0],
[255, 102, 0],
[194, 255, 0],
[0, 143, 255],
[51, 255, 0],
[0, 82, 255],
[0, 255, 41],
[0, 255, 173],
[10, 0, 255],
[173, 255, 0],
[0, 255, 153],
[255, 92, 0],
[255, 0, 255],
[255, 0, 245],
[255, 0, 102],
[255, 173, 0],
[255, 0, 20],
[255, 184, 184],
[0, 31, 255],
[0, 255, 61],
[0, 71, 255],
[255, 0, 204],
[0, 255, 194],
[0, 255, 82],
[0, 10, 255],
[0, 112, 255],
[51, 0, 255],
[0, 194, 255],
[0, 122, 255],
[0, 255, 163],
[255, 153, 0],
[0, 255, 10],
[255, 112, 0],
[143, 255, 0],
[82, 0, 255],
[163, 255, 0],
[255, 235, 0],
[8, 184, 170],
[133, 0, 255],
[0, 255, 92],
[184, 0, 255],
[255, 0, 31],
[0, 184, 255],
[0, 214, 255],
[255, 0, 112],
[92, 255, 0],
[0, 224, 255],
[112, 224, 255],
[70, 184, 160],
[163, 0, 255],
[153, 0, 255],
[71, 255, 0],
[255, 0, 163],
[255, 204, 0],
[255, 0, 143],
[0, 255, 235],
[133, 255, 0],
[255, 0, 235],
[245, 0, 255],
[255, 0, 122],
[255, 245, 0],
[10, 190, 212],
[214, 255, 0],
[0, 204, 255],
[20, 0, 255],
[255, 255, 0],
[0, 153, 255],
[0, 41, 255],
[0, 255, 204],
[41, 0, 255],
[41, 255, 0],
[173, 0, 255],
[0, 245, 255],
[71, 0, 255],
[122, 0, 255],
[0, 255, 184],
[0, 92, 255],
[184, 255, 0],
[0, 133, 255],
[255, 214, 0],
[25, 194, 194],
[102, 255, 0],
[92, 0, 255],
]

View File

@@ -0,0 +1,559 @@
# type: ignore
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.controlnet import zero_module
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.outputs import BaseOutput
from invokeai.backend.bria.transformer_bria import (
EmbedND,
FluxSingleTransformerBlock,
FluxTransformerBlock,
TimestepProjEmbeddings,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
BRIA_CONTROL_MODES = Literal["depth", "canny", "colorgrid", "recolor", "tile", "pose"]
class BriaControlModes(Enum):
depth = 0
canny = 1
colorgrid = 2
recolor = 3
tile = 4
pose = 5
@dataclass
class BriaControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
controlnet_single_block_samples: Tuple[torch.Tensor]
class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Optional[List[int]] = None,
num_mode: int = None,
rope_theta: int = 10000,
time_theta: int = 10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
# self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
# text_time_guidance_cls = (
# CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
# )
# self.time_text_embed = text_time_guidance_cls(
# embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
# )
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.union = num_mode is not None and num_mode > 0
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self):
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(
cls,
transformer,
num_layers: int = 4,
num_single_layers: int = 10,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
load_weights_from_transformer=True,
):
config = transformer.config
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.single_transformer_blocks.load_state_dict(
transformer.single_transformer_blocks.state_dict(), strict=False
)
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
return controlnet
def forward(
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if guidance is not None:
print("guidance is not supported in BriaControlNetModel")
if pooled_projections is not None:
print("pooled_projections is not supported in BriaControlNetModel")
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
# Convert controlnet_cond to the same dtype as the model weights
controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# Validate controlnet_mode values are within the valid range
if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode):
raise ValueError(
f"`controlnet_mode` values must be in range [0, {self.num_mode - 1}], but got values outside this range"
)
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch
controlnet_mode_emb = controlnet_mode_emb.expand(
encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2]
)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0)
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
block_samples = ()
for _, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for _, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
controlnet_single_block_samples = ()
for single_block_sample, controlnet_block in zip(
single_block_samples, self.controlnet_single_blocks, strict=False
):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)
return BriaControlNetOutput(
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
class BriaMultiControlNetModel(ModelMixin):
r"""
`BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
compatible with `BriaControlNetModel`.
Args:
controlnets (`List[BriaControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`BriaControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: List[torch.tensor],
controlnet_mode: List[torch.tensor],
conditioning_scale: List[float],
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[BriaControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1 and self.nets[0].union:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, strict=False)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(
control_block_samples, block_samples, strict=False
)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples, strict=False
)
]
# Regular Multi-ControlNets
# load all ControlNets into memories
else:
for i, (image, mode, scale, controlnet) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets, strict=False)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(
control_block_samples, block_samples, strict=False
)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples, strict=False
)
]
return control_block_samples, control_single_block_samples

View File

@@ -0,0 +1,68 @@
from typing import List, Tuple
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from PIL import Image
@torch.no_grad()
def prepare_control_images(
vae: AutoencoderKL,
control_images: list[Image.Image],
control_modes: list[int],
width: int,
height: int,
device: torch.device,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
tensored_control_images = []
tensored_control_modes = []
for idx, control_image_ in enumerate(control_images):
tensored_control_image = _prepare_image(
image=control_image_,
width=width,
height=height,
device=device,
dtype=vae.dtype,
)
height, width = tensored_control_image.shape[-2:]
# vae encode
tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample()
tensored_control_image = (tensored_control_image) * vae.config.scaling_factor
# pack
height_control_image, width_control_image = tensored_control_image.shape[2:]
tensored_control_image = _pack_latents(
tensored_control_image,
height_control_image,
width_control_image,
)
tensored_control_images.append(tensored_control_image)
tensored_control_modes.append(
torch.tensor(control_modes[idx]).expand(tensored_control_image.shape[0]).to(device, dtype=torch.long)
)
return tensored_control_images, tensored_control_modes
def _prepare_image(
image: Image.Image,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
image = image.convert("RGB")
image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width)
image = image.repeat_interleave(1, dim=0)
image = image.to(device=device, dtype=dtype)
return image
def _pack_latents(latents, height, width):
latents = latents.view(1, 4, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(1, (height // 2) * (width // 2), 16)
return latents

View File

@@ -0,0 +1,636 @@
from typing import Any, Callable, Dict, List, Optional, Union
import diffusers
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusion3Pipeline
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = pipe(prompt).images[0]
>>> image.save("sd3.png")
```
"""
T5_PRECISION = torch.float16
"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
"""
class BriaPipeline(FluxPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Stable Diffusion 3 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
def __init__(
self,
transformer: BriaTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
):
self.register_modules(
vae=vae,
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# TODO - why different than offical flux (-1)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
if self.vae.config.shift_factor is None:
self.vae.config.shift_factor = 0
self.vae.to(dtype=torch.float32)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = get_t5_prompt_embeds(
self.tokenizer,
self.text_encoder,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=self.transformer.dtype)
if do_classifier_free_guidance and negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = (
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = get_t5_prompt_embeds(
self.tokenizer,
self.text_encoder,
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=self.transformer.dtype)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, negative_prompt_embeds, text_ids
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
max_sequence_length: int = 128,
clip_value: Union[None, float] = None,
normalize: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
callback_on_step_end_tensor_inputs = (
["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
)
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
(prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if (
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
and self.scheduler.config["use_dynamic_shifting"]
):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1] # Shift by height - Why just height?
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
else:
# 4. Prepare timesteps
# Sample from training sigmas
if isinstance(self.scheduler, DDIMScheduler) or isinstance(self.scheduler, EulerAncestralDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, None
)
else:
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps,
num_inference_steps=num_inference_steps,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Supprot different diffusers versions
if diffusers.__version__ >= "0.32.0":
latent_image_ids = latent_image_ids[0]
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
cfg_noise_pred_text = noise_pred_text.std()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if normalize:
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
if clip_value:
assert clip_value > 0
noise_pred = noise_pred.clip(-clip_value, clip_value)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
def to(self, *args, **kwargs):
DiffusionPipeline.to(self, *args, **kwargs)
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
for block in self.text_encoder.encoder.block:
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
if self.vae.config.shift_factor == 0 and self.vae.dtype != torch.float32:
self.vae.to(dtype=torch.float32)
return self
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,671 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import diffusers
import numpy as np
import torch
from diffusers import AutoencoderKL # Waiting for diffusers udpdate
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND, logging
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import randn_tensor
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
from invokeai.backend.bria.controlnet_bria import BriaControlNetModel
from invokeai.backend.bria.pipeline_bria import BriaPipeline
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class BriaControlNetPipeline(BriaPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Stable Diffusion 3 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__( # EYAL - removed clip text encoder + tokenizer
self,
transformer: BriaTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
controlnet: BriaControlNetModel,
):
super().__init__(
transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
)
self.register_modules(controlnet=controlnet)
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
num_channels_latents = self.transformer.config.in_channels // 4
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
return control_image, control_mode
def prepare_multi_control(
self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode
):
num_channels_latents = self.transformer.config.in_channels // 4
control_images = []
for _, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
control_image = control_images
# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
control_modes = []
for cmode in control_mode:
if cmode is None:
cmode = -1
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
control_modes.append(control_mode)
control_mode = control_modes
return control_image, control_mode
def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
return controlnet_keep
def get_control_start_end(self, control_guidance_start, control_guidance_end):
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = 1 # TODO - why is this 1?
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
return control_guidance_start, control_guidance_end
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_image: Optional[PipelineImageInput] = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
latent_image_ids: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
text_ids: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
max_sequence_length: int = 128,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
control_guidance_start, control_guidance_end = self.get_control_start_end(
control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
)
# 1. Check inputs. Raise error if not correct
callback_on_step_end_tensor_inputs = (
["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
)
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
device = self._execution_device
# 4. Prepare timesteps
if (
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
and self.scheduler.config["use_dynamic_shifting"]
):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
# Determine image sequence length
if control_image is not None:
if isinstance(control_image, list):
image_seq_len = control_image[0].shape[1]
else:
image_seq_len = control_image.shape[1]
else:
# Use latents sequence length when no control image is provided
image_seq_len = latents.shape[1]
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
else:
# 5. Prepare timesteps
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Create tensor stating which controlnets to keep
if control_image is not None:
controlnet_keep = self.get_controlnet_keep(
timesteps=timesteps,
control_guidance_start=control_guidance_start,
control_guidance_end=control_guidance_end,
)
if diffusers.__version__ >= "0.32.0":
latent_image_ids = latent_image_ids[0]
text_ids = text_ids[0]
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# EYAL - added the CFG loop
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# Handling ControlNet
if control_image is not None:
if isinstance(controlnet_keep[i], list):
if isinstance(controlnet_conditioning_scale, list):
cond_scale = controlnet_conditioning_scale
else:
cond_scale = [
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)
]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=cond_scale,
timestep=timestep,
# guidance=guidance,
# pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)
else:
controlnet_block_samples, controlnet_single_block_samples = None, None
# This is predicts "v" from flow-matching
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
def encode_prompt(
prompt: Union[str, List[str]],
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or torch.device("cuda")
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
# dynamically adjust the LoRA scale
if text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
dtype = text_encoder.dtype if text_encoder is not None else torch.float32
if prompt_embeds is None:
prompt_embeds = get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=dtype)
if negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=dtype)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if text_encoder is not None:
if USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(text_encoder, lora_scale)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, negative_prompt_embeds, text_ids
def prepare_latents(
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator,
latents: Optional[torch.FloatTensor] = None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
vae_scale_factor = 16
height = 2 * (int(height) // vae_scale_factor)
width = 2 * (int(width) // vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents

View File

@@ -0,0 +1,322 @@
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from invokeai.backend.bria.bria_utils import FluxPosEmbed as EmbedND
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Timesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class TimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = Timesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
"""
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: Optional[List[int]] = None,
rope_theta=10000,
time_theta=10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
# if pooled_projection_dim:
# self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
if guidance_embeds:
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
# temb = (
# self.time_text_embed(timestep, pooled_projections)
# if guidance is None
# else self.time_text_embed(timestep, guidance, pooled_projections)
# )
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
# if pooled_projections:
# temb+=self.pooled_text_embed(pooled_projections)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 2:
ids = torch.cat((txt_ids, img_ids), dim=0)
else:
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -112,7 +112,7 @@ def denoise(
)
# Slice prediction to only include the main image tokens
if img_cond_seq is not None:
if img_input_ids is not None:
pred = pred[:, :original_seq_len]
step_cfg_scale = cfg_scale[step_index]
@@ -125,26 +125,9 @@ def denoise(
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
# For negative prediction with Kontext, we need to include the reference images
# to maintain consistency between positive and negative passes. Without this,
# CFG would create artifacts as the attention mechanism would see different
# spatial structures in each pass
neg_img_input = img
neg_img_input_ids = img_ids
# Add channel-wise conditioning for negative pass if present
if img_cond is not None:
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (Kontext) for negative pass
# This ensures reference images are processed consistently
if img_cond_seq is not None:
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
neg_pred = model(
img=neg_img_input,
img_ids=neg_img_input_ids,
img=img,
img_ids=img_ids,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
@@ -157,10 +140,6 @@ def denoise(
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)
# Slice negative prediction to match main image tokens
if img_cond_seq is not None:
neg_pred = neg_pred[:, :original_seq_len]
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
preview_img = img - t_curr * pred

View File

@@ -1,14 +1,15 @@
import einops
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
def generate_img_ids_with_offset(
@@ -18,10 +19,8 @@ def generate_img_ids_with_offset(
device: torch.device,
dtype: torch.dtype,
idx_offset: int = 0,
h_offset: int = 0,
w_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with optional index and spatial offsets.
"""Generate tensor of image position ids with an optional offset.
Args:
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
@@ -29,9 +28,7 @@ def generate_img_ids_with_offset(
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids (default: 0).
h_offset (int): Spatial offset for height/y-coordinates in latent space (default: 0).
w_offset (int): Spatial offset for width/x-coordinates in latent space (default: 0).
idx_offset (int): Offset to add to the first dimension of the image ids.
Returns:
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
@@ -45,10 +42,6 @@ def generate_img_ids_with_offset(
packed_height = latent_height // 2
packed_width = latent_width // 2
# Convert spatial offsets from latent space to packed space
packed_h_offset = h_offset // 2
packed_w_offset = w_offset // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
@@ -56,13 +49,13 @@ def generate_img_ids_with_offset(
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions) with spatial offset
y_indices = torch.arange(packed_height, device=device, dtype=dtype) + packed_h_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions) with spatial offset
x_indices = torch.arange(packed_width, device=device, dtype=dtype) + packed_w_offset
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
@@ -80,14 +73,14 @@ class KontextExtension:
def __init__(
self,
kontext_conditioning: list[FluxKontextConditioningField],
kontext_conditioning: FluxKontextConditioningField,
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
):
"""
Initializes the KontextExtension, pre-processing the reference images
Initializes the KontextExtension, pre-processing the reference image
into latents and positional IDs.
"""
self._context = context
@@ -100,116 +93,54 @@ class KontextExtension:
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference images and prepares their concatenated latents and IDs with spatial tiling."""
all_latents = []
all_ids = []
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
# Track cumulative dimensions for spatial tiling
# These track the running extent of the virtual canvas in latent space
canvas_h = 0 # Running canvas height
canvas_w = 0 # Running canvas width
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
for idx, kontext_field in enumerate(self.kontext_conditioning):
image = self._context.images.get_pil(kontext_field.image.image_name)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Convert to RGB
image = image.convert("RGB")
# Pack the latents and generate IDs
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1,
)
# Convert to tensor using torchvision transforms for consistency
transformation = T.Compose(
[
T.ToTensor(), # Converts PIL image to tensor and scales to [0, 1]
]
)
image_tensor = transformation(image)
# Convert from [0, 1] to [-1, 1] range expected by VAE
image_tensor = image_tensor * 2.0 - 1.0
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
# Estimate working memory for encode operation (50% of decode memory requirements)
img_h = image_tensor.shape[-2]
img_w = image_tensor.shape[-1]
element_size = next(vae_info.model.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
estimated_working_memory = int(img_h * img_w * element_size * scaling_constant)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
# Use sample=False to get the distribution mean without noise
kontext_latents_unpacked = vae.encode(image_tensor, sample=False)
TorchDevice.empty_cache()
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pad latents to be compatible with patch_size=2
# This ensures dimensions are even for the pack() function
pad_h = (2 - latent_height % 2) % 2
pad_w = (2 - latent_width % 2) % 2
if pad_h > 0 or pad_w > 0:
kontext_latents_unpacked = F.pad(kontext_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
# Update dimensions after padding
_, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
# Determine spatial offsets for this reference image
h_offset = 0
w_offset = 0
if idx > 0: # First image starts at (0, 0)
# Calculate potential canvas dimensions for each tiling option
# Option 1: Tile vertically (below existing content)
potential_h_vertical = canvas_h + latent_height
# Option 2: Tile horizontally (to the right of existing content)
potential_w_horizontal = canvas_w + latent_width
# Choose arrangement that minimizes the maximum dimension
# This keeps the canvas closer to square, optimizing attention computation
if potential_h_vertical > potential_w_horizontal:
# Tile horizontally (to the right of existing images)
w_offset = canvas_w
canvas_w = canvas_w + latent_width
canvas_h = max(canvas_h, latent_height)
else:
# Tile vertically (below existing images)
h_offset = canvas_h
canvas_h = canvas_h + latent_height
canvas_w = max(canvas_w, latent_width)
else:
# First image - just set canvas dimensions
canvas_h = latent_height
canvas_w = latent_width
# Generate IDs with both index offset and spatial offsets
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1, # All reference images use index=1 (matching ComfyUI implementation)
h_offset=h_offset,
w_offset=w_offset,
)
all_latents.append(kontext_latents_packed)
all_ids.append(kontext_ids)
# Concatenate all latents and IDs along the sequence dimension
concatenated_latents = torch.cat(all_latents, dim=1) # Concatenate along sequence dimension
concatenated_ids = torch.cat(all_ids, dim=1) # Concatenate along sequence dimension
return concatenated_latents, concatenated_ids
return kontext_latents_packed, kontext_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""

View File

@@ -1,304 +0,0 @@
# This file is vendored from https://github.com/ShieldMnt/invisible-watermark
#
# `invisible-watermark` is MIT licensed as of August 23, 2025, when the code was copied into this repo.
#
# Why we vendored it in:
# `invisible-watermark` has a dependency on `opencv-python`, which conflicts with Invoke's dependency on
# `opencv-contrib-python`. It's easier to copy the code over than complicate the installation process by
# requiring an extra post-install step of removing `opencv-python` and installing `opencv-contrib-python`.
import struct
import uuid
import base64
import cv2
import numpy as np
import pywt
class WatermarkEncoder(object):
def __init__(self, content=b""):
seq = np.array([n for n in content], dtype=np.uint8)
self._watermarks = list(np.unpackbits(seq))
self._wmLen = len(self._watermarks)
self._wmType = "bytes"
def set_by_ipv4(self, addr):
bits = []
ips = addr.split(".")
for ip in ips:
bits += list(np.unpackbits(np.array([ip % 255], dtype=np.uint8)))
self._watermarks = bits
self._wmLen = len(self._watermarks)
self._wmType = "ipv4"
assert self._wmLen == 32
def set_by_uuid(self, uid):
u = uuid.UUID(uid)
self._wmType = "uuid"
seq = np.array([n for n in u.bytes], dtype=np.uint8)
self._watermarks = list(np.unpackbits(seq))
self._wmLen = len(self._watermarks)
def set_by_bytes(self, content):
self._wmType = "bytes"
seq = np.array([n for n in content], dtype=np.uint8)
self._watermarks = list(np.unpackbits(seq))
self._wmLen = len(self._watermarks)
def set_by_b16(self, b16):
content = base64.b16decode(b16)
self.set_by_bytes(content)
self._wmType = "b16"
def set_by_bits(self, bits=[]):
self._watermarks = [int(bit) % 2 for bit in bits]
self._wmLen = len(self._watermarks)
self._wmType = "bits"
def set_watermark(self, wmType="bytes", content=""):
if wmType == "ipv4":
self.set_by_ipv4(content)
elif wmType == "uuid":
self.set_by_uuid(content)
elif wmType == "bits":
self.set_by_bits(content)
elif wmType == "bytes":
self.set_by_bytes(content)
elif wmType == "b16":
self.set_by_b16(content)
else:
raise NameError("%s is not supported" % wmType)
def get_length(self):
return self._wmLen
# @classmethod
# def loadModel(cls):
# RivaWatermark.loadModel()
def encode(self, cv2Image, method="dwtDct", **configs):
(r, c, channels) = cv2Image.shape
if r * c < 256 * 256:
raise RuntimeError("image too small, should be larger than 256x256")
if method == "dwtDct":
embed = EmbedMaxDct(self._watermarks, wmLen=self._wmLen, **configs)
return embed.encode(cv2Image)
# elif method == 'dwtDctSvd':
# embed = EmbedDwtDctSvd(self._watermarks, wmLen=self._wmLen, **configs)
# return embed.encode(cv2Image)
# elif method == 'rivaGan':
# embed = RivaWatermark(self._watermarks, self._wmLen)
# return embed.encode(cv2Image)
else:
raise NameError("%s is not supported" % method)
class WatermarkDecoder(object):
def __init__(self, wm_type="bytes", length=0):
self._wmType = wm_type
if wm_type == "ipv4":
self._wmLen = 32
elif wm_type == "uuid":
self._wmLen = 128
elif wm_type == "bytes":
self._wmLen = length
elif wm_type == "bits":
self._wmLen = length
elif wm_type == "b16":
self._wmLen = length
else:
raise NameError("%s is unsupported" % wm_type)
def reconstruct_ipv4(self, bits):
ips = [str(ip) for ip in list(np.packbits(bits))]
return ".".join(ips)
def reconstruct_uuid(self, bits):
nums = np.packbits(bits)
bstr = b""
for i in range(16):
bstr += struct.pack(">B", nums[i])
return str(uuid.UUID(bytes=bstr))
def reconstruct_bits(self, bits):
# return ''.join([str(b) for b in bits])
return bits
def reconstruct_b16(self, bits):
bstr = self.reconstruct_bytes(bits)
return base64.b16encode(bstr)
def reconstruct_bytes(self, bits):
nums = np.packbits(bits)
bstr = b""
for i in range(self._wmLen // 8):
bstr += struct.pack(">B", nums[i])
return bstr
def reconstruct(self, bits):
if len(bits) != self._wmLen:
raise RuntimeError("bits are not matched with watermark length")
if self._wmType == "ipv4":
return self.reconstruct_ipv4(bits)
elif self._wmType == "uuid":
return self.reconstruct_uuid(bits)
elif self._wmType == "bits":
return self.reconstruct_bits(bits)
elif self._wmType == "b16":
return self.reconstruct_b16(bits)
else:
return self.reconstruct_bytes(bits)
def decode(self, cv2Image, method="dwtDct", **configs):
(r, c, channels) = cv2Image.shape
if r * c < 256 * 256:
raise RuntimeError("image too small, should be larger than 256x256")
bits = []
if method == "dwtDct":
embed = EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
bits = embed.decode(cv2Image)
# elif method == 'dwtDctSvd':
# embed = EmbedDwtDctSvd(watermarks=[], wmLen=self._wmLen, **configs)
# bits = embed.decode(cv2Image)
# elif method == 'rivaGan':
# embed = RivaWatermark(watermarks=[], wmLen=self._wmLen, **configs)
# bits = embed.decode(cv2Image)
else:
raise NameError("%s is not supported" % method)
return self.reconstruct(bits)
# @classmethod
# def loadModel(cls):
# RivaWatermark.loadModel()
class EmbedMaxDct(object):
def __init__(self, watermarks=[], wmLen=8, scales=[0, 36, 36], block=4):
self._watermarks = watermarks
self._wmLen = wmLen
self._scales = scales
self._block = block
def encode(self, bgr):
(row, col, channels) = bgr.shape
yuv = cv2.cvtColor(bgr, cv2.COLOR_BGR2YUV)
for channel in range(2):
if self._scales[channel] <= 0:
continue
ca1, (h1, v1, d1) = pywt.dwt2(yuv[: row // 4 * 4, : col // 4 * 4, channel], "haar")
self.encode_frame(ca1, self._scales[channel])
yuv[: row // 4 * 4, : col // 4 * 4, channel] = pywt.idwt2((ca1, (v1, h1, d1)), "haar")
bgr_encoded = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR)
return bgr_encoded
def decode(self, bgr):
(row, col, channels) = bgr.shape
yuv = cv2.cvtColor(bgr, cv2.COLOR_BGR2YUV)
scores = [[] for i in range(self._wmLen)]
for channel in range(2):
if self._scales[channel] <= 0:
continue
ca1, (h1, v1, d1) = pywt.dwt2(yuv[: row // 4 * 4, : col // 4 * 4, channel], "haar")
scores = self.decode_frame(ca1, self._scales[channel], scores)
avgScores = list(map(lambda l: np.array(l).mean(), scores))
bits = np.array(avgScores) * 255 > 127
return bits
def decode_frame(self, frame, scale, scores):
(row, col) = frame.shape
num = 0
for i in range(row // self._block):
for j in range(col // self._block):
block = frame[
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
]
score = self.infer_dct_matrix(block, scale)
# score = self.infer_dct_svd(block, scale)
wmBit = num % self._wmLen
scores[wmBit].append(score)
num = num + 1
return scores
def diffuse_dct_svd(self, block, wmBit, scale):
u, s, v = np.linalg.svd(cv2.dct(block))
s[0] = (s[0] // scale + 0.25 + 0.5 * wmBit) * scale
return cv2.idct(np.dot(u, np.dot(np.diag(s), v)))
def infer_dct_svd(self, block, scale):
u, s, v = np.linalg.svd(cv2.dct(block))
score = 0
score = int((s[0] % scale) > scale * 0.5)
return score
if score >= 0.5:
return 1.0
else:
return 0.0
def diffuse_dct_matrix(self, block, wmBit, scale):
pos = np.argmax(abs(block.flatten()[1:])) + 1
i, j = pos // self._block, pos % self._block
val = block[i][j]
if val >= 0.0:
block[i][j] = (val // scale + 0.25 + 0.5 * wmBit) * scale
else:
val = abs(val)
block[i][j] = -1.0 * (val // scale + 0.25 + 0.5 * wmBit) * scale
return block
def infer_dct_matrix(self, block, scale):
pos = np.argmax(abs(block.flatten()[1:])) + 1
i, j = pos // self._block, pos % self._block
val = block[i][j]
if val < 0:
val = abs(val)
if (val % scale) > 0.5 * scale:
return 1
else:
return 0
def encode_frame(self, frame, scale):
"""
frame is a matrix (M, N)
we get K (watermark bits size) blocks (self._block x self._block)
For i-th block, we encode watermark[i] bit into it
"""
(row, col) = frame.shape
num = 0
for i in range(row // self._block):
for j in range(col // self._block):
block = frame[
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
]
wmBit = self._watermarks[(num % self._wmLen)]
diffusedBlock = self.diffuse_dct_matrix(block, wmBit, scale)
# diffusedBlock = self.diffuse_dct_svd(block, wmBit, scale)
frame[
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
] = diffusedBlock
num = num + 1

View File

@@ -6,10 +6,13 @@ configuration variable, that allows the watermarking to be supressed.
import cv2
import numpy as np
from imwatermark import WatermarkEncoder
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.image_util.imwatermark.vendor import WatermarkEncoder
from invokeai.app.services.config.config_default import get_config
config = get_config()
class InvisibleWatermark:

View File

@@ -1,109 +0,0 @@
from typing import Optional
import torch
from PIL import Image
# Import SAM2 components - these should be available in transformers 4.56.0+
from transformers.models.sam2 import Sam2Model
from transformers.models.sam2.processing_sam2 import Sam2Processor
from invokeai.backend.image_util.segment_anything.shared import SAMInput
from invokeai.backend.raw_model import RawModel
class SegmentAnything2Pipeline(RawModel):
"""A wrapper class for the transformers SAM2 model and processor that makes it compatible with the model manager."""
def __init__(self, sam2_model: Sam2Model, sam2_processor: Sam2Processor):
"""Initialize the SAM2 pipeline.
Args:
sam2_model: The SAM2 model
sam2_processor: The SAM2 processor (can be Sam2Processor or Sam2VideoProcessor)
"""
self._sam2_model = sam2_model
self._sam2_processor = sam2_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK: The SAM2 pipeline may not work on MPS devices. We only allow it to be moved to CPU or CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._sam2_model.to(device=device, dtype=dtype)
def calc_size(self) -> int:
# HACK: Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._sam2_model)
def segment(
self,
image: Image.Image,
inputs: list[SAMInput],
) -> torch.Tensor:
"""Segment the image using the provided inputs.
Args:
image: The image to segment.
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
input_boxes: list[list[float]] = []
input_points: list[list[list[float]]] = []
input_labels: list[list[int]] = []
for i in inputs:
box: list[float] | None = None
points: list[list[float]] | None = None
labels: list[int] | None = None
if i.bounding_box is not None:
box: list[float] | None = [
i.bounding_box.x_min,
i.bounding_box.y_min,
i.bounding_box.x_max,
i.bounding_box.y_max,
]
if i.points is not None:
points = []
labels = []
for point in i.points:
points.append([point.x, point.y])
labels.append(point.label.value)
if box is not None:
input_boxes.append(box)
if points is not None:
input_points.append(points)
if labels is not None:
input_labels.append(labels)
batched_input_boxes = [input_boxes] if input_boxes else None
batched_input_points = [input_points] if input_points else None
batched_input_labels = [input_labels] if input_labels else None
processed_inputs = self._sam2_processor(
images=image,
input_boxes=batched_input_boxes,
input_points=batched_input_points,
input_labels=batched_input_labels,
return_tensors="pt",
).to(self._sam2_model.device)
# Generate masks using the SAM2 model
outputs = self._sam2_model(**processed_inputs)
# Post-process the masks to get the final segmentation
masks = self._sam2_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=processed_inputs.original_sizes,
reshaped_input_sizes=processed_inputs.reshaped_input_sizes,
)
# There should be only one batch.
assert len(masks) == 1
return masks[0]

View File

@@ -1,13 +1,20 @@
from typing import Optional
from typing import Optional, TypeAlias
import torch
from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.image_util.segment_anything.shared import SAMInput
from invokeai.backend.raw_model import RawModel
# Type aliases for the inputs to the SAM model.
ListOfBoundingBoxes: TypeAlias = list[list[int]]
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
ListOfPoints: TypeAlias = list[list[int]]
"""A list of points. Each point is in the format [x, y]."""
ListOfPointLabels: TypeAlias = list[int]
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
@@ -31,65 +38,55 @@ class SegmentAnythingPipeline(RawModel):
def segment(
self,
image: Image.Image,
inputs: list[SAMInput],
bounding_boxes: list[list[int]] | None = None,
point_lists: list[list[list[int]]] | None = None,
) -> torch.Tensor:
"""Segment the image using the provided inputs.
"""Run the SAM model.
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
point_lists will be ignored.
Args:
image: The image to segment.
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
input_boxes: list[list[float]] = []
input_points: list[list[list[float]]] = []
input_labels: list[list[int]] = []
# Prep the inputs:
# - Create a list of bounding boxes or points and labels.
# - Add a batch dimension of 1 to the inputs.
if bounding_boxes:
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
input_points: list[ListOfPoints] | None = None
input_labels: list[ListOfPointLabels] | None = None
elif point_lists:
input_boxes: list[ListOfBoundingBoxes] | None = None
input_points: list[ListOfPoints] | None = []
input_labels: list[ListOfPointLabels] | None = []
for point_list in point_lists:
input_points.append([[p[0], p[1]] for p in point_list])
input_labels.append([p[2] for p in point_list])
for i in inputs:
box: list[float] | None = None
points: list[list[float]] | None = None
labels: list[int] | None = None
else:
raise ValueError("Either bounding_boxes or points and labels must be provided.")
if i.bounding_box is not None:
box: list[float] | None = [
i.bounding_box.x_min,
i.bounding_box.y_min,
i.bounding_box.x_max,
i.bounding_box.y_max,
]
if i.points is not None:
points = []
labels = []
for point in i.points:
points.append([point.x, point.y])
labels.append(point.label.value)
if box is not None:
input_boxes.append(box)
if points is not None:
input_points.append(points)
if labels is not None:
input_labels.append(labels)
batched_input_boxes = [input_boxes] if input_boxes else None
batched_input_points = input_points if input_points else None
batched_input_labels = input_labels if input_labels else None
processed_inputs = self._sam_processor(
inputs = self._sam_processor(
images=image,
input_boxes=batched_input_boxes,
input_points=batched_input_points,
input_labels=batched_input_labels,
input_boxes=input_boxes,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt",
).to(self._sam_model.device)
outputs = self._sam_model(**processed_inputs)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=processed_inputs.original_sizes,
reshaped_input_sizes=processed_inputs.reshaped_input_sizes,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)
# There should be only one batch.

View File

@@ -1,49 +0,0 @@
from enum import Enum
from pydantic import BaseModel, model_validator
from pydantic.fields import Field
class BoundingBox(BaseModel):
x_min: int = Field(..., description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(..., description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(..., description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(..., description="The maximum y-coordinate of the bounding box (exclusive).")
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
def tuple(self) -> tuple[int, int, int, int]:
"""
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
"""
return (self.x_min, self.y_min, self.x_max, self.y_max)
class SAMPointLabel(Enum):
negative = -1
neutral = 0
positive = 1
class SAMPoint(BaseModel):
x: int = Field(..., description="The x-coordinate of the point")
y: int = Field(..., description="The y-coordinate of the point")
label: SAMPointLabel = Field(..., description="The label of the point")
class SAMInput(BaseModel):
bounding_box: BoundingBox | None = Field(None, description="The bounding box to use for segmentation")
points: list[SAMPoint] | None = Field(None, description="The points to use for segmentation")
@model_validator(mode="after")
def check_input(self):
if not self.bounding_box and not self.points:
raise ValueError("Either bounding_box or points must be provided")
return self

View File

@@ -90,11 +90,6 @@ class MainModelDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
class LoraModelDefaultSettings(BaseModel):
weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
model_config = ConfigDict(extra="forbid")
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
@@ -292,9 +287,6 @@ class LoRAConfigBase(ABC, BaseModel):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[LoraModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
@@ -500,15 +492,6 @@ class MainConfigBase(ABC, BaseModel):
variant: AnyVariant = ModelVariantType.Normal
class VideoConfigBase(ABC, BaseModel):
type: Literal[ModelType.Video] = ModelType.Video
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
@@ -666,21 +649,6 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase):
raise NotImplementedError("API models are not parsed from disk.")
class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
"""Model config for API-based video models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
# API models are not stored on disk, so we can't match them.
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@@ -750,13 +718,12 @@ AnyModelConfig = Annotated[
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
class ModelConfigFactory:

View File

@@ -9,7 +9,6 @@ import spandrel
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
@@ -23,7 +22,6 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
LoraModelDefaultSettings,
MainModelDefaultSettings,
ModelConfigFactory,
SubmodelDefinition,
@@ -127,6 +125,8 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"BriaPipeline": ModelType.Main,
"BriaTransformer2DModel": ModelType.ControlNet,
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -218,8 +218,6 @@ class ModelProbe(object):
if not fields["default_settings"]:
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
if fields["type"] in {ModelType.LoRA}:
fields["default_settings"] = get_default_settings_lora()
elif fields["type"] is ModelType.Main:
fields["default_settings"] = get_default_settings_main(fields["base"])
@@ -497,21 +495,9 @@ class ModelProbe(object):
# scan model
scan_result = pscan.scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {model_name} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.")
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {model_name} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.")
raise Exception(f"Error scanning model {model_name} for malware. Aborting import.")
# Probing utilities
@@ -546,10 +532,6 @@ def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAd
return None
def get_default_settings_lora() -> LoraModelDefaultSettings:
return LoraModelDefaultSettings()
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512)
@@ -881,6 +863,8 @@ class PipelineFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion3
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
return BaseModelType.CogView4
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
return BaseModelType.Bria
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
@@ -1030,6 +1014,9 @@ class ControlNetFolderProbe(FolderProbeBase):
if config.get("_class_name", None) == "FluxControlNetModel":
return BaseModelType.Flux
if config.get("_class_name", None) == "BriaTransformer2DModel":
return BaseModelType.Bria
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
if dimension == 768:

View File

@@ -0,0 +1,96 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
DiffusersConfigBase,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
class BriaControlNetDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria control net models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path)
repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path
dtype = self._torch_dtype
try:
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=dtype,
variant=variant,
use_safetensors=False,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
else:
raise e
return result
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
class BriaDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
if submodel_type is None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value
dtype = self._torch_dtype
try:
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=dtype,
variant=variant,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
else:
raise e
return result

View File

@@ -80,7 +80,13 @@ class GenericDiffusersLoader(ModelLoader):
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
"transformer_bria",
]:
if module == "transformer_bria":
module = "invokeai.backend.bria.transformer_bria"
elif class_name == "BriaTransformer2DModel":
class_name = "BriaControlNetModel"
module = "invokeai.backend.bria.controlnet_bria"
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines

View File

@@ -12,6 +12,9 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
@@ -62,6 +65,8 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
else:
# If neither is available, return 0
return 0
elif isinstance(model, (Body, Hand, Face)):
return calc_module_size(model.model)
elif isinstance(
model,
(

View File

@@ -6,17 +6,13 @@ import torch
from picklescan.scanner import scan_file_path
from safetensors import safe_open
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.silence_warnings import SilenceWarnings
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
logger = InvokeAILogger.get_logger()
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
@@ -83,24 +79,8 @@ class ModelOnDisk:
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {path.stem} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(
f"The model {path.stem} is potentially infected by malware. Aborting import."
)
if scan_result.scan_err:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {path.stem} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {path.stem} for malware. Aborting import.")
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):

View File

@@ -149,29 +149,13 @@ flux_kontext = StarterModel(
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (quantized)",
name="FLUX.1 Kontext dev (Quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_krea = StarterModel(
name="FLUX.1 Krea dev",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors",
description="FLUX.1 Krea dev. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_krea_quantized = StarterModel(
name="FLUX.1 Krea dev (quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf",
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -596,14 +580,13 @@ t2i_sketch_sdxl = StarterModel(
)
# endregion
# region SpandrelImageToImage
animesharp_v4_rcan = StarterModel(
name="2x-AnimeSharpV4_RCAN",
realesrgan_anime = StarterModel(
name="RealESRGAN_x4plus_anime_6B",
base=BaseModelType.Any,
source="https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN.safetensors",
description="A 2x upscaling model (optimized for anime images).",
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
type=ModelType.SpandrelImageToImage,
)
realesrgan_x4 = StarterModel(
name="RealESRGAN_x4plus",
base=BaseModelType.Any,
@@ -749,7 +732,7 @@ STARTER_MODELS: list[StarterModel] = [
t2i_lineart_sdxl,
t2i_sketch_sdxl,
realesrgan_x4,
animesharp_v4_rcan,
realesrgan_anime,
realesrgan_x2,
swinir,
t5_base_encoder,
@@ -760,8 +743,6 @@ STARTER_MODELS: list[StarterModel] = [
llava_onevision,
flux_fill,
cogview4,
flux_krea,
flux_krea_quantized,
]
sd1_bundle: list[StarterModel] = [
@@ -813,7 +794,6 @@ flux_bundle: list[StarterModel] = [
flux_redux,
flux_fill,
flux_kontext_quantized,
flux_krea_quantized,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {

View File

@@ -28,11 +28,9 @@ class BaseModelType(str, Enum):
CogView4 = "cogview4"
Imagen3 = "imagen3"
Imagen4 = "imagen4"
Gemini2_5 = "gemini-2.5"
ChatGPT4o = "chatgpt-4o"
FluxKontext = "flux-kontext"
Veo3 = "veo3"
Runway = "runway"
Bria = "bria"
class ModelType(str, Enum):
@@ -54,7 +52,6 @@ class ModelType(str, Enum):
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
Video = "video"
class SubModelType(str, Enum):

View File

@@ -8,12 +8,8 @@ import picklescan.scanner as pscan
import safetensors
import torch
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.taxonomy import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
@@ -63,21 +59,9 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
if scan:
scan_result = pscan.scan_file_path(path)
if scan_result.infected_files != 0:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {path} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"The model {path} is potentially infected by malware. Aborting import.")
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {path} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {path} for malware. Aborting import.")
raise Exception(f"Error scanning model at {path} for malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -18,25 +18,16 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
# Check if keys use transformer prefix
transformer_prefix_keys = [
# Next, check that this is likely a FLUX model by spot-checking a few keys.
expected_keys = [
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
transformer_keys_present = all(k in state_dict for k in transformer_prefix_keys)
all_expected_keys_present = all(k in state_dict for k in expected_keys)
# Check if keys use base_model.model prefix
base_model_prefix_keys = [
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
base_model_keys_present = all(k in state_dict for k in base_model_prefix_keys)
return all_keys_in_peft_format and (transformer_keys_present or base_model_keys_present)
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(
@@ -58,16 +49,8 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Determine which prefix is used and remove it from all keys.
# Check if any key starts with "base_model.model." prefix
has_base_model_prefix = any(k.startswith("base_model.model.") for k in grouped_state_dict.keys())
if has_base_model_prefix:
# Remove the "base_model.model." prefix from all keys.
grouped_state_dict = {k.replace("base_model.model.", ""): v for k, v in grouped_state_dict.items()}
else:
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Constants for FLUX.1
num_double_layers = 19

View File

@@ -20,7 +20,7 @@ def main():
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
with log_time("Initialize FLUX transformer on meta device"):
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]

View File

@@ -33,7 +33,7 @@ def main():
)
# inference_dtype = torch.bfloat16
with log_time("Initialize FLUX transformer on meta device"):
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]

View File

@@ -27,7 +27,7 @@ def main():
"""
model_path = Path("/data/misc/text_encoder_2")
with log_time("Initialize T5 on meta device"):
with log_time("Intialize T5 on meta device"):
model_config = AutoConfig.from_pretrained(model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)

View File

@@ -1,117 +0,0 @@
from typing import Literal
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def estimate_vae_working_memory_sd15_sdxl(
operation: Literal["encode", "decode"],
image_tensor: torch.Tensor,
vae: AutoencoderKL | AutoencoderTiny,
tile_size: int | None,
fp32: bool,
) -> int:
"""Estimate the working memory required to encode or decode the given tensor."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if fp32 else 2
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
if tile_size is not None:
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
h = tile_size
w = tile_size
working_memory = h * w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
working_memory = h * w * element_size * scaling_constant
if fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_cogview4(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = h * w * element_size * scaling_constant
print(f"estimate_vae_working_memory_cogview4: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_flux(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoEncoder
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
out_h = latent_scale_factor_for_operation * image_tensor.shape[-2]
out_w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = out_h * out_w * element_size * scaling_constant
print(f"estimate_vae_working_memory_flux: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_sd3(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = h * w * element_size * scaling_constant
print(f"estimate_vae_working_memory_sd3: {int(working_memory)}")
return int(working_memory)

View File

@@ -44,5 +44,4 @@ yalc.lock
# vitest
tsconfig.vitest-temp.json
coverage/
*.tgz
coverage/

View File

@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
returnNull: false,
});
const store = createStore();
const store = createStore(undefined, false);
$store.set(store);
$baseUrl.set('http://localhost:9090');

Some files were not shown because too many files have changed in this diff Show More