Compare commits

..

27 Commits

Author SHA1 Message Date
psychedelicious
54bda8e8e4 chore: bump version to v6.2.0a1 2025-07-25 17:16:51 +10:00
psychedelicious
c14889f055 tidy(ui): enable devmode redux checks 2025-07-25 17:15:19 +10:00
psychedelicious
86680f296a chore(ui): lint 2025-07-25 17:15:19 +10:00
psychedelicious
de0b7801a6 fix(ui): infinite loop when setting tile controlnet model 2025-07-25 17:15:19 +10:00
psychedelicious
a03c7ca4e3 fix(ui): do not store whole model configs in state 2025-07-25 17:15:19 +10:00
psychedelicious
32af53779d refactor(ui): just manually validate async stuff 2025-07-25 17:15:19 +10:00
psychedelicious
a8662953fc refactor(ui): work around zod async validation issue 2025-07-25 17:15:19 +10:00
psychedelicious
82cdfd83e4 fix(ui): check initial retrieval and set as last persisted 2025-07-25 17:15:19 +10:00
psychedelicious
3f3fdf0b43 chore(ui): bump zod to latest
Checking if it fixes an issue w/ async validators
2025-07-25 17:15:18 +10:00
psychedelicious
53dbd5a7c9 refactor(ui): use zod for all redux state 2025-07-25 17:15:18 +10:00
psychedelicious
bbe5979349 refactor(ui): use zod for all redux state (wip)
needed for confidence w/ state rehydration logic
2025-07-25 17:15:18 +10:00
psychedelicious
ca70540ddd feat(ui): iterate on storage api 2025-07-25 17:15:18 +10:00
psychedelicious
37e25ccbf7 refactor(ui): restructure persistence driver creation to support custom drivers 2025-07-25 17:15:18 +10:00
psychedelicious
28e7a83f98 revert(ui): temp changes to main.tsx for testing 2025-07-25 17:15:18 +10:00
psychedelicious
3b39912b1c revert(ui): temp disable eslint rule 2025-07-25 17:15:18 +10:00
psychedelicious
c76698f205 git: update gitignore 2025-07-25 17:15:18 +10:00
psychedelicious
8f27a393d8 wip 2025-07-25 17:15:18 +10:00
psychedelicious
84ff6dbe69 chore: ruff 2025-07-25 17:15:18 +10:00
psychedelicious
4620a2137c tests(app): service mocks 2025-07-25 17:15:18 +10:00
psychedelicious
8ddbd979dd chore(ui): lint 2025-07-25 17:15:17 +10:00
psychedelicious
19ec9d268e refactor(ui): iterate on persistence 2025-07-25 17:15:17 +10:00
psychedelicious
ab683802ba refactor(ui): iterate on persistence 2025-07-25 17:15:17 +10:00
psychedelicious
98957ec9ea refactor(ui): alternate approach to slice configs 2025-07-25 17:15:17 +10:00
psychedelicious
7936ee9b7f chore(ui): typegen 2025-07-25 17:15:17 +10:00
psychedelicious
a96b7afdfb feat(api): make client state key query not body 2025-07-25 17:15:17 +10:00
psychedelicious
bb58a70b70 refactor(ui): cleaner slice definitions 2025-07-25 17:15:17 +10:00
psychedelicious
aaa1e1a480 feat: server-side client state persistence 2025-07-25 17:15:17 +10:00
263 changed files with 3962 additions and 7985 deletions

View File

@@ -45,9 +45,6 @@ jobs:
steps:
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
# the /mnt dir has 70GBs of free space
# /dev/sda1 74G 28K 70G 1% /mnt
# According to some online posts the /mnt is not always there, so checking before setting docker to use it
run: |
echo "----- Free space before cleanup"
df -h
@@ -55,11 +52,6 @@ jobs:
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
if [ -d /mnt ]; then
sudo chmod -R 777 /mnt
echo '{"data-root": "/mnt/docker-root"}' | sudo tee /etc/docker/daemon.json
sudo systemctl restart docker
fi
echo "----- Free space after cleanup"
df -h

View File

@@ -1,30 +0,0 @@
# Checks that large files and LFS-tracked files are properly checked in with pointer format.
# Uses https://github.com/ppremk/lfs-warning to detect LFS issues.
name: 'lfs checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
jobs:
lfs-check:
runs-on: ubuntu-latest
timeout-minutes: 5
permissions:
# Required to label and comment on the PRs
pull-requests: write
steps:
- name: checkout
uses: actions/checkout@v4
- name: check lfs files
uses: ppremk/lfs-warning@v3.3

View File

@@ -39,18 +39,6 @@ jobs:
- name: checkout
uses: actions/checkout@v4
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
run: |
echo "----- Free space before cleanup"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
echo "----- Free space after cleanup"
df -h
- name: check for changed files
if: ${{ inputs.always_run != true }}
id: changed-files

View File

@@ -22,10 +22,6 @@
## GPU_DRIVER can be set to either `cuda` or `rocm` to enable GPU support in the container accordingly.
# GPU_DRIVER=cuda #| rocm
## If you are using ROCM, you will need to ensure that the render group within the container and the host system use the same group ID.
## To obtain the group ID of the render group on the host system, run `getent group render` and grab the number.
# RENDER_GROUP_ID=
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
## It is usually not necessary to change this. Use `id -u` on the host system to find the UID.
# CONTAINER_UID=1000

View File

@@ -43,6 +43,7 @@ ENV \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
UV_INDEX="https://download.pytorch.org/whl/cu124" \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
@@ -73,18 +74,20 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
ulimit -n 30000 && \
uv sync --extra $GPU_DRIVER --frozen
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids" && groupadd render
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv sync --frozen
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
@@ -102,6 +105,8 @@ COPY invokeai ${INVOKEAI_SRC}/invokeai
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
ulimit -n 30000 && \
uv pip install -e .[$GPU_DRIVER]
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv pip install -e .

View File

@@ -1,136 +0,0 @@
# syntax=docker/dockerfile:1.4
#### Web UI ------------------------------------
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./
RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN npx vite build
## Backend ---------------------------------------
FROM library/ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
apt update && apt install -y --no-install-recommends \
ca-certificates \
git \
gosu \
libglib2.0-0 \
libgl1 \
libglx-mesa0 \
build-essential \
libopencv-dev \
libstdc++-10-dev \
wget
ENV \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
VIRTUAL_ENV=/opt/venv \
INVOKEAI_SRC=/opt/invokeai \
PYTHON_VERSION=3.12 \
UV_PYTHON=3.12 \
UV_COMPILE_BYTECODE=1 \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
PATH="/opt/venv/bin:$PATH" \
CONTAINER_UID=${CONTAINER_UID:-1000} \
CONTAINER_GID=${CONTAINER_GID:-1000}
ARG GPU_DRIVER=cuda
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
RUN --mount=type=cache,target=/root/.cache/uv \
uv python install ${PYTHON_VERSION} && \
# chmod --recursive a+rX /root/.local/share/uv/python
chmod 711 /root
WORKDIR ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
ulimit -n 30000 && \
uv sync --extra $GPU_DRIVER --frozen
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
if [ "$GPU_DRIVER" = "rocm" ]; then \
wget -O /tmp/amdgpu-install.deb \
https://repo.radeon.com/amdgpu-install/6.3.4/ubuntu/noble/amdgpu-install_6.3.60304-1_all.deb && \
apt install -y /tmp/amdgpu-install.deb && \
apt update && \
amdgpu-install --usecase=rocm -y && \
apt-get autoclean && \
apt clean && \
rm -rf /tmp/* /var/tmp/* && \
usermod -a -G render ubuntu && \
usermod -a -G video ubuntu && \
echo "\\n/opt/rocm/lib\\n/opt/rocm/lib64" >> /etc/ld.so.conf.d/rocm.conf && \
ldconfig && \
update-alternatives --auto rocm; \
fi
## Heathen711: Leaving this for review input, will remove before merge
# RUN --mount=type=cache,target=/var/cache/apt \
# --mount=type=cache,target=/var/lib/apt \
# if [ "$GPU_DRIVER" = "rocm" ]; then \
# groupadd render && \
# usermod -a -G render ubuntu && \
# usermod -a -G video ubuntu; \
# fi
## Link amdgpu.ids for ROCm builds
## contributed by https://github.com/Rubonnek
# RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
# ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
CMD ["invokeai-web"]
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# add sources last to minimize image changes on code changes
COPY invokeai ${INVOKEAI_SRC}/invokeai
# this should not increase image size because we've already installed dependencies
# in a previous layer
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
ulimit -n 30000 && \
uv pip install -e .[$GPU_DRIVER]

View File

@@ -47,9 +47,8 @@ services:
invokeai-rocm:
<<: *invokeai
environment:
- AMD_VISIBLE_DEVICES=all
- RENDER_GROUP_ID=${RENDER_GROUP_ID}
runtime: amd
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
profiles:
- rocm

View File

@@ -21,17 +21,6 @@ _=$(id ${USER} 2>&1) || useradd -u ${USER_ID} ${USER}
# ensure the UID is correct
usermod -u ${USER_ID} ${USER} 1>/dev/null
## ROCM specific configuration
# render group within the container must match the host render group
# otherwise the container will not be able to access the host GPU.
if [[ -v "RENDER_GROUP_ID" ]] && [[ ! -z "${RENDER_GROUP_ID}" ]]; then
# ensure the render group exists
groupmod -g ${RENDER_GROUP_ID} render
usermod -a -G render ${USER}
usermod -a -G video ${USER}
fi
### Set the $PUBLIC_KEY env var to enable SSH access.
# We do not install openssh-server in the image by default to avoid bloat.
# but it is useful to have the full SSH server e.g. on Runpod.

View File

@@ -13,7 +13,7 @@ run() {
# parse .env file for build args
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER=/ {print $2}' .env)"
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
# default to 'cuda' profile
[[ -z "$profile" ]] && profile="cuda"
@@ -30,7 +30,7 @@ run() {
printf "%s\n" "starting service $service_name"
docker compose --profile "$profile" up -d "$service_name"
docker compose --profile "$profile" logs -f
docker compose logs -f
}
run

View File

@@ -265,7 +265,7 @@ If the key is unrecognized, this call raises an
#### exists(key) -> AnyModelConfig
Returns True if a model with the given key exists in the database.
Returns True if a model with the given key exists in the databsae.
#### search_by_path(path) -> AnyModelConfig
@@ -718,7 +718,7 @@ When downloading remote models is implemented, additional
configuration information, such as list of trigger terms, will be
retrieved from the HuggingFace and Civitai model repositories.
The probed values can be overridden by providing a dictionary in the
The probed values can be overriden by providing a dictionary in the
optional `config` argument passed to `import_model()`. You may provide
overriding values for any of the model's configuration
attributes. Here is an example of setting the
@@ -841,7 +841,7 @@ variable.
#### installer.start(invoker)
The `start` method is called by the API initialization routines when
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.

View File

@@ -16,7 +16,7 @@ We thank [all contributors](https://github.com/invoke-ai/InvokeAI/graphs/contrib
- @psychedelicious (Spencer Mabrito) - Web Team Leader
- @joshistoast (Josh Corbett) - Web Development
- @cheerio (Mary Rogers) - Lead Engineer & Web App Development
- @ebr (Eugene Brodsky) - Cloud/DevOps/Software engineer; your friendly neighbourhood cluster-autoscaler
- @ebr (Eugene Brodsky) - Cloud/DevOps/Sofware engineer; your friendly neighbourhood cluster-autoscaler
- @sunija - Standalone version
- @brandon (Brandon Rising) - Platform, Infrastructure, Backend Systems
- @ryanjdick (Ryan Dick) - Machine Learning & Training

View File

@@ -69,34 +69,34 @@ The following commands vary depending on the version of Invoke being installed a
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
7. Determine the torch backend to use for installation, if any. This is necessary to get the right version of torch installed. This is acheived by using [UV's built in torch support.](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection)
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
=== "Invoke v5.12 and later"
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu128`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.3`.
- **In all other cases, do not use a torch backend.**
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu128`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.10.0 to v5.11.0"
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu126`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.2.4`.
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu126`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.0.0 to v5.9.1"
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.1`.
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
- **In all other cases, do not use an index.**
=== "Invoke v4"
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm5.2`.
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
- **In all other cases, do not use an index.**
8. Install the `invokeai` package. Substitute the package specifier and version.
@@ -105,10 +105,10 @@ The following commands vary depending on the version of Invoke being installed a
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
```
If you determined you needed to use a torch backend in the previous step, you'll need to set the backend like this:
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --torch-backend=<VERSION> --force-reinstall
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
```
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:

View File

@@ -33,45 +33,30 @@ Hardware requirements vary significantly depending on model and image output siz
More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download and Set Up the Launcher
## Step 2: Download
The Launcher manages your Invoke install. Follow these instructions to download and set up the Launcher.
Download the most recent launcher for your operating system:
!!! info "Instructions for each OS"
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
- [Download for Linux](https://download.invoke.ai/Invoke%20Community%20Edition.AppImage)
=== "Windows"
## Step 3: Install or Update
- [Download for Windows](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition.Setup.latest.exe)
- Run the `EXE` to install the Launcher and start it.
- A desktop shortcut will be created; use this to run the Launcher in the future.
- You can delete the `EXE` file you downloaded.
=== "macOS"
- [Download for macOS](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest-arm64.dmg)
- Open the `DMG` and drag the app into `Applications`.
- Run the Launcher using its entry in `Applications`.
- You can delete the `DMG` file you downloaded.
=== "Linux"
- [Download for Linux](https://github.com/invoke-ai/launcher/releases/latest/download/Invoke.Community.Edition-latest.AppImage)
- You may need to edit the `AppImage` file properties and make it executable.
- Optionally move the file to a location that does not require admin privileges and add a desktop shortcut for it.
- Run the Launcher by double-clicking the `AppImage` or the shortcut you made.
## Step 3: Install Invoke
Run the Launcher you just set up if you haven't already. Click **Install** and follow the instructions to install (or update) Invoke.
Run the launcher you just downloaded, click **Install** and follow the instructions to get set up.
If you have an existing Invoke installation, you can select it and let the launcher manage the install. You'll be able to update or launch the installation.
!!! tip "Updating"
!!! warning "Problem running the launcher on macOS"
The Launcher will check for updates for itself _and_ Invoke.
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can manually flag the launcher as safe:
- When the Launcher detects an update is available for itself, you'll get a small popup window. Click through this and the Launcher will update itself.
- When the Launcher detects an update for Invoke, you'll see a small green alert in the Launcher. Click that and follow the instructions to update Invoke.
- Open the **Invoke Community Edition.dmg** file.
- Drag the launcher to **Applications**.
- Open a terminal.
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
You should now be able to run the launcher.
## Step 4: Launch

View File

@@ -41,7 +41,7 @@ Nodes have a "Use Cache" option in their footer. This allows for performance imp
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
### Create Latent Noise
### Noise
An initial noise tensor is necessary for the latent diffusion process. As a result, the Denoising node requires a noise node input.

View File

@@ -5,9 +5,9 @@ from pathlib import Path
from typing import Optional
import torch
from fastapi import Body
from fastapi import Body, HTTPException, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
async def get_invocation_cache_status() -> InvocationCacheStatus:
"""Clears the invocation cache"""
return ApiDependencies.invoker.services.invocation_cache.get_status()
@app_router.get(
"/client_state",
operation_id="get_client_state_by_key",
response_model=JsonValue | None,
)
async def get_client_state_by_key(
key: str = Query(..., description="Key to get"),
) -> JsonValue | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.post(
"/client_state",
operation_id="set_client_state",
response_model=None,
)
async def set_client_state(
key: str = Query(..., description="Key to set"),
value: JsonValue = Body(..., description="Value of the key"),
) -> None:
"""Sets the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.delete(
"/client_state",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state() -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete()
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -1,58 +0,0 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.util.logging import logging
client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"])
@client_state_router.get(
"/{queue_id}/get_by_key",
operation_id="get_client_state_by_key",
response_model=str | None,
)
async def get_client_state_by_key(
queue_id: str = Path(description="The queue id to perform this operation on"),
key: str = Query(..., description="Key to get"),
) -> str | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@client_state_router.post(
"/{queue_id}/set_by_key",
operation_id="set_client_state",
response_model=str,
)
async def set_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
key: str = Query(..., description="Key to set"),
value: str = Body(..., description="Stringified value to set"),
) -> str:
"""Sets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@client_state_router.post(
"/{queue_id}/delete",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -19,7 +19,6 @@ from invokeai.app.api.routers import (
app_info,
board_images,
boards,
client_state,
download_queue,
images,
model_manager,
@@ -132,7 +131,6 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.include_router(client_state.client_state_router, prefix="/api")
app.openapi = get_openapi_func(app)
@@ -157,12 +155,6 @@ def overridden_redoc() -> HTMLResponse:
web_root_path = Path(list(web_dir.__path__)[0])
if app_config.unsafe_disable_picklescan:
logger.warning(
"The unsafe_disable_picklescan option is enabled. This disables malware scanning while installing and"
"loading models, which may allow malicious code to be executed. Use at your own risk."
)
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:

View File

@@ -17,7 +17,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
# should refactor to avoid this duplication.
@@ -39,11 +38,7 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = estimate_vae_working_memory_cogview4(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
@@ -67,8 +62,6 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKL)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")

View File

@@ -6,6 +6,7 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -19,7 +20,6 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
# should refactor to avoid this duplication.
@@ -39,15 +39,22 @@ class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = estimate_vae_working_memory_cogview4(
operation="decode", image_tensor=latents, vae=vae_info.model
)
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

View File

@@ -64,7 +64,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
Gemini2_5Model = "Gemini2_5ModelField"
FluxKontextModel = "FluxKontextModelField"
# endregion

View File

@@ -63,7 +63,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.1.0",
version="4.0.0",
)
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
@@ -153,7 +153,7 @@ class FluxDenoiseInvocation(BaseInvocation):
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
@@ -328,21 +328,6 @@ class FluxDenoiseInvocation(BaseInvocation):
cfg_scale_end_step=self.cfg_scale_end_step,
)
kontext_extension = None
if self.kontext_conditioning:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning
if isinstance(self.kontext_conditioning, list)
else [self.kontext_conditioning],
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
@@ -400,6 +385,19 @@ class FluxDenoiseInvocation(BaseInvocation):
dtype=inference_dtype,
)
kontext_extension = None
if self.kontext_conditioning is not None:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None

View File

@@ -3,6 +3,7 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -17,7 +18,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
@invocation(
@@ -39,11 +39,17 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
assert isinstance(vae_info.model, AutoEncoder)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="decode", image_tensor=latents, vae=vae_info.model
)
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype

View File

@@ -15,7 +15,6 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
@invocation(
@@ -42,12 +41,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
assert isinstance(vae_info.model, AutoEncoder)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

View File

@@ -1347,96 +1347,3 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
image_dto = context.images.save(image=target_image)
return ImageOutput.build(image_dto)
@invocation(
"flux_kontext_image_prep",
title="FLUX Kontext Image Prep",
tags=["image", "concatenate", "flux", "kontext"],
category="image",
version="1.0.0",
)
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Prepares an image or images for use with FLUX Kontext. The first/single image is resized to the nearest
preferred Kontext resolution. All other images are concatenated horizontally, maintaining their aspect ratio."""
images: list[ImageField] = InputField(
description="The images to concatenate",
min_length=1,
max_length=10,
)
use_preferred_resolution: bool = InputField(
default=True, description="Use FLUX preferred resolutions for the first image"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
# Step 1: Load all images
pil_images = []
for image_field in self.images:
image = context.images.get_pil(image_field.image_name, mode="RGBA")
pil_images.append(image)
# Step 2: Determine target resolution for the first image
first_image = pil_images[0]
width, height = first_image.size
if self.use_preferred_resolution:
aspect_ratio = width / height
# Find the closest preferred resolution for the first image
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
scaled_height = 2 * int(target_height / 16)
final_height = 8 * scaled_height # This will be consistent for all images
scaled_width = 2 * int(target_width / 16)
first_width = 8 * scaled_width
else:
# Use original dimensions of first image, ensuring divisibility by 16
final_height = 16 * (height // 16)
first_width = 16 * (width // 16)
# Ensure minimum dimensions
if final_height < 16:
final_height = 16
if first_width < 16:
first_width = 16
# Step 3: Process and resize all images with consistent height
processed_images = []
total_width = 0
for i, image in enumerate(pil_images):
if i == 0:
# First image uses the calculated dimensions
final_width = first_width
else:
# Subsequent images maintain aspect ratio with the same height
img_aspect_ratio = image.width / image.height
# Calculate width that maintains aspect ratio at the target height
calculated_width = int(final_height * img_aspect_ratio)
# Ensure width is divisible by 16 for proper VAE encoding
final_width = 16 * (calculated_width // 16)
# Ensure minimum width
if final_width < 16:
final_width = 16
# Resize image to calculated dimensions
resized_image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
processed_images.append(resized_image)
total_width += final_width
# Step 4: Concatenate images horizontally
concatenated_image = Image.new("RGB", (total_width, final_height))
x_offset = 0
for img in processed_images:
concatenated_image.paste(img, (x_offset, 0))
x_offset += img.width
# Save the concatenated image
image_dto = context.images.save(image=concatenated_image)
return ImageOutput.build(image_dto)

View File

@@ -27,7 +27,6 @@ from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
@invocation(
@@ -53,24 +52,11 @@ class ImageToLatentsInvocation(BaseInvocation):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
@classmethod
@staticmethod
def vae_encode(
cls,
vae_info: LoadedModel,
upcast: bool,
tiled: bool,
image_tensor: torch.Tensor,
tile_size: int = 0,
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
operation="encode",
image_tensor=image_tensor,
vae=vae_info.model,
tile_size=tile_size if tiled else None,
fp32=upcast,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
@@ -127,7 +113,6 @@ class ImageToLatentsInvocation(BaseInvocation):
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
@@ -135,11 +120,7 @@ class ImageToLatentsInvocation(BaseInvocation):
context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info,
upcast=self.fp32,
tiled=self.tiled or context.config.get().force_tiled_decode,
image_tensor=image_tensor,
tile_size=self.tile_size,
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu")

View File

@@ -27,7 +27,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
@invocation(
@@ -54,6 +53,39 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
def _estimate_working_memory(
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if self.fp32 else 2
scaling_constant = 2200 # Determined experimentally.
if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
out_h = tile_size
out_w = tile_size
working_memory = out_h * out_w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
working_memory = out_h * out_w * element_size * scaling_constant
if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
@@ -62,13 +94,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
operation="decode",
image_tensor=latents,
vae=vae_info.model,
tile_size=self.tile_size if use_tiling else None,
fp32=self.fp32,
)
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

View File

@@ -17,7 +17,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
@invocation(
@@ -35,11 +34,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = estimate_vae_working_memory_sd3(
operation="encode", image_tensor=image_tensor, vae=vae_info.model
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
@@ -63,8 +58,6 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKL)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")

View File

@@ -6,6 +6,7 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -19,7 +20,6 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
@invocation(
@@ -41,15 +41,22 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = estimate_vae_working_memory_sd3(
operation="decode", image_tensor=latents, vae=vae_info.model
)
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),

View File

@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from pydantic import JsonValue
class ClientStatePersistenceABC(ABC):
"""
@@ -8,35 +10,26 @@ class ClientStatePersistenceABC(ABC):
"""
@abstractmethod
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
def set_by_key(self, key: str, value: JsonValue) -> None:
"""
Set a key-value pair for the client.
Store the data for the client.
Args:
key (str): The key to set.
value (str): The value to set for the key.
Returns:
str: The value that was set.
:param data: The client data to be stored.
"""
pass
@abstractmethod
def get_by_key(self, queue_id: str, key: str) -> str | None:
def get_by_key(self, key: str) -> JsonValue | None:
"""
Get the value for a specific key of the client.
Get the data for the client.
Args:
key (str): The key to retrieve the value for.
Returns:
str | None: The value associated with the key, or None if the key does not exist.
:return: The client data.
"""
pass
@abstractmethod
def delete(self, queue_id: str) -> None:
def delete(self) -> None:
"""
Delete all client state.
Delete the data for the client.
"""
pass

View File

@@ -1,5 +1,7 @@
import json
from pydantic import JsonValue
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -19,21 +21,8 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def _get(self) -> dict[str, str] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
state = self._get() or {}
def set_by_key(self, key: str, value: JsonValue) -> None:
state = self.get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
@@ -47,15 +36,26 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
(json.dumps(state),),
)
return value
def get(self) -> dict[str, JsonValue] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def get_by_key(self, queue_id: str, key: str) -> str | None:
state = self._get()
def get_by_key(self, key: str) -> JsonValue | None:
state = self.get()
if state is None:
return None
return state.get(key, None)
def delete(self, queue_id: str) -> None:
def delete(self) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""

View File

@@ -107,7 +107,6 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.<br>Valid values: `blake3_multi`, `blake3_single`, `random`, `md5`, `sha1`, `sha224`, `sha256`, `sha384`, `sha512`, `blake2b`, `blake2s`, `sha3_224`, `sha3_256`, `sha3_384`, `sha3_512`, `shake_128`, `shake_256`
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
"""
_root: Optional[Path] = PrivateAttr(default=None)
@@ -197,7 +196,6 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3_single", description="Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
scan_models_on_startup: bool = Field(default=False, description="Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.")
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
# fmt: on

View File

@@ -7,7 +7,7 @@ import threading
import time
from pathlib import Path
from queue import Empty, Queue
from shutil import move, rmtree
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@@ -186,15 +186,13 @@ class ModelInstallService(ModelInstallServiceBase):
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
if preferred_name := config.name:
if Path(model_path).is_file():
# Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot.
preferred_name = f"{preferred_name}{model_path.suffix}"
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
)
try:
new_path = self._move_model(model_path, dest_path)
new_path = self._copy_model(model_path, dest_path)
except FileExistsError as excp:
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
@@ -619,17 +617,30 @@ class ModelInstallService(ModelInstallServiceBase):
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
if old_path.is_dir():
copytree(old_path, new_path)
else:
copyfile(old_path, new_path)
return new_path
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
if new_path.exists():
raise FileExistsError(f"Cannot move {old_path} to {new_path}: destination already exists")
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
move(old_path, new_path)
return new_path
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):

View File

@@ -87,21 +87,9 @@ class ModelLoadService(ModelLoadServiceBase):
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if self._app_config.unsafe_disable_picklescan:
self._logger.warning(
f"Model at {checkpoint} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
if scan_result.scan_err:
if self._app_config.unsafe_disable_picklescan:
self._logger.warning(
f"Error scanning model at {checkpoint} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result

View File

@@ -112,7 +112,7 @@ def denoise(
)
# Slice prediction to only include the main image tokens
if img_cond_seq is not None:
if img_input_ids is not None:
pred = pred[:, :original_seq_len]
step_cfg_scale = cfg_scale[step_index]
@@ -125,26 +125,9 @@ def denoise(
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
# For negative prediction with Kontext, we need to include the reference images
# to maintain consistency between positive and negative passes. Without this,
# CFG would create artifacts as the attention mechanism would see different
# spatial structures in each pass
neg_img_input = img
neg_img_input_ids = img_ids
# Add channel-wise conditioning for negative pass if present
if img_cond is not None:
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (Kontext) for negative pass
# This ensures reference images are processed consistently
if img_cond_seq is not None:
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
neg_pred = model(
img=neg_img_input,
img_ids=neg_img_input_ids,
img=img,
img_ids=img_ids,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
@@ -157,10 +140,6 @@ def denoise(
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)
# Slice negative prediction to match main image tokens
if img_cond_seq is not None:
neg_pred = neg_pred[:, :original_seq_len]
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
preview_img = img - t_curr * pred

View File

@@ -1,14 +1,15 @@
import einops
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
def generate_img_ids_with_offset(
@@ -18,10 +19,8 @@ def generate_img_ids_with_offset(
device: torch.device,
dtype: torch.dtype,
idx_offset: int = 0,
h_offset: int = 0,
w_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with optional index and spatial offsets.
"""Generate tensor of image position ids with an optional offset.
Args:
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
@@ -29,9 +28,7 @@ def generate_img_ids_with_offset(
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids (default: 0).
h_offset (int): Spatial offset for height/y-coordinates in latent space (default: 0).
w_offset (int): Spatial offset for width/x-coordinates in latent space (default: 0).
idx_offset (int): Offset to add to the first dimension of the image ids.
Returns:
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
@@ -45,10 +42,6 @@ def generate_img_ids_with_offset(
packed_height = latent_height // 2
packed_width = latent_width // 2
# Convert spatial offsets from latent space to packed space
packed_h_offset = h_offset // 2
packed_w_offset = w_offset // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
@@ -56,13 +49,13 @@ def generate_img_ids_with_offset(
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions) with spatial offset
y_indices = torch.arange(packed_height, device=device, dtype=dtype) + packed_h_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions) with spatial offset
x_indices = torch.arange(packed_width, device=device, dtype=dtype) + packed_w_offset
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
@@ -80,14 +73,14 @@ class KontextExtension:
def __init__(
self,
kontext_conditioning: list[FluxKontextConditioningField],
kontext_conditioning: FluxKontextConditioningField,
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
):
"""
Initializes the KontextExtension, pre-processing the reference images
Initializes the KontextExtension, pre-processing the reference image
into latents and positional IDs.
"""
self._context = context
@@ -100,116 +93,54 @@ class KontextExtension:
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference images and prepares their concatenated latents and IDs with spatial tiling."""
all_latents = []
all_ids = []
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
# Track cumulative dimensions for spatial tiling
# These track the running extent of the virtual canvas in latent space
canvas_h = 0 # Running canvas height
canvas_w = 0 # Running canvas width
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
for idx, kontext_field in enumerate(self.kontext_conditioning):
image = self._context.images.get_pil(kontext_field.image.image_name)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Convert to RGB
image = image.convert("RGB")
# Pack the latents and generate IDs
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1,
)
# Convert to tensor using torchvision transforms for consistency
transformation = T.Compose(
[
T.ToTensor(), # Converts PIL image to tensor and scales to [0, 1]
]
)
image_tensor = transformation(image)
# Convert from [0, 1] to [-1, 1] range expected by VAE
image_tensor = image_tensor * 2.0 - 1.0
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
# Estimate working memory for encode operation (50% of decode memory requirements)
img_h = image_tensor.shape[-2]
img_w = image_tensor.shape[-1]
element_size = next(vae_info.model.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
estimated_working_memory = int(img_h * img_w * element_size * scaling_constant)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
# Use sample=False to get the distribution mean without noise
kontext_latents_unpacked = vae.encode(image_tensor, sample=False)
TorchDevice.empty_cache()
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pad latents to be compatible with patch_size=2
# This ensures dimensions are even for the pack() function
pad_h = (2 - latent_height % 2) % 2
pad_w = (2 - latent_width % 2) % 2
if pad_h > 0 or pad_w > 0:
kontext_latents_unpacked = F.pad(kontext_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
# Update dimensions after padding
_, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
# Determine spatial offsets for this reference image
h_offset = 0
w_offset = 0
if idx > 0: # First image starts at (0, 0)
# Calculate potential canvas dimensions for each tiling option
# Option 1: Tile vertically (below existing content)
potential_h_vertical = canvas_h + latent_height
# Option 2: Tile horizontally (to the right of existing content)
potential_w_horizontal = canvas_w + latent_width
# Choose arrangement that minimizes the maximum dimension
# This keeps the canvas closer to square, optimizing attention computation
if potential_h_vertical > potential_w_horizontal:
# Tile horizontally (to the right of existing images)
w_offset = canvas_w
canvas_w = canvas_w + latent_width
canvas_h = max(canvas_h, latent_height)
else:
# Tile vertically (below existing images)
h_offset = canvas_h
canvas_h = canvas_h + latent_height
canvas_w = max(canvas_w, latent_width)
else:
# First image - just set canvas dimensions
canvas_h = latent_height
canvas_w = latent_width
# Generate IDs with both index offset and spatial offsets
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1, # All reference images use index=1 (matching ComfyUI implementation)
h_offset=h_offset,
w_offset=w_offset,
)
all_latents.append(kontext_latents_packed)
all_ids.append(kontext_ids)
# Concatenate all latents and IDs along the sequence dimension
concatenated_latents = torch.cat(all_latents, dim=1) # Concatenate along sequence dimension
concatenated_ids = torch.cat(all_ids, dim=1) # Concatenate along sequence dimension
return concatenated_latents, concatenated_ids
return kontext_latents_packed, kontext_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""

View File

@@ -1,304 +0,0 @@
# This file is vendored from https://github.com/ShieldMnt/invisible-watermark
#
# `invisible-watermark` is MIT licensed as of August 23, 2025, when the code was copied into this repo.
#
# Why we vendored it in:
# `invisible-watermark` has a dependency on `opencv-python`, which conflicts with Invoke's dependency on
# `opencv-contrib-python`. It's easier to copy the code over than complicate the installation process by
# requiring an extra post-install step of removing `opencv-python` and installing `opencv-contrib-python`.
import struct
import uuid
import base64
import cv2
import numpy as np
import pywt
class WatermarkEncoder(object):
def __init__(self, content=b""):
seq = np.array([n for n in content], dtype=np.uint8)
self._watermarks = list(np.unpackbits(seq))
self._wmLen = len(self._watermarks)
self._wmType = "bytes"
def set_by_ipv4(self, addr):
bits = []
ips = addr.split(".")
for ip in ips:
bits += list(np.unpackbits(np.array([ip % 255], dtype=np.uint8)))
self._watermarks = bits
self._wmLen = len(self._watermarks)
self._wmType = "ipv4"
assert self._wmLen == 32
def set_by_uuid(self, uid):
u = uuid.UUID(uid)
self._wmType = "uuid"
seq = np.array([n for n in u.bytes], dtype=np.uint8)
self._watermarks = list(np.unpackbits(seq))
self._wmLen = len(self._watermarks)
def set_by_bytes(self, content):
self._wmType = "bytes"
seq = np.array([n for n in content], dtype=np.uint8)
self._watermarks = list(np.unpackbits(seq))
self._wmLen = len(self._watermarks)
def set_by_b16(self, b16):
content = base64.b16decode(b16)
self.set_by_bytes(content)
self._wmType = "b16"
def set_by_bits(self, bits=[]):
self._watermarks = [int(bit) % 2 for bit in bits]
self._wmLen = len(self._watermarks)
self._wmType = "bits"
def set_watermark(self, wmType="bytes", content=""):
if wmType == "ipv4":
self.set_by_ipv4(content)
elif wmType == "uuid":
self.set_by_uuid(content)
elif wmType == "bits":
self.set_by_bits(content)
elif wmType == "bytes":
self.set_by_bytes(content)
elif wmType == "b16":
self.set_by_b16(content)
else:
raise NameError("%s is not supported" % wmType)
def get_length(self):
return self._wmLen
# @classmethod
# def loadModel(cls):
# RivaWatermark.loadModel()
def encode(self, cv2Image, method="dwtDct", **configs):
(r, c, channels) = cv2Image.shape
if r * c < 256 * 256:
raise RuntimeError("image too small, should be larger than 256x256")
if method == "dwtDct":
embed = EmbedMaxDct(self._watermarks, wmLen=self._wmLen, **configs)
return embed.encode(cv2Image)
# elif method == 'dwtDctSvd':
# embed = EmbedDwtDctSvd(self._watermarks, wmLen=self._wmLen, **configs)
# return embed.encode(cv2Image)
# elif method == 'rivaGan':
# embed = RivaWatermark(self._watermarks, self._wmLen)
# return embed.encode(cv2Image)
else:
raise NameError("%s is not supported" % method)
class WatermarkDecoder(object):
def __init__(self, wm_type="bytes", length=0):
self._wmType = wm_type
if wm_type == "ipv4":
self._wmLen = 32
elif wm_type == "uuid":
self._wmLen = 128
elif wm_type == "bytes":
self._wmLen = length
elif wm_type == "bits":
self._wmLen = length
elif wm_type == "b16":
self._wmLen = length
else:
raise NameError("%s is unsupported" % wm_type)
def reconstruct_ipv4(self, bits):
ips = [str(ip) for ip in list(np.packbits(bits))]
return ".".join(ips)
def reconstruct_uuid(self, bits):
nums = np.packbits(bits)
bstr = b""
for i in range(16):
bstr += struct.pack(">B", nums[i])
return str(uuid.UUID(bytes=bstr))
def reconstruct_bits(self, bits):
# return ''.join([str(b) for b in bits])
return bits
def reconstruct_b16(self, bits):
bstr = self.reconstruct_bytes(bits)
return base64.b16encode(bstr)
def reconstruct_bytes(self, bits):
nums = np.packbits(bits)
bstr = b""
for i in range(self._wmLen // 8):
bstr += struct.pack(">B", nums[i])
return bstr
def reconstruct(self, bits):
if len(bits) != self._wmLen:
raise RuntimeError("bits are not matched with watermark length")
if self._wmType == "ipv4":
return self.reconstruct_ipv4(bits)
elif self._wmType == "uuid":
return self.reconstruct_uuid(bits)
elif self._wmType == "bits":
return self.reconstruct_bits(bits)
elif self._wmType == "b16":
return self.reconstruct_b16(bits)
else:
return self.reconstruct_bytes(bits)
def decode(self, cv2Image, method="dwtDct", **configs):
(r, c, channels) = cv2Image.shape
if r * c < 256 * 256:
raise RuntimeError("image too small, should be larger than 256x256")
bits = []
if method == "dwtDct":
embed = EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
bits = embed.decode(cv2Image)
# elif method == 'dwtDctSvd':
# embed = EmbedDwtDctSvd(watermarks=[], wmLen=self._wmLen, **configs)
# bits = embed.decode(cv2Image)
# elif method == 'rivaGan':
# embed = RivaWatermark(watermarks=[], wmLen=self._wmLen, **configs)
# bits = embed.decode(cv2Image)
else:
raise NameError("%s is not supported" % method)
return self.reconstruct(bits)
# @classmethod
# def loadModel(cls):
# RivaWatermark.loadModel()
class EmbedMaxDct(object):
def __init__(self, watermarks=[], wmLen=8, scales=[0, 36, 36], block=4):
self._watermarks = watermarks
self._wmLen = wmLen
self._scales = scales
self._block = block
def encode(self, bgr):
(row, col, channels) = bgr.shape
yuv = cv2.cvtColor(bgr, cv2.COLOR_BGR2YUV)
for channel in range(2):
if self._scales[channel] <= 0:
continue
ca1, (h1, v1, d1) = pywt.dwt2(yuv[: row // 4 * 4, : col // 4 * 4, channel], "haar")
self.encode_frame(ca1, self._scales[channel])
yuv[: row // 4 * 4, : col // 4 * 4, channel] = pywt.idwt2((ca1, (v1, h1, d1)), "haar")
bgr_encoded = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR)
return bgr_encoded
def decode(self, bgr):
(row, col, channels) = bgr.shape
yuv = cv2.cvtColor(bgr, cv2.COLOR_BGR2YUV)
scores = [[] for i in range(self._wmLen)]
for channel in range(2):
if self._scales[channel] <= 0:
continue
ca1, (h1, v1, d1) = pywt.dwt2(yuv[: row // 4 * 4, : col // 4 * 4, channel], "haar")
scores = self.decode_frame(ca1, self._scales[channel], scores)
avgScores = list(map(lambda l: np.array(l).mean(), scores))
bits = np.array(avgScores) * 255 > 127
return bits
def decode_frame(self, frame, scale, scores):
(row, col) = frame.shape
num = 0
for i in range(row // self._block):
for j in range(col // self._block):
block = frame[
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
]
score = self.infer_dct_matrix(block, scale)
# score = self.infer_dct_svd(block, scale)
wmBit = num % self._wmLen
scores[wmBit].append(score)
num = num + 1
return scores
def diffuse_dct_svd(self, block, wmBit, scale):
u, s, v = np.linalg.svd(cv2.dct(block))
s[0] = (s[0] // scale + 0.25 + 0.5 * wmBit) * scale
return cv2.idct(np.dot(u, np.dot(np.diag(s), v)))
def infer_dct_svd(self, block, scale):
u, s, v = np.linalg.svd(cv2.dct(block))
score = 0
score = int((s[0] % scale) > scale * 0.5)
return score
if score >= 0.5:
return 1.0
else:
return 0.0
def diffuse_dct_matrix(self, block, wmBit, scale):
pos = np.argmax(abs(block.flatten()[1:])) + 1
i, j = pos // self._block, pos % self._block
val = block[i][j]
if val >= 0.0:
block[i][j] = (val // scale + 0.25 + 0.5 * wmBit) * scale
else:
val = abs(val)
block[i][j] = -1.0 * (val // scale + 0.25 + 0.5 * wmBit) * scale
return block
def infer_dct_matrix(self, block, scale):
pos = np.argmax(abs(block.flatten()[1:])) + 1
i, j = pos // self._block, pos % self._block
val = block[i][j]
if val < 0:
val = abs(val)
if (val % scale) > 0.5 * scale:
return 1
else:
return 0
def encode_frame(self, frame, scale):
"""
frame is a matrix (M, N)
we get K (watermark bits size) blocks (self._block x self._block)
For i-th block, we encode watermark[i] bit into it
"""
(row, col) = frame.shape
num = 0
for i in range(row // self._block):
for j in range(col // self._block):
block = frame[
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
]
wmBit = self._watermarks[(num % self._wmLen)]
diffusedBlock = self.diffuse_dct_matrix(block, wmBit, scale)
# diffusedBlock = self.diffuse_dct_svd(block, wmBit, scale)
frame[
i * self._block : i * self._block + self._block, j * self._block : j * self._block + self._block
] = diffusedBlock
num = num + 1

View File

@@ -6,10 +6,13 @@ configuration variable, that allows the watermarking to be supressed.
import cv2
import numpy as np
from imwatermark import WatermarkEncoder
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.image_util.imwatermark.vendor import WatermarkEncoder
from invokeai.app.services.config.config_default import get_config
config = get_config()
class InvisibleWatermark:

View File

@@ -9,7 +9,6 @@ import spandrel
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
@@ -494,21 +493,9 @@ class ModelProbe(object):
# scan model
scan_result = pscan.scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {model_name} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.")
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {model_name} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.")
raise Exception(f"Error scanning model {model_name} for malware. Aborting import.")
# Probing utilities

View File

@@ -6,17 +6,13 @@ import torch
from picklescan.scanner import scan_file_path
from safetensors import safe_open
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.silence_warnings import SilenceWarnings
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
logger = InvokeAILogger.get_logger()
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
@@ -83,24 +79,8 @@ class ModelOnDisk:
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {path.stem} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(
f"The model {path.stem} is potentially infected by malware. Aborting import."
)
if scan_result.scan_err:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {path.stem} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {path.stem} for malware. Aborting import.")
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):

View File

@@ -149,29 +149,13 @@ flux_kontext = StarterModel(
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (quantized)",
name="FLUX.1 Kontext dev (Quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_krea = StarterModel(
name="FLUX.1 Krea dev",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors",
description="FLUX.1 Krea dev. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_krea_quantized = StarterModel(
name="FLUX.1 Krea dev (quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf",
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -596,14 +580,13 @@ t2i_sketch_sdxl = StarterModel(
)
# endregion
# region SpandrelImageToImage
animesharp_v4_rcan = StarterModel(
name="2x-AnimeSharpV4_RCAN",
realesrgan_anime = StarterModel(
name="RealESRGAN_x4plus_anime_6B",
base=BaseModelType.Any,
source="https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN.safetensors",
description="A 2x upscaling model (optimized for anime images).",
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
type=ModelType.SpandrelImageToImage,
)
realesrgan_x4 = StarterModel(
name="RealESRGAN_x4plus",
base=BaseModelType.Any,
@@ -749,7 +732,7 @@ STARTER_MODELS: list[StarterModel] = [
t2i_lineart_sdxl,
t2i_sketch_sdxl,
realesrgan_x4,
animesharp_v4_rcan,
realesrgan_anime,
realesrgan_x2,
swinir,
t5_base_encoder,
@@ -760,8 +743,6 @@ STARTER_MODELS: list[StarterModel] = [
llava_onevision,
flux_fill,
cogview4,
flux_krea,
flux_krea_quantized,
]
sd1_bundle: list[StarterModel] = [
@@ -813,7 +794,6 @@ flux_bundle: list[StarterModel] = [
flux_redux,
flux_fill,
flux_kontext_quantized,
flux_krea_quantized,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {

View File

@@ -28,7 +28,6 @@ class BaseModelType(str, Enum):
CogView4 = "cogview4"
Imagen3 = "imagen3"
Imagen4 = "imagen4"
Gemini2_5 = "gemini-2.5"
ChatGPT4o = "chatgpt-4o"
FluxKontext = "flux-kontext"

View File

@@ -8,12 +8,8 @@ import picklescan.scanner as pscan
import safetensors
import torch
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.taxonomy import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
@@ -63,21 +59,9 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
if scan:
scan_result = pscan.scan_file_path(path)
if scan_result.infected_files != 0:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {path} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"The model {path} is potentially infected by malware. Aborting import.")
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {path} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {path} for malware. Aborting import.")
raise Exception(f"Error scanning model at {path} for malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -18,25 +18,16 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
# Check if keys use transformer prefix
transformer_prefix_keys = [
# Next, check that this is likely a FLUX model by spot-checking a few keys.
expected_keys = [
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
transformer_keys_present = all(k in state_dict for k in transformer_prefix_keys)
all_expected_keys_present = all(k in state_dict for k in expected_keys)
# Check if keys use base_model.model prefix
base_model_prefix_keys = [
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"base_model.model.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"base_model.model.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
base_model_keys_present = all(k in state_dict for k in base_model_prefix_keys)
return all_keys_in_peft_format and (transformer_keys_present or base_model_keys_present)
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(
@@ -58,16 +49,8 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Determine which prefix is used and remove it from all keys.
# Check if any key starts with "base_model.model." prefix
has_base_model_prefix = any(k.startswith("base_model.model.") for k in grouped_state_dict.keys())
if has_base_model_prefix:
# Remove the "base_model.model." prefix from all keys.
grouped_state_dict = {k.replace("base_model.model.", ""): v for k, v in grouped_state_dict.items()}
else:
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Constants for FLUX.1
num_double_layers = 19

View File

@@ -20,7 +20,7 @@ def main():
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
with log_time("Initialize FLUX transformer on meta device"):
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]

View File

@@ -33,7 +33,7 @@ def main():
)
# inference_dtype = torch.bfloat16
with log_time("Initialize FLUX transformer on meta device"):
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]

View File

@@ -27,7 +27,7 @@ def main():
"""
model_path = Path("/data/misc/text_encoder_2")
with log_time("Initialize T5 on meta device"):
with log_time("Intialize T5 on meta device"):
model_config = AutoConfig.from_pretrained(model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)

View File

@@ -1,117 +0,0 @@
from typing import Literal
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def estimate_vae_working_memory_sd15_sdxl(
operation: Literal["encode", "decode"],
image_tensor: torch.Tensor,
vae: AutoencoderKL | AutoencoderTiny,
tile_size: int | None,
fp32: bool,
) -> int:
"""Estimate the working memory required to encode or decode the given tensor."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if fp32 else 2
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
if tile_size is not None:
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
h = tile_size
w = tile_size
working_memory = h * w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
working_memory = h * w * element_size * scaling_constant
if fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_cogview4(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = h * w * element_size * scaling_constant
print(f"estimate_vae_working_memory_cogview4: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_flux(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoEncoder
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
out_h = latent_scale_factor_for_operation * image_tensor.shape[-2]
out_w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = out_h * out_w * element_size * scaling_constant
print(f"estimate_vae_working_memory_flux: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_sd3(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = h * w * element_size * scaling_constant
print(f"estimate_vae_working_memory_sd3: {int(working_memory)}")
return int(working_memory)

View File

@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
returnNull: false,
});
const store = createStore();
const store = createStore({ driver: { getItem: () => {}, setItem: () => {} }, persistThrottle: 2000 });
$store.set(store);
$baseUrl.set('http://localhost:9090');

View File

@@ -17,7 +17,6 @@ const config: KnipConfig = {
'src/app/store/use-debounced-app-selector.ts',
],
ignoreBinaries: ['only-allow'],
ignoreDependencies: ['magic-string'],
paths: {
'public/*': ['public/*'],
},

View File

@@ -63,7 +63,6 @@
"framer-motion": "^11.10.0",
"i18next": "^25.3.2",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.22",
"linkify-react": "^4.3.1",
@@ -139,7 +138,6 @@
"eslint-plugin-unused-imports": "^4.1.4",
"globals": "^16.3.0",
"knip": "^5.61.3",
"magic-string": "^0.30.17",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",

View File

@@ -80,9 +80,6 @@ importers:
i18next-http-backend:
specifier: ^3.0.2
version: 3.0.2
idb-keyval:
specifier: 6.2.1
version: 6.2.1
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
@@ -291,9 +288,6 @@ importers:
knip:
specifier: ^5.61.3
version: 5.61.3(@types/node@22.16.0)(typescript@5.8.3)
magic-string:
specifier: ^0.30.17
version: 0.30.17
openapi-types:
specifier: ^12.1.3
version: 12.1.3
@@ -2778,9 +2772,6 @@ packages:
typescript:
optional: true
idb-keyval@6.2.1:
resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -7275,8 +7266,6 @@ snapshots:
optionalDependencies:
typescript: 5.8.3
idb-keyval@6.2.1: {}
ieee754@1.2.1: {}
ignore@5.3.2: {}

View File

@@ -1470,6 +1470,7 @@
"ui": {
"tabs": {
"queue": "Warteschlange",
"generation": "Erzeugung",
"gallery": "Galerie",
"models": "Modelle",
"upscaling": "Hochskalierung",

View File

@@ -38,7 +38,6 @@
"deletedImagesCannotBeRestored": "Deleted images cannot be restored.",
"hideBoards": "Hide Boards",
"loading": "Loading...",
"locateInGalery": "Locate in Gallery",
"menuItemAutoAdd": "Auto-add to this Board",
"move": "Move",
"movingImagesToBoard_one": "Moving {{count}} image to board:",
@@ -115,9 +114,6 @@
"t2iAdapter": "T2I Adapter",
"positivePrompt": "Positive Prompt",
"negativePrompt": "Negative Prompt",
"removeNegativePrompt": "Remove Negative Prompt",
"addNegativePrompt": "Add Negative Prompt",
"selectYourModel": "Select Your Model",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
"dontShowMeThese": "Don't show me these",
@@ -614,18 +610,10 @@
"title": "Toggle Non-Raster Layers",
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
},
"fitBboxToLayers": {
"title": "Fit Bbox To Layers",
"desc": "Automatically adjust the generation bounding box to fit visible layers"
},
"fitBboxToMasks": {
"title": "Fit Bbox To Masks",
"desc": "Automatically adjust the generation bounding box to fit visible inpaint masks"
},
"toggleBbox": {
"title": "Toggle Bbox Visibility",
"desc": "Hide or show the generation bounding box"
},
"applySegmentAnything": {
"title": "Apply Segment Anything",
"desc": "Apply the current Segment Anything mask.",
@@ -775,7 +763,6 @@
"allPrompts": "All Prompts",
"cfgScale": "CFG scale",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"clipSkip": "$t(parameters.clipSkip)",
"createdBy": "Created By",
"generationMode": "Generation Mode",
"guidance": "Guidance",
@@ -878,9 +865,6 @@
"install": "Install",
"installAll": "Install All",
"installRepo": "Install Repo",
"installBundle": "Install Bundle",
"installBundleMsg1": "Are you sure you want to install the {{bundleName}} bundle?",
"installBundleMsg2": "This bundle will install the following {{count}} models:",
"ipAdapters": "IP Adapters",
"learnMoreAboutSupportedModels": "Learn more about the models we support",
"load": "Load",
@@ -1251,7 +1235,7 @@
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with FLUX Kontext via BFL API",
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
"canvasIsFiltering": "Canvas is busy (filtering)",
"canvasIsTransforming": "Canvas is busy (transforming)",
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
@@ -1299,7 +1283,6 @@
"remixImage": "Remix Image",
"usePrompt": "Use Prompt",
"useSeed": "Use Seed",
"useClipSkip": "Use CLIP Skip",
"width": "Width",
"gaussianBlur": "Gaussian Blur",
"boxBlur": "Box Blur",
@@ -1381,8 +1364,8 @@
"addedToBoard": "Added to board {{name}}'s assets",
"addedToUncategorized": "Added to board $t(boards.uncategorized)'s assets",
"baseModelChanged": "Base Model Changed",
"baseModelChangedCleared_one": "Updated, cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Updated, cleared or disabled {{count}} incompatible submodels",
"baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels",
"canceled": "Processing Canceled",
"connected": "Connected to Server",
"imageCopied": "Image Copied",
@@ -1950,11 +1933,8 @@
"zoomToNode": "Zoom to Node",
"nodeFieldTooltip": "To add a node field, click the small plus sign button on the field in the Workflow Editor, or drag the field by its name into the form.",
"addToForm": "Add to Form",
"removeFromForm": "Remove from Form",
"label": "Label",
"showDescription": "Show Description",
"showShuffle": "Show Shuffle",
"shuffle": "Shuffle",
"component": "Component",
"numberInput": "Number Input",
"singleLine": "Single Line",
@@ -2086,8 +2066,6 @@
"asControlLayer": "As $t(controlLayers.controlLayer)",
"asControlLayerResize": "As $t(controlLayers.controlLayer) (Resize)",
"referenceImage": "Reference Image",
"maxRefImages": "Max Ref Images",
"useAsReferenceImage": "Use as Reference Image",
"regionalReferenceImage": "Regional Reference Image",
"globalReferenceImage": "Global Reference Image",
"sendingToCanvas": "Staging Generations on Canvas",
@@ -2196,8 +2174,7 @@
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
"rgNoRegion": "no region drawn",
"fluxFillIncompatibleWithControlLoRA": "Control LoRA is not compatible with FLUX Fill",
"bboxHidden": "Bounding box is hidden (shift+o to toggle)"
"fluxFillIncompatibleWithControlLoRA": "Control LoRA is not compatible with FLUX Fill"
},
"errors": {
"unableToFindImage": "Unable to find image",
@@ -2556,7 +2533,7 @@
},
"ui": {
"tabs": {
"generate": "Generate",
"generation": "Generation",
"canvas": "Canvas",
"workflows": "Workflows",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
@@ -2567,12 +2544,6 @@
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"gallery": "Gallery"
},
"panels": {
"launchpad": "Launchpad",
"workflowEditor": "Workflow Editor",
"imageViewer": "Image Viewer",
"canvas": "Canvas"
},
"launchpad": {
"workflowsTitle": "Go deep with Workflows.",
"upscalingTitle": "Upscale and add detail.",
@@ -2580,28 +2551,6 @@
"generateTitle": "Generate images from text prompts.",
"modelGuideText": "Want to learn what prompts work best for each model?",
"modelGuideLink": "Check out our Model Guide.",
"createNewWorkflowFromScratch": "Create a new Workflow from scratch",
"browseAndLoadWorkflows": "Browse and load existing workflows",
"addStyleRef": {
"title": "Add a Style Reference",
"description": "Add an image to transfer its look."
},
"editImage": {
"title": "Edit Image",
"description": "Add an image to refine."
},
"generateFromText": {
"title": "Generate from Text",
"description": "Enter a prompt and Invoke."
},
"useALayoutImage": {
"title": "Use a Layout Image",
"description": "Add an image to control composition."
},
"generate": {
"canvasCalloutTitle": "Looking to get more control, edit, and iterate on your images?",
"canvasCalloutLink": "Navigate to Canvas for more capabilities."
},
"workflows": {
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
"learnMoreLink": "Learn more about creating workflows",
@@ -2638,13 +2587,6 @@
"upscaleModel": "Upscale Model",
"model": "Model",
"scale": "Scale",
"creativityAndStructure": {
"title": "Creativity & Structure Defaults",
"conservative": "Conservative",
"balanced": "Balanced",
"creative": "Creative",
"artistic": "Artistic"
},
"helpText": {
"promptAdvice": "When upscaling, use a prompt that describes the medium and style. Avoid describing specific content details in the image.",
"styleAdvice": "Upscaling works best with the general style of your image."
@@ -2689,8 +2631,10 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Canvas: Color Picker does not sample alpha, bbox respects aspect ratio lock when resizing shuffle button for number fields in Workflow Builder, hide pixel dimension sliders when using a model that doesn't support them",
"Workflows: Add a Shuffle button to number input fields"
"New setting to send all Canvas generations directly to the Gallery.",
"New Invert Mask (Shift+V) and Fit BBox to Mask (Shift+B) capabilities.",
"Expanded support for Model Thumbnails and configurations.",
"Various other quality of life updates and fixes"
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -399,6 +399,7 @@
"ui": {
"tabs": {
"canvas": "Lienzo",
"generation": "Generación",
"queue": "Cola",
"workflows": "Flujos de trabajo",
"models": "Modelos",

View File

@@ -1820,6 +1820,7 @@
"upscaling": "Agrandissement",
"gallery": "Galerie",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"generation": "Génération",
"workflows": "Workflows",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"models": "Modèles",

View File

@@ -128,10 +128,7 @@
"search": "Cerca",
"clear": "Cancella",
"compactView": "Vista compatta",
"fullView": "Vista completa",
"removeNegativePrompt": "Rimuovi prompt negativo",
"addNegativePrompt": "Aggiungi prompt negativo",
"selectYourModel": "Seleziona il modello"
"fullView": "Vista completa"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -413,14 +410,6 @@
"cancelSegmentAnything": {
"title": "Annulla Segment Anything",
"desc": "Annulla l'operazione Segment Anything corrente."
},
"fitBboxToLayers": {
"title": "Adatta il riquadro di delimitazione ai livelli",
"desc": "Regola automaticamente il riquadro di delimitazione della generazione per adattarlo ai livelli visibili"
},
"toggleBbox": {
"title": "Attiva/disattiva la visibilità del riquadro di delimitazione",
"desc": "Nascondi o mostra il riquadro di delimitazione della generazione"
}
},
"workflows": {
@@ -722,10 +711,7 @@
"bundleDescription": "Ogni pacchetto include modelli essenziali per ogni famiglia di modelli e modelli base selezionati per iniziare.",
"browseAll": "Oppure scopri tutti i modelli disponibili:"
},
"launchpadTab": "Rampa di lancio",
"installBundle": "Installa pacchetto",
"installBundleMsg1": "Vuoi davvero installare il pacchetto {{bundleName}}?",
"installBundleMsg2": "Questo pacchetto installerà i seguenti {{count}} modelli:"
"launchpadTab": "Rampa di lancio"
},
"parameters": {
"images": "Immagini",
@@ -812,6 +798,7 @@
"modelIncompatibleScaledBboxWidth": "La larghezza scalata del riquadro è {{width}} ma {{model}} richiede multipli di {{multiple}}",
"modelIncompatibleScaledBboxHeight": "L'altezza scalata del riquadro è {{height}} ma {{model}} richiede multipli di {{multiple}}",
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Accedi alle impostazioni del tuo account per effettuare l'upgrade.",
"fluxKontextMultipleReferenceImages": "È possibile utilizzare solo 1 immagine di riferimento alla volta con Flux Kontext",
"promptExpansionResultPending": "Accetta o ignora il risultato dell'espansione del prompt",
"promptExpansionPending": "Espansione del prompt in corso"
},
@@ -841,8 +828,7 @@
"coherenceMinDenoise": "Min rid. rumore",
"recallMetadata": "Richiama i metadati",
"disabledNoRasterContent": "Disabilitato (nessun contenuto Raster)",
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Visita le <LinkComponent>impostazioni account</LinkComponent> per effettuare l'upgrade.",
"useClipSkip": "Usa CLIP Skip"
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Visita le <LinkComponent>impostazioni account</LinkComponent> per effettuare l'upgrade."
},
"settings": {
"models": "Modelli",
@@ -895,8 +881,8 @@
"parameterSet": "Parametro richiamato",
"parameterNotSet": "Parametro non richiamato",
"problemCopyingImage": "Impossibile copiare l'immagine",
"baseModelChangedCleared_one": "Aggiornato, cancellato o disabilitato {{count}} sottomodello incompatibile",
"baseModelChangedCleared_many": "Aggiornati, cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
"baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
"imageUploaded": "Immagine caricata",
@@ -1187,8 +1173,8 @@
"layeringStrategy": "Strategia livelli",
"longestPath": "Percorso più lungo",
"layoutDirection": "Direzione schema",
"layoutDirectionRight": "A destra",
"layoutDirectionDown": "In basso",
"layoutDirectionRight": "Orizzontale",
"layoutDirectionDown": "Verticale",
"alignment": "Allineamento nodi",
"alignmentUL": "In alto a sinistra",
"alignmentDL": "In basso a sinistra",
@@ -1241,8 +1227,7 @@
"updateBoardError": "Errore durante l'aggiornamento della bacheca",
"uncategorizedImages": "Immagini non categorizzate",
"deleteAllUncategorizedImages": "Elimina tutte le immagini non categorizzate",
"deletedImagesCannotBeRestored": "Le immagini eliminate non possono essere ripristinate.",
"locateInGalery": "Trova nella Galleria"
"deletedImagesCannotBeRestored": "Le immagini eliminate non possono essere ripristinate."
},
"queue": {
"queueFront": "Aggiungi all'inizio della coda",
@@ -1743,7 +1728,7 @@
"structure": {
"heading": "Struttura",
"paragraphs": [
"La struttura determina quanto l'immagine finale rispecchierà lo schema dell'originale. Un valore struttura basso permette cambiamenti significativi, mentre un valore struttura alto conserva la composizione e lo schema originali."
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Un valore struttura basso permette cambiamenti significativi, mentre un valore struttura alto conserva la composizione e lo schema originali."
]
},
"fluxDevLicense": {
@@ -1989,10 +1974,7 @@
"publishInProgress": "Pubblicazione in corso",
"selectingOutputNode": "Selezione del nodo di uscita",
"selectingOutputNodeDesc": "Fare clic su un nodo per selezionarlo come nodo di uscita del flusso di lavoro.",
"errorWorkflowHasUnpublishableNodes": "Il flusso di lavoro ha nodi di estrazione lotto, generatore o metadati",
"showShuffle": "Mostra Mescola",
"shuffle": "Mescola",
"removeFromForm": "Rimuovi dal modulo"
"errorWorkflowHasUnpublishableNodes": "Il flusso di lavoro ha nodi di estrazione lotto, generatore o metadati"
},
"loadMore": "Carica altro",
"searchPlaceholder": "Cerca per nome, descrizione o etichetta",
@@ -2473,8 +2455,7 @@
"ipAdapterIncompatibleBaseModel": "modello base dell'immagine di riferimento incompatibile",
"ipAdapterNoImageSelected": "nessuna immagine di riferimento selezionata",
"rgAutoNegativeNotSupported": "Auto-Negativo non supportato per il modello base selezionato",
"fluxFillIncompatibleWithControlLoRA": "Il controllo LoRA non è compatibile con FLUX Fill",
"bboxHidden": "Il riquadro di delimitazione è nascosto (Shift+o per attivarlo)"
"fluxFillIncompatibleWithControlLoRA": "Il controllo LoRA non è compatibile con FLUX Fill"
},
"pasteTo": "Incolla su",
"pasteToBboxDesc": "Nuovo livello (nel riquadro di delimitazione)",
@@ -2514,12 +2495,11 @@
"off": "Spento"
},
"invertMask": "Inverti maschera",
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere",
"maxRefImages": "Max Immagini di rif.to",
"useAsReferenceImage": "Usa come immagine di riferimento"
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere"
},
"ui": {
"tabs": {
"generation": "Generazione",
"canvas": "Tela",
"workflows": "Flussi di lavoro",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
@@ -2528,8 +2508,7 @@
"queue": "Coda",
"upscaling": "Amplia",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"gallery": "Galleria",
"generate": "Genera"
"gallery": "Galleria"
},
"launchpad": {
"workflowsTitle": "Approfondisci i flussi di lavoro.",
@@ -2577,43 +2556,8 @@
"helpText": {
"promptAdvice": "Durante l'ampliamento, utilizza un prompt che descriva il mezzo e lo stile. Evita di descrivere dettagli specifici del contenuto dell'immagine.",
"styleAdvice": "L'ampliamento funziona meglio con lo stile generale dell'immagine."
},
"creativityAndStructure": {
"title": "Creatività e struttura predefinite",
"conservative": "Conservativo",
"balanced": "Bilanciato",
"creative": "Creativo",
"artistic": "Artistico"
}
},
"createNewWorkflowFromScratch": "Crea un nuovo flusso di lavoro da zero",
"browseAndLoadWorkflows": "Sfoglia e carica i flussi di lavoro esistenti",
"addStyleRef": {
"title": "Aggiungi un riferimento di stile",
"description": "Aggiungi un'immagine per trasferirne l'aspetto."
},
"editImage": {
"title": "Modifica immagine",
"description": "Aggiungi un'immagine da perfezionare."
},
"generateFromText": {
"title": "Genera da testo",
"description": "Inserisci un prompt e genera."
},
"useALayoutImage": {
"description": "Aggiungi un'immagine per controllare la composizione.",
"title": "Usa una immagine guida"
},
"generate": {
"canvasCalloutTitle": "Vuoi avere più controllo, modificare e affinare le tue immagini?",
"canvasCalloutLink": "Per ulteriori funzionalità, vai su Tela."
}
},
"panels": {
"launchpad": "Rampa di lancio",
"workflowEditor": "Editor del flusso di lavoro",
"imageViewer": "Visualizzatore immagini",
"canvas": "Tela"
}
},
"upscaling": {
@@ -2704,8 +2648,10 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Tela: Color Picker non campiona l'alfa, il riquadro di delimitazione rispetta il blocco delle proporzioni quando si ridimensiona il pulsante Mescola per i campi numerici nel generatore di flusso di lavoro, nasconde i cursori delle dimensioni dei pixel quando si utilizza un modello che non li supporta",
"Flussi di lavoro: aggiunto un pulsante Mescola ai campi di input numerici"
"Nuova impostazione per inviare tutte le generazioni della Tela direttamente alla Galleria.",
"Nuove funzionalità Inverti maschera (Maiusc+V) e Adatta il Riquadro di delimitazione alla maschera (Maiusc+B).",
"Supporto esteso per miniature e configurazioni dei modelli.",
"Vari altri aggiornamenti e correzioni per la qualità della vita"
]
},
"system": {

View File

@@ -755,6 +755,7 @@
"noFLUXVAEModelSelected": "FLUX生成にVAEモデルが選択されていません",
"noT5EncoderModelSelected": "FLUX生成にT5エンコーダモデルが選択されていません",
"modelDisabledForTrial": "{{modelName}} を使用した生成はトライアルアカウントではご利用いただけません.アカウント設定にアクセスしてアップグレードしてください。",
"fluxKontextMultipleReferenceImages": "Flux Kontext では一度に 1 つの参照画像しか使用できません",
"promptExpansionPending": "プロンプト拡張が進行中",
"promptExpansionResultPending": "プロンプト拡張結果を受け入れるか破棄してください"
},
@@ -1782,6 +1783,7 @@
"workflows": "ワークフロー",
"models": "モデル",
"gallery": "ギャラリー",
"generation": "生成",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"upscaling": "アップスケーリング",

View File

@@ -1931,6 +1931,7 @@
},
"ui": {
"tabs": {
"generation": "Генерация",
"canvas": "Холст",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"models": "Модели",

View File

@@ -55,8 +55,7 @@
"assetsWithCount_other": "{{count}} tài nguyên",
"uncategorizedImages": "Ảnh Chưa Sắp Xếp",
"deleteAllUncategorizedImages": "Xoá Tất Cả Ảnh Chưa Sắp Xếp",
"deletedImagesCannotBeRestored": "Ảnh đã xoá không thể phục hồi lại.",
"locateInGalery": "Vị Trí Ở Thư Viện Ảnh"
"deletedImagesCannotBeRestored": "Ảnh đã xoá không thể phục hồi lại."
},
"gallery": {
"swapImages": "Đổi Hình Ảnh",
@@ -253,10 +252,7 @@
"clear": "Dọn Dẹp",
"compactView": "Chế Độ Xem Gọn",
"fullView": "Chế Độ Xem Đầy Đủ",
"options_withCount_other": "{{count}} thiết lập",
"removeNegativePrompt": "Xóa Lệnh Tiêu Cực",
"addNegativePrompt": "Thêm Lệnh Tiêu Cực",
"selectYourModel": "Chọn Model"
"options_withCount_other": "{{count}} thiết lập"
},
"prompt": {
"addPromptTrigger": "Thêm Trigger Cho Lệnh",
@@ -303,7 +299,7 @@
"pruneTooltip": "Cắt bớt {{item_count}} mục đã hoàn tất",
"pruneSucceeded": "Đã cắt bớt {{item_count}} mục đã hoàn tất khỏi hàng",
"clearTooltip": "Huỷ Và Dọn Dẹp Tất Cả Mục",
"clearQueueAlertDialog": "Dọn dẹp hàng đợi sẽ ngay lập tức huỷ tất cả mục đang xử lý và làm sạch hàng hoàn toàn. Bộ lọc đang chờ xử lý sẽ bị huỷ bỏ và Vùng Dựng Canva sẽ được khởi động lại.",
"clearQueueAlertDialog": "Dọn dẹp hàng đợi sẽ ngay lập tức huỷ tất cả mục đang xử lý và làm sạch hàng hoàn toàn. Bộ lọc đang chờ xử lý sẽ bị huỷ bỏ.",
"session": "Phiên",
"item": "Mục",
"resumeFailed": "Có Vấn Đề Khi Tiếp Tục Bộ Xử Lý",
@@ -347,14 +343,13 @@
"retrySucceeded": "Mục Đã Thử Lại",
"retryFailed": "Có Vấn Đề Khi Thử Lại Mục",
"retryItem": "Thử Lại Mục",
"credits": "Nguồn",
"cancelAllExceptCurrent": "Huỷ Bỏ Tất Cả Ngoại Trừ Mục Hiện Tại"
"credits": "Nguồn"
},
"hotkeys": {
"canvas": {
"fitLayersToCanvas": {
"title": "Xếp Vừa Layers Vào Canvas",
"desc": "Căn chỉnh để góc nhìn vừa vặn với tất cả layer nhìn thấy dược."
"desc": "Căn chỉnh để góc nhìn vừa vặn với tất cả layer."
},
"setZoomTo800Percent": {
"desc": "Phóng to canvas lên 800%.",
@@ -478,32 +473,6 @@
"toggleNonRasterLayers": {
"title": "Bật/Tắt Layer Không Thuộc Dạng Raster",
"desc": "Hiện hoặc ẩn tất cả layer không thuộc dạng raster (Layer Điều Khiển Được, Lớp Phủ Inpaint, Chỉ Dẫn Khu Vực)."
},
"invertMask": {
"title": "Đảo Ngược Lớp Phủ",
"desc": "Đảo ngược lớp phủ inpaint được chọn, tạo một lớp phủ mới với độ trong suốt đối nghịch."
},
"fitBboxToMasks": {
"title": "Xếp Vừa Hộp Giới Hạn Vào Lớp Phủ",
"desc": "Tự động điểu chỉnh hộp giới hạn tạo sinh vừa vặn vào lớp phủ inpaint nhìn thấy được"
},
"applySegmentAnything": {
"title": "Áp Dụng Segment Anything",
"desc": "Áp dụng lớp phủ Segment Anything hiện tại.",
"key": "enter"
},
"cancelSegmentAnything": {
"title": "Huỷ Segment Anything",
"desc": "Huỷ hoạt động Segment Anything hiện tại.",
"key": "esc"
},
"fitBboxToLayers": {
"title": "Xếp Vừa Hộp Giới Hạn Vào Layer",
"desc": "Tự động điểu chỉnh hộp giới hạn tạo sinh vừa vặn vào layer nhìn thấy được"
},
"toggleBbox": {
"title": "Bật/Tắt Hiển Thị Hộp Giới Hạn",
"desc": "Ẩn hoặc hiện hộp giới hạn tạo sinh"
}
},
"workflows": {
@@ -633,10 +602,6 @@
"clearSelection": {
"desc": "Xoá phần lựa chọn hiện tại nếu có.",
"title": "Xoá Phần Lựa Chọn"
},
"starImage": {
"title": "Dấu/Huỷ Sao Hình Ảnh",
"desc": "Đánh dấu sao hoặc huỷ đánh dấu sao ảnh được chọn."
}
},
"app": {
@@ -696,11 +661,6 @@
"selectModelsTab": {
"desc": "Chọn tab Model (Mô Hình).",
"title": "Chọn Tab Model"
},
"selectGenerateTab": {
"title": "Chọn Tab Tạo Sinh",
"desc": "Chọn tab Tạo Sinh.",
"key": "1"
}
},
"searchHotkeys": "Tìm Phím tắt",
@@ -877,10 +837,7 @@
"stableDiffusion15": "Stable Diffusion 1.5",
"sdxl": "SDXL",
"fluxDev": "FLUX.1 dev"
},
"installBundle": "Tải Xuống Gói",
"installBundleMsg1": "Bạn có chắc chắn muốn tải xuống gói {{bundleName}}?",
"installBundleMsg2": "Gói này sẽ tải xuống {{count}} model sau đây:"
}
},
"metadata": {
"guidance": "Hướng Dẫn",
@@ -913,8 +870,7 @@
"recallParameters": "Gợi Nhớ Tham Số",
"scheduler": "Scheduler",
"noMetaData": "Không tìm thấy metadata",
"imageDimensions": "Kích Thước Ảnh",
"clipSkip": "$t(parameters.clipSkip)"
"imageDimensions": "Kích Thước Ảnh"
},
"accordions": {
"generation": {
@@ -1134,23 +1090,7 @@
"unknownField_withName": "Vùng Dữ Liệu Không Rõ \"{{name}}\"",
"unexpectedField_withName": "Sai Vùng Dữ Liệu \"{{name}}\"",
"unknownFieldEditWorkflowToFix_withName": "Workflow chứa vùng dữ liệu không rõ \"{{name}}\".\nHãy biên tập workflow để sửa lỗi.",
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\"",
"layout": {
"autoLayout": "Bố Cục Tự Động",
"layeringStrategy": "Chiến Lược Phân Layer",
"networkSimplex": "Network Simplex",
"longestPath": "Đường Đi Dài Nhất",
"nodeSpacing": "Khoảng Cách Node",
"layerSpacing": "Khoảng Cách Layer",
"layoutDirection": "Hướng Bố Cục",
"layoutDirectionRight": "Phải",
"layoutDirectionDown": "Xuống",
"alignment": "Căn Chỉnh Node",
"alignmentUL": "Trên Cùng Bên Trái",
"alignmentDL": "Dưới Cùng Bên Trái",
"alignmentUR": "Trên Cùng Bên Phải",
"alignmentDR": "Dưới Cùng Bên Phải"
}
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\""
},
"popovers": {
"paramCFGRescaleMultiplier": {
@@ -1657,6 +1597,7 @@
"modelIncompatibleScaledBboxHeight": "Chiều dài hộp giới hạn theo tỉ lệ là {{height}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelIncompatibleScaledBboxWidth": "Chiều rộng hộp giới hạn theo tỉ lệ là {{width}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần thiết lập tài khoản để nâng cấp.",
"fluxKontextMultipleReferenceImages": "Chỉ có thể dùng 1 Ảnh Mẫu cùng lúc với Flux Kontext",
"promptExpansionPending": "Trong quá trình mở rộng lệnh",
"promptExpansionResultPending": "Hãy chấp thuận hoặc huỷ bỏ kết quả mở rộng lệnh của bạn"
},
@@ -1722,8 +1663,7 @@
"upscaling": "Upscale",
"tileSize": "Kích Thước Khối",
"disabledNoRasterContent": "Đã Tắt (Không Có Nội Dung Dạng Raster)",
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần <LinkComponent>thiết lập tài khoản</LinkComponent> để nâng cấp.",
"useClipSkip": "Dùng CLIP Skip"
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần <LinkComponent>thiết lập tài khoản</LinkComponent> để nâng cấp."
},
"dynamicPrompts": {
"seedBehaviour": {
@@ -2214,8 +2154,7 @@
"rgReferenceImagesNotSupported": "Ảnh Mẫu Khu Vực không được hỗ trợ cho model cơ sở được chọn",
"rgAutoNegativeNotSupported": "Tự Động Đảo Chiều không được hỗ trợ cho model cơ sở được chọn",
"rgNoRegion": "không có khu vực được vẽ",
"fluxFillIncompatibleWithControlLoRA": "LoRA Điều Khiển Được không tương tích với FLUX Fill",
"bboxHidden": "Hộp giới hạn đang ẩn (shift+o để bật/tắt)"
"fluxFillIncompatibleWithControlLoRA": "LoRA Điều Khiển Được không tương tích với FLUX Fill"
},
"pasteTo": "Dán Vào",
"pasteToAssets": "Tài Nguyên",
@@ -2253,11 +2192,7 @@
"off": "Tắt",
"switchOnStart": "Khi Bắt Đầu",
"switchOnFinish": "Khi Kết Thúc"
},
"fitBboxToMasks": "Xếp Vừa Hộp Giới Hạn Vào Lớp Phủ",
"invertMask": "Đảo Ngược Lớp Phủ",
"maxRefImages": "Ảnh Mẫu Tối Đa",
"useAsReferenceImage": "Dùng Làm Ảnh Mẫu"
}
},
"stylePresets": {
"negativePrompt": "Lệnh Tiêu Cực",
@@ -2419,28 +2354,20 @@
"noValidLayerAdapters": "Không có Layer Adaper Phù Hợp",
"promptGenerationStarted": "Trình tạo sinh lệnh khởi động",
"uploadAndPromptGenerationFailed": "Thất bại khi tải lên ảnh để tạo sinh lệnh",
"promptExpansionFailed": "Có vấn đề xảy ra. Hãy thử mở rộng lệnh lại.",
"maskInverted": "Đã Đảo Ngược Lớp Phủ",
"maskInvertFailed": "Thất Bại Khi Đảo Ngược Lớp Phủ",
"noVisibleMasks": "Không Có Lớp Phủ Đang Hiển Thị",
"noVisibleMasksDesc": "Tạo hoặc bật ít nhất một lớp phủ inpaint để đảo ngược",
"noInpaintMaskSelected": "Không Có Lớp Phủ Inpant Được Chọn",
"noInpaintMaskSelectedDesc": "Chọn một lớp phủ inpaint để đảo ngược",
"invalidBbox": "Hộp Giới Hạn Không Hợp Lệ",
"invalidBboxDesc": "Hợp giới hạn có kích thước không hợp lệ"
"promptExpansionFailed": "Có vấn đề xảy ra. Hãy thử mở rộng lệnh lại."
},
"ui": {
"tabs": {
"gallery": "Thư Viện Ảnh",
"models": "Models",
"generation": "Generation (Máy Tạo Sinh)",
"upscaling": "Upscale (Nâng Cấp Chất Lượng Hình Ảnh)",
"canvas": "Canvas (Vùng Ảnh)",
"upscalingTab": "$t(common.tab) $t(ui.tabs.upscaling)",
"modelsTab": "$t(common.tab) $t(ui.tabs.models)",
"queue": "Queue (Hàng Đợi)",
"workflows": "Workflow (Luồng Làm Việc)",
"workflowsTab": "$t(common.tab) $t(ui.tabs.workflows)",
"generate": "Tạo Sinh"
"workflowsTab": "$t(common.tab) $t(ui.tabs.workflows)"
},
"launchpad": {
"workflowsTitle": "Đi sâu hơn với Workflow.",
@@ -2488,43 +2415,8 @@
"promptAdvice": "Khi upscale, dùng lệnh để mô tả phương thức và phong cách. Tránh mô tả các chi tiết cụ thể trong ảnh.",
"styleAdvice": "Upscale thích hợp nhất cho phong cách chung của ảnh."
},
"scale": "Kích Thước",
"creativityAndStructure": {
"title": "Độ Sáng Tạo & Cấu Trúc Mặc Định",
"conservative": "Bảo toàn",
"balanced": "Cân bằng",
"creative": "Sáng tạo",
"artistic": "Thẩm mỹ"
}
},
"createNewWorkflowFromScratch": "Tạo workflow mới từ đầu",
"browseAndLoadWorkflows": "Duyệt và tải workflow có sẵn",
"addStyleRef": {
"title": "Thêm Phong Cách Mẫu",
"description": "Thêm ảnh để chuyển đổi diện mạo của nó."
},
"editImage": {
"title": "Biên Tập Ảnh",
"description": "Thêm ảnh để chỉnh sửa."
},
"generateFromText": {
"title": "Tạo Sinh Từ Chữ",
"description": "Nhập lệnh vào và Kích Hoạt."
},
"useALayoutImage": {
"title": "Dùng Bố Cục Ảnh",
"description": "Thêm ảnh để điều khiển bố cục."
},
"generate": {
"canvasCalloutTitle": "Đang tìm cách để điều khiển, chỉnh sửa, và làm lại ảnh?",
"canvasCalloutLink": "Vào Canvas cho nhiều tính năng hơn."
"scale": "Kích Thước"
}
},
"panels": {
"launchpad": "Launchpad",
"workflowEditor": "Trình Biên Tập Workflow",
"imageViewer": "Trình Xem Ảnh",
"canvas": "Canvas"
}
},
"workflows": {
@@ -2639,10 +2531,7 @@
"publishingValidationRunInProgress": "Quá trình kiểm tra tính hợp lệ đang diễn ra.",
"selectingOutputNodeDesc": "Bấm vào node để biến nó thành node đầu ra của workflow.",
"selectingOutputNode": "Chọn node đầu ra",
"errorWorkflowHasUnpublishableNodes": "Workflow có lô node, node sản sinh, hoặc node tách metadata",
"removeFromForm": "Xóa Khỏi Vùng Nhập",
"showShuffle": "Hiện Xáo Trộn",
"shuffle": "Xáo Trộn"
"errorWorkflowHasUnpublishableNodes": "Workflow có lô node, node sản sinh, hoặc node tách metadata"
},
"yourWorkflows": "Workflow Của Bạn",
"browseWorkflows": "Khám Phá Workflow",
@@ -2699,8 +2588,9 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Misc QoL: Bật/Tắt hiển thị hộp giới hạn, đánh dấu node bị lỗi, chặn lỗi thêm node vào vùng nhập nhiều lần, khả năng đọc lại metadata của CLIP Skip",
"Giảm lượng tiêu thụ VRAM cho các ảnh mẫu Kontext và mã hóa VAE"
"Tạo sinh ảnh nhanh hơn với Launchpad và thẻ Tạo Sinh đã cơ bản hoá.",
"Biên tập với lệnh bằng Flux Kontext Dev.",
"Xuất ra file PSD, ẩn số lượng lớn lớp phủ, sắp xếp model & ảnh — tất cả cho một giao diện đã thiết kế lại để chuyên điều khiển."
]
},
"upsell": {

View File

@@ -1772,6 +1772,7 @@
},
"ui": {
"tabs": {
"generation": "生成",
"queue": "队列",
"canvas": "画布",
"upscaling": "放大中",

View File

@@ -2,8 +2,8 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { useClearStorage } from 'app/contexts/clear-storage-context';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import { clearStorage } from 'app/store/enhancers/reduxRemember/driver';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { AppContent } from 'features/ui/components/AppContent';
@@ -21,12 +21,13 @@ interface Props {
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
const didStudioInit = useStore($didStudioInit);
const clearStorage = useClearStorage();
const handleReset = useCallback(() => {
clearStorage();
location.reload();
return false;
}, []);
}, [clearStorage]);
return (
<ThemeLocaleProvider>

View File

@@ -1,11 +1,12 @@
import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import { ClearStorageProvider } from 'app/contexts/clear-storage-context';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { addStorageListeners } from 'app/store/enhancers/reduxRemember/driver';
import { buildStorageApi } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -36,7 +37,7 @@ import {
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { ToastConfig } from 'features/toast/toast';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useState } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $socketOptions } from 'services/events/stores';
@@ -71,7 +72,14 @@ interface Props extends PropsWithChildren {
* If provided, overrides in-app navigation to the model manager
*/
onClickGoToModelManager?: () => void;
storagePersistDebounce?: number;
storageConfig?: {
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
getItem: (key: string) => Promise<any>;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
persistThrottle: number;
};
}
const InvokeAIUI = ({
@@ -98,11 +106,8 @@ const InvokeAIUI = ({
loggingOverrides,
onClickGoToModelManager,
whatsNew,
storagePersistDebounce = 300,
storageConfig,
}: Props) => {
const [store, setStore] = useState<ReturnType<typeof createStore> | undefined>(undefined);
const [didRehydrate, setDidRehydrate] = useState(false);
useLayoutEffect(() => {
/*
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
@@ -314,38 +319,44 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
const storage = useMemo(() => buildStorageApi(storageConfig), [storageConfig]);
useEffect(() => {
const onRehydrated = () => {
setDidRehydrate(true);
const storageCleanup = storage.registerListeners();
return () => {
storageCleanup();
};
const store = createStore({ persist: true, persistDebounce: storagePersistDebounce, onRehydrated });
setStore(store);
}, [storage]);
const store = useMemo(() => {
return createStore({
driver: storage.reduxRememberDriver,
persistThrottle: storageConfig?.persistThrottle ?? 2000,
});
}, [storage.reduxRememberDriver, storageConfig?.persistThrottle]);
useEffect(() => {
$store.set(store);
if (import.meta.env.MODE === 'development') {
window.$store = $store;
}
const removeStorageListeners = addStorageListeners();
return () => {
removeStorageListeners();
setStore(undefined);
$store.set(undefined);
if (import.meta.env.MODE === 'development') {
window.$store = undefined;
}
};
}, [storagePersistDebounce]);
if (!store || !didRehydrate) {
return <Loading />;
}
}, [store]);
return (
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
<ClearStorageProvider value={storage.clearStorage}>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</ClearStorageProvider>
</React.StrictMode>
);
};

View File

@@ -0,0 +1,10 @@
import { createContext, useContext } from 'react';
const ClearStorageContext = createContext<() => void>(() => {});
export const ClearStorageProvider = ClearStorageContext.Provider;
export const useClearStorage = () => {
const context = useContext(ClearStorageContext);
return context;
};

View File

@@ -93,7 +93,5 @@ export const configureLogging = (
localStorage.setItem('ROARR_FILTER', filter);
}
const styleOutput = localStorage.getItem('ROARR_STYLE_OUTPUT') === 'false' ? false : true;
ROARR.write = createLogWriter({ styleOutput });
ROARR.write = createLogWriter();
};

View File

@@ -1,209 +1,243 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $authToken } from 'app/store/nanostores/authToken';
import { $projectId } from 'app/store/nanostores/projectId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { UseStore } from 'idb-keyval';
import { createStore as idbCreateStore, del as idbDel, get as idbGet } from 'idb-keyval';
import type { Driver } from 'redux-remember';
import { serializeError } from 'serialize-error';
import { buildV1Url, getBaseUrl } from 'services/api';
import type { JsonObject } from 'type-fest';
import type { Driver as ReduxRememberDriver } from 'redux-remember';
import { getBaseUrl } from 'services/api';
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
const log = logger('system');
const getUrl = (endpoint: 'get_by_key' | 'set_by_key' | 'delete', key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const buildOSSServerBackedDriver = (): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
const path = buildV1Url(`client_state/${$queueId.get()}/${endpoint}`, query);
const url = `${baseUrl}/${path}`;
return url;
};
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key. If the value to
// be persisted is the same as the last persisted value, we can skip the network request.
const lastPersistedState = new Map<string, unknown>();
const getHeaders = () => {
const headers = new Headers();
const authToken = $authToken.get();
const projectId = $projectId.get();
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set('project-id', projectId);
}
return headers;
};
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key in this map.
// If the value to be persisted is the same as the last persisted value, we will skip the network request.
const lastPersistedState = new Map<string, string | undefined>();
// As of v6.3.0, we use server-backed storage for client state. This replaces the previous IndexedDB-based storage,
// which was implemented using `idb-keyval`.
//
// To facilitate a smooth transition, we implement a migration strategy that attempts to retrieve values from IndexedDB
// and persist them to the new server-backed storage. This is done on a best-effort basis.
// These constants were used in the previous IndexedDB-based storage implementation.
const IDB_DB_NAME = 'invoke';
const IDB_STORE_NAME = 'invoke-store';
const IDB_STORAGE_PREFIX = '@@invokeai-';
// Lazy store creation
let _idbKeyValStore: UseStore | null = null;
const getIdbKeyValStore = () => {
if (_idbKeyValStore === null) {
_idbKeyValStore = idbCreateStore(IDB_DB_NAME, IDB_STORE_NAME);
}
return _idbKeyValStore;
};
const getIdbKey = (key: string) => {
return `${IDB_STORAGE_PREFIX}${key}`;
};
const getItem = async (key: string) => {
try {
const url = getUrl('get_by_key', key);
const headers = getHeaders();
const res = await fetch(url, { method: 'GET', headers });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
const getUrl = (key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const value = await res.json();
const path = buildAppInfoUrl('client_state', query);
const url = `${baseUrl}/${path}`;
return url;
};
// Best-effort migration from IndexedDB to the new storage system
log.trace({ key, value }, 'Server-backed storage value retrieved');
if (!value) {
const idbKey = getIdbKey(key);
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
// It's a bit tricky to query IndexedDB directly to check if value exists, so we use `idb-keyval` to do it.
// Thing is, `idb-keyval` requires you to create a store to query it. End result - we are creating a store
// even if we don't use it for anything besides checking if the key is present.
const idbKeyValStore = getIdbKeyValStore();
const idbValue = await idbGet(idbKey, idbKeyValStore);
if (idbValue) {
log.debug(
{ key, idbKey, idbValue },
'No value in server-backed storage, but found value in IndexedDB - attempting migration'
);
await idbDel(idbKey, idbKeyValStore);
await setItem(key, idbValue);
log.debug({ key, idbKey, idbValue }, 'Migration successful');
return idbValue;
const url = getUrl(key);
const res = await fetch(url, { method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch (error) {
// Just log if IndexedDB retrieval fails - this is a best-effort migration.
log.debug(
{ key, idbKey, error: serializeError(error) } as JsonObject,
'Error checking for or migrating from IndexedDB'
);
const text = await res.text();
if (!lastPersistedState.get(key)) {
lastPersistedState.set(key, text);
}
return JSON.parse(text);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping persist for key "${key}" as value is unchanged.`);
return value;
}
const url = getUrl(key);
const headers = new Headers({
'Content-Type': 'application/json',
});
const res = await fetch(url, { method: 'POST', headers, body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl();
const res = await fetch(url, { method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
lastPersistedState.set(key, value);
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`);
return value;
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
const setItem = async (key: string, value: string) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(
{ key, last: lastPersistedState.get(key), next: value },
`Skipping persist for ${key} as value is unchanged`
);
return value;
}
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
const url = getUrl('set_by_key', key);
const headers = getHeaders();
const res = await fetch(url, { method: 'POST', headers, body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const resultValue = await res.json();
lastPersistedState.set(key, resultValue);
return resultValue;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const buildCustomDriver = (api: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}): {
reduxRememberDriver: ReduxRememberDriver;
clearStorage: () => Promise<void>;
registerListeners: () => () => void;
} => {
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
let persistRefCount = 0;
export const reduxRememberDriver: Driver = { getItem, setItem };
// See the comment in `buildOSSServerBackedDriver` for an explanation of this variable.
const lastPersistedState = new Map<string, unknown>();
export const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl('delete');
const headers = getHeaders();
const res = await fetch(url, { method: 'POST', headers });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
const reduxRememberDriver: ReduxRememberDriver = {
getItem: async (key) => {
try {
log.trace(`Getting client state for key "${key}"`);
return await api.getItem(key);
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
persistRefCount++;
export const addStorageListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
if (lastPersistedState.get(key) === value) {
log.trace(`Skipping setting client state for key "${key}" as value is unchanged`);
return value;
}
log.trace(`Setting client state for key "${key}", ${value}`);
await api.setItem(key, value);
lastPersistedState.set(key, value);
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
},
};
const clearStorage = async () => {
try {
persistRefCount++;
log.trace('Clearing client state');
await api.clear();
} catch {
log.error('Failed to clear client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
const registerListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};
return { reduxRememberDriver, clearStorage, registerListeners };
};
export const buildStorageApi = (api?: {
getItem: (key: string) => Promise<any>;
setItem: (key: string, value: any) => Promise<any>;
clear: () => Promise<void>;
}) => {
if (api) {
return buildCustomDriver(api);
} else {
return buildOSSServerBackedDriver();
}
};

View File

@@ -33,9 +33,8 @@ export class StorageError extends Error {
}
}
const log = logger('system');
export const errorHandler = (err: PersistError | RehydrateError) => {
const log = logger('system');
if (err instanceof PersistError) {
log.error({ error: serializeError(err) }, 'Problem persisting state');
} else if (err instanceof RehydrateError) {

View File

@@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraIsEnabledChanged } from 'features/controlLayers/store/lorasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import {
@@ -12,7 +12,6 @@ import {
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/parameters/types/constants';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
@@ -23,7 +22,6 @@ import {
isFluxKontextApiModelConfig,
isFluxKontextModelConfig,
isFluxReduxModelConfig,
isGemini2_5ModelConfig,
} from 'services/api/types';
const log = logger('models');
@@ -46,13 +44,13 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
let modelsUpdatedDisabledOrCleared = 0;
let modelsCleared = 0;
// handle incompatible loras
state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBase) {
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: false }));
modelsUpdatedDisabledOrCleared += 1;
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;
}
});
@@ -60,57 +58,52 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
const { vae } = state.params;
if (vae && vae.base !== newBase) {
dispatch(vaeSelected(null));
modelsUpdatedDisabledOrCleared += 1;
modelsCleared += 1;
}
if (SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) {
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
let newGlobalRefImageModel = null;
let newGlobalRefImageModel = null;
// Certain models require the ref image model to be the same as the main model - others just need a matching
// base. Helper to grab the first exact match or the first available model if no exact match is found.
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
// Certain models require the ref image model to be the same as the main model - others just need a matching
// base. Helper to grab the first exact match or the first available model if no exact match is found.
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
} else if (newModel.base === 'chatgpt-4o') {
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
} else if (newModel.base === 'gemini-2.5') {
const gemini2_5Models = allRefImageModels.filter(isGemini2_5ModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(gemini2_5Models);
} else if (newModel.base === 'flux-kontext') {
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
} else if (newModel.base === 'flux') {
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
} else {
newGlobalRefImageModel = allRefImageModels[0] ?? null;
}
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
} else if (newModel.base === 'chatgpt-4o') {
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
} else if (newModel.base === 'flux-kontext') {
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
} else if (newModel.base === 'flux') {
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
} else {
newGlobalRefImageModel = allRefImageModels[0] ?? null;
}
// All ref image entities are updated to use the same new model
const refImageEntities = selectReferenceImageEntities(state);
for (const entity of refImageEntities) {
const shouldUpdateModel =
(entity.config.model && entity.config.model.base !== newBase) ||
(!entity.config.model && newGlobalRefImageModel);
// All ref image entities are updated to use the same new model
const refImageEntities = selectReferenceImageEntities(state);
for (const entity of refImageEntities) {
const shouldUpdateModel =
(entity.config.model && entity.config.model.base !== newBase) ||
(!entity.config.model && newGlobalRefImageModel);
if (shouldUpdateModel) {
dispatch(
refImageModelChanged({
id: entity.id,
modelConfig: newGlobalRefImageModel,
})
);
modelsUpdatedDisabledOrCleared += 1;
}
if (shouldUpdateModel) {
dispatch(
refImageModelChanged({
id: entity.id,
modelConfig: newGlobalRefImageModel,
})
);
modelsCleared += 1;
}
}
@@ -135,17 +128,17 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
modelConfig: newRegionalRefImageModel,
})
);
modelsUpdatedDisabledOrCleared += 1;
modelsCleared += 1;
}
}
}
if (modelsUpdatedDisabledOrCleared > 0) {
if (modelsCleared > 0) {
toast({
id: 'BASE_MODEL_CHANGED',
title: t('toast.baseModelChanged'),
description: t('toast.baseModelChangedCleared', {
count: modelsUpdatedDisabledOrCleared,
count: modelsCleared,
}),
status: 'warning',
});

View File

@@ -1,5 +1,5 @@
import type { ThunkDispatch, TypedStartListening, UnknownAction } from '@reduxjs/toolkit';
import { addListener, combineReducers, configureStore, createAction, createListenerMiddleware } from '@reduxjs/toolkit';
import { addListener, combineReducers, configureStore, createListenerMiddleware } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
@@ -40,15 +40,14 @@ import { systemSliceConfig } from 'features/system/store/systemSlice';
import { uiSliceConfig } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { REMEMBER_REHYDRATED, rememberEnhancer, rememberReducer } from 'redux-remember';
import type { Driver, SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable, { newHistory } from 'redux-undo';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { reduxRememberDriver } from './enhancers/reduxRemember/driver';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
@@ -128,10 +127,9 @@ const unserialize: UnserializeFunction = (data, key) => {
let state;
try {
const initialState = getInitialState();
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(deepClone(parsed), keys(initialState));
const stripped = pick(deepClone(data), keys(initialState));
/*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
@@ -143,7 +141,7 @@ const unserialize: UnserializeFunction = (data, key) => {
log.debug(
{
persistedData: parsed as JsonObject,
persistedData: data as JsonObject,
rehydratedData: migrated as JsonObject,
diff: diff(data, migrated) as JsonObject,
},
@@ -184,8 +182,8 @@ const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
.filter((sliceConfig) => !!sliceConfig.persistConfig)
.map((sliceConfig) => sliceConfig.slice.reducerPath);
export const createStore = (options?: { persist?: boolean; persistDebounce?: number; onRehydrated?: () => void }) => {
const store = configureStore({
export const createStore = (reduxRememberOptions: { driver: Driver; persistThrottle: number }) =>
configureStore({
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
@@ -197,23 +195,19 @@ export const createStore = (options?: { persist?: boolean; persistDebounce?: num
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
// .concat(getDebugLoggerMiddleware({ withDiff: true, withNextState: true }))
// .concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const enhancers = getDefaultEnhancers();
if (options?.persist) {
return enhancers.prepend(
rememberEnhancer(reduxRememberDriver, PERSISTED_KEYS, {
persistDebounce: options?.persistDebounce ?? 2000,
serialize,
unserialize,
prefix: '',
errorHandler,
})
);
} else {
return enhancers;
}
return enhancers.prepend(
rememberEnhancer(reduxRememberOptions.driver, PERSISTED_KEYS, {
persistThrottle: reduxRememberOptions.persistThrottle,
serialize,
unserialize,
prefix: '',
errorHandler,
})
);
},
devTools: {
actionSanitizer,
@@ -228,18 +222,6 @@ export const createStore = (options?: { persist?: boolean; persistDebounce?: num
},
});
// Once-off listener to support waiting for rehydration before rendering the app
startAppListening({
actionCreator: createAction(REMEMBER_REHYDRATED),
effect: (action, { unsubscribe }) => {
unsubscribe();
options?.onRehydrated?.();
},
});
return store;
};
export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */

View File

@@ -58,7 +58,6 @@ const zNumericalParameterConfig = z.object({
fineStep: z.number().default(8),
coarseStep: z.number().default(64),
});
export type NumericalParameterConfig = z.infer<typeof zNumericalParameterConfig>;
/**
* Configuration options for the InvokeAI UI.

View File

@@ -1,9 +1,9 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { allEntitiesDeleted, inpaintMaskAdded } from 'features/controlLayers/store/canvasSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { canvasReset } from 'features/controlLayers/store/actions';
import { inpaintMaskAdded } from 'features/controlLayers/store/canvasSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
@@ -11,10 +11,9 @@ import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
export const SessionMenuItems = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const tab = useAppSelector(selectActiveTab);
const resetCanvasLayers = useCallback(() => {
dispatch(allEntitiesDeleted());
dispatch(canvasReset());
dispatch(inpaintMaskAdded({ isSelected: true, isBookmarked: true }));
$canvasManager.get()?.stage.fitBboxToStage();
}, [dispatch]);
@@ -23,16 +22,12 @@ export const SessionMenuItems = memo(() => {
}, [dispatch]);
return (
<>
{tab === 'canvas' && (
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
{t('controlLayers.resetCanvasLayers')}
</MenuItem>
)}
{(tab === 'canvas' || tab === 'generate') && (
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetGenerationSettings}>
{t('controlLayers.resetGenerationSettings')}
</MenuItem>
)}
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
{t('controlLayers.resetCanvasLayers')}
</MenuItem>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetGenerationSettings}>
{t('controlLayers.resetGenerationSettings')}
</MenuItem>
</>
);
});

View File

@@ -1,5 +0,0 @@
const randomFloat = (min: number, max: number): number => {
return Math.random() * (max - min + Number.EPSILON) + min;
};
export default randomFloat;

View File

@@ -8,7 +8,6 @@ import {
isModalOpenChanged,
selectChangeBoardModalSlice,
} from 'features/changeBoardModal/store/slice';
import { selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
@@ -27,8 +26,7 @@ const selectIsModalOpen = createSelector(
const ChangeBoardModal = () => {
useAssertSingleton('ChangeBoardModal');
const dispatch = useAppDispatch();
const currentBoardId = useAppSelector(selectSelectedBoardId);
const [selectedBoardId, setSelectedBoardId] = useState<string | null>();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true });
const isModalOpen = useAppSelector(selectIsModalOpen);
const imagesToChange = useAppSelector(selectImagesToChange);
@@ -37,19 +35,15 @@ const ChangeBoardModal = () => {
const { t } = useTranslation();
const options = useMemo<ComboboxOption[]>(() => {
return [{ label: t('boards.uncategorized'), value: 'none' }]
.concat(
(boards ?? [])
.map((board) => ({
label: board.board_name,
value: board.board_id,
}))
.sort((a, b) => a.label.localeCompare(b.label))
)
.filter((board) => board.value !== currentBoardId);
}, [boards, currentBoardId, t]);
return [{ label: t('boards.uncategorized'), value: 'none' }].concat(
(boards ?? []).map((board) => ({
label: board.board_name,
value: board.board_id,
}))
);
}, [boards, t]);
const value = useMemo(() => options.find((o) => o.value === selectedBoardId), [options, selectedBoardId]);
const value = useMemo(() => options.find((o) => o.value === selectedBoard), [options, selectedBoard]);
const handleClose = useCallback(() => {
dispatch(changeBoardReset());
@@ -57,26 +51,27 @@ const ChangeBoardModal = () => {
}, [dispatch]);
const handleChangeBoard = useCallback(() => {
if (!imagesToChange.length || !selectedBoardId) {
if (!imagesToChange.length || !selectedBoard) {
return;
}
if (selectedBoardId === 'none') {
if (selectedBoard === 'none') {
removeImagesFromBoard({ image_names: imagesToChange });
} else {
addImagesToBoard({
image_names: imagesToChange,
board_id: selectedBoardId,
board_id: selectedBoard,
});
}
setSelectedBoard(null);
dispatch(changeBoardReset());
}, [addImagesToBoard, dispatch, imagesToChange, removeImagesFromBoard, selectedBoardId]);
}, [addImagesToBoard, dispatch, imagesToChange, removeImagesFromBoard, selectedBoard]);
const onChange = useCallback<ComboboxOnChange>((v) => {
if (!v) {
return;
}
setSelectedBoardId(v.value);
setSelectedBoard(v.value);
}, []);
return (
@@ -94,6 +89,7 @@ const ChangeBoardModal = () => {
{t('boards.movingImagesToBoard', {
count: imagesToChange.length,
})}
:
</Text>
<FormControl isDisabled={isFetching}>
<Combobox

View File

@@ -1,24 +0,0 @@
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasAlertsBboxVisibility = memo(() => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const isBboxHidden = useStore(canvasManager.tool.tools.bbox.$isBboxHidden);
if (!isBboxHidden) {
return null;
}
return (
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('controlLayers.warnings.bboxHidden')}</AlertTitle>
</Alert>
);
});
CanvasAlertsBboxVisibility.displayName = 'CanvasAlertsBboxVisibility';

View File

@@ -1,20 +1,15 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
@@ -34,10 +29,7 @@ export const RefImageImage = memo(
dndTargetData,
}: Props<T>) => {
const { t } = useTranslation();
const store = useAppStore();
const isConnected = useStore($isConnected);
const tab = useAppSelector(selectActiveTab);
const isStaging = useCanvasIsStaging();
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
@@ -56,20 +48,6 @@ export const RefImageImage = memo(
[onChangeImage]
);
const recallSizeAndOptimize = useCallback(() => {
if (!imageDTO || (tab === 'canvas' && isStaging)) {
return;
}
const { width, height } = imageDTO;
if (tab === 'canvas') {
store.dispatch(bboxSizeRecalled({ width, height }));
store.dispatch(bboxSizeOptimized());
} else if (tab === 'generate') {
store.dispatch(sizeRecalled({ width, height }));
store.dispatch(sizeOptimized());
}
}, [imageDTO, isStaging, store, tab]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
@@ -91,14 +69,6 @@ export const RefImageImage = memo(
tooltip={t('common.reset')}
/>
</Flex>
<Flex position="absolute" flexDir="column" bottom={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={recallSizeAndOptimize}
icon={<PiRulerBold size={16} />}
tooltip={t('parameters.useSize')}
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />

View File

@@ -63,7 +63,6 @@ RefImageList.displayName = 'RefImageList';
const dndTargetData = addGlobalReferenceImageDndTarget.getData();
const MaxRefImages = memo(() => {
const { t } = useTranslation();
return (
<Button
position="relative"
@@ -76,7 +75,7 @@ const MaxRefImages = memo(() => {
borderRadius="base"
isDisabled
>
{t('controlLayers.maxRefImages')}
Max Ref Images
</Button>
);
});
@@ -84,7 +83,6 @@ MaxRefImages.displayName = 'MaxRefImages';
const AddRefImageDropTargetAndButton = memo(() => {
const { dispatch, getState } = useAppStore();
const { t } = useTranslation();
const tab = useAppSelector(selectActiveTab);
const uploadOptions = useMemo(
@@ -116,7 +114,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
leftIcon={<PiUploadBold />}
{...uploadApi.getUploadButtonProps()}
>
{t('controlLayers.referenceImage')}
Reference Image
<input {...uploadApi.getUploadInputProps()} />
<DndDropTarget label="Drop" dndTarget={addGlobalReferenceImageDndTarget} dndTargetData={dndTargetData} />
</Button>

View File

@@ -10,19 +10,13 @@ import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
Gemini2_5ModelConfig,
IPAdapterModelConfig,
} from 'services/api/types';
type Props = {
modelKey: string | null;
onChangeModel: (
modelConfig:
| IPAdapterModelConfig
| FLUXReduxModelConfig
| ChatGPT4oModelConfig
| FLUXKontextModelConfig
| Gemini2_5ModelConfig
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig
) => void;
};
@@ -34,13 +28,7 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const _onChangeModel = useCallback(
(
modelConfig:
| IPAdapterModelConfig
| FLUXReduxModelConfig
| ChatGPT4oModelConfig
| FLUXKontextModelConfig
| Gemini2_5ModelConfig
| null
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig | null
) => {
if (!modelConfig) {
return;
@@ -51,14 +39,7 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
);
const getIsDisabled = useCallback(
(
model:
| IPAdapterModelConfig
| FLUXReduxModelConfig
| ChatGPT4oModelConfig
| FLUXKontextModelConfig
| Gemini2_5ModelConfig
): boolean => {
(model: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig): boolean => {
return !areBasesCompatibleForRefImage(mainModelConfig, model);
},
[mainModelConfig]

View File

@@ -12,7 +12,7 @@ import {
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { imageNameToImageObject } from 'features/controlLayers/store/util';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useContext, useEffect, useMemo, useState } from 'react';
import { getImageDTOSafe } from 'services/api/endpoints/images';
@@ -71,8 +71,8 @@ export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWi
},
onAccept: (item, imageDTO) => {
const bboxRect = selectBboxRect(store.getState());
const { x, y } = bboxRect;
const imageObject = imageDTOToImageObject(imageDTO);
const { x, y, width, height } = bboxRect;
const imageObject = imageNameToImageObject(imageDTO.image_name, { width, height });
const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState());
const overrides: Partial<CanvasRasterLayerState> = {
position: { x, y },

View File

@@ -15,7 +15,6 @@ import { useCanvasEntityQuickSwitchHotkey } from 'features/controlLayers/hooks/u
import { useCanvasFilterHotkey } from 'features/controlLayers/hooks/useCanvasFilterHotkey';
import { useCanvasInvertMaskHotkey } from 'features/controlLayers/hooks/useCanvasInvertMaskHotkey';
import { useCanvasResetLayerHotkey } from 'features/controlLayers/hooks/useCanvasResetLayerHotkey';
import { useCanvasToggleBboxHotkey } from 'features/controlLayers/hooks/useCanvasToggleBboxHotkey';
import { useCanvasToggleNonRasterLayersHotkey } from 'features/controlLayers/hooks/useCanvasToggleNonRasterLayersHotkey';
import { useCanvasTransformHotkey } from 'features/controlLayers/hooks/useCanvasTransformHotkey';
import { useCanvasUndoRedoHotkeys } from 'features/controlLayers/hooks/useCanvasUndoRedoHotkeys';
@@ -32,7 +31,6 @@ export const CanvasToolbar = memo(() => {
useCanvasFilterHotkey();
useCanvasInvertMaskHotkey();
useCanvasToggleNonRasterLayersHotkey();
useCanvasToggleBboxHotkey();
return (
<Flex w="full" gap={2} alignItems="center" px={2}>

View File

@@ -1,8 +1,6 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiResizeBold } from 'react-icons/pi';
@@ -11,23 +9,9 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const isBusy = useCanvasIsBusy();
const isCanvasFocused = useIsRegionFocused('canvas');
const onClick = useCallback(() => {
canvasManager.tool.tools.bbox.fitToLayers();
canvasManager.stage.fitLayersToStage();
}, [canvasManager.tool.tools.bbox, canvasManager.stage]);
useRegisteredHotkeys({
id: 'fitBboxToLayers',
category: 'canvas',
callback: () => {
canvasManager.tool.tools.bbox.fitToLayers();
canvasManager.stage.fitLayersToStage();
},
options: { enabled: isCanvasFocused && !isBusy, preventDefault: true },
dependencies: [isCanvasFocused, isBusy],
});
}, [canvasManager.tool.tools.bbox]);
return (
<IconButton

View File

@@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { AppGetState } from 'app/store/store';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
@@ -16,11 +16,7 @@ import {
rgRefImageAdded,
} from 'features/controlLayers/store/canvasSlice';
import { selectBase, selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
selectCanvasSlice,
selectEntity,
selectSelectedEntityIdentifier,
} from 'features/controlLayers/store/selectors';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasRegionalGuidanceState,
@@ -28,7 +24,6 @@ import type {
ControlLoRAConfig,
ControlNetConfig,
FluxKontextReferenceImageConfig,
Gemini2_5ReferenceImageConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
@@ -36,7 +31,6 @@ import {
initialChatGPT4oReferenceImage,
initialControlNet,
initialFluxKontextReferenceImage,
initialGemini2_5ReferenceImage,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
@@ -78,11 +72,7 @@ export const selectDefaultControlAdapter = createSelector(
export const getDefaultRefImageConfig = (
getState: AppGetState
):
| IPAdapterConfig
| ChatGPT4oReferenceImageConfig
| FluxKontextReferenceImageConfig
| Gemini2_5ReferenceImageConfig => {
): IPAdapterConfig | ChatGPT4oReferenceImageConfig | FluxKontextReferenceImageConfig => {
const state = getState();
const mainModelConfig = selectMainModelConfig(state);
@@ -103,12 +93,6 @@ export const getDefaultRefImageConfig = (
return config;
}
if (base === 'gemini-2.5') {
const config = deepClone(initialGemini2_5ReferenceImage);
config.model = zModelIdentifierField.parse(mainModelConfig);
return config;
}
// Otherwise, find the first compatible IP Adapter model.
const modelConfig = ipAdapterModelConfigs.find((m) => m.base === base);
@@ -152,49 +136,37 @@ export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState):
export const useAddControlLayer = () => {
const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const selectedControlLayer =
selectedEntityIdentifier?.type === 'control_layer' ? selectedEntityIdentifier.id : undefined;
const func = useCallback(() => {
const overrides = { controlAdapter: deepClone(initialControlNet) };
dispatch(controlLayerAdded({ isSelected: true, overrides, addAfter: selectedControlLayer }));
}, [dispatch, selectedControlLayer]);
dispatch(controlLayerAdded({ isSelected: true, overrides }));
}, [dispatch]);
return func;
};
export const useAddRasterLayer = () => {
const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const selectedRasterLayer =
selectedEntityIdentifier?.type === 'raster_layer' ? selectedEntityIdentifier.id : undefined;
const func = useCallback(() => {
dispatch(rasterLayerAdded({ isSelected: true, addAfter: selectedRasterLayer }));
}, [dispatch, selectedRasterLayer]);
dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]);
return func;
};
export const useAddInpaintMask = () => {
const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const selectedInpaintMask =
selectedEntityIdentifier?.type === 'inpaint_mask' ? selectedEntityIdentifier.id : undefined;
const func = useCallback(() => {
dispatch(inpaintMaskAdded({ isSelected: true, addAfter: selectedInpaintMask }));
}, [dispatch, selectedInpaintMask]);
dispatch(inpaintMaskAdded({ isSelected: true }));
}, [dispatch]);
return func;
};
export const useAddRegionalGuidance = () => {
const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const selectedRegionalGuidance =
selectedEntityIdentifier?.type === 'regional_guidance' ? selectedEntityIdentifier.id : undefined;
const func = useCallback(() => {
dispatch(rgAdded({ isSelected: true, addAfter: selectedRegionalGuidance }));
}, [dispatch, selectedRegionalGuidance]);
dispatch(rgAdded({ isSelected: true }));
}, [dispatch]);
return func;
};

View File

@@ -1,18 +0,0 @@
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useCallback } from 'react';
export const useCanvasToggleBboxHotkey = () => {
const canvasManager = useCanvasManager();
const handleToggleBboxVisibility = useCallback(() => {
canvasManager.tool.tools.bbox.toggleBboxVisibility();
}, [canvasManager]);
useRegisteredHotkeys({
id: 'toggleBbox',
category: 'canvas',
callback: handleToggleBboxVisibility,
dependencies: [handleToggleBboxVisibility],
});
};

View File

@@ -3,7 +3,6 @@ import {
selectIsChatGPT4o,
selectIsCogView4,
selectIsFluxKontext,
selectIsGemini2_5,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
@@ -20,22 +19,21 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isImagen4 = useAppSelector(selectIsImagen4);
const isFluxKontext = useAppSelector(selectIsFluxKontext);
const isChatGPT4o = useAppSelector(selectIsChatGPT4o);
const isGemini2_5 = useAppSelector(selectIsGemini2_5);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o && !isGemini2_5;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o && !isGemini2_5;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o && !isGemini2_5;
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'raster_layer':
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o && !isGemini2_5;
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isFluxKontext, isChatGPT4o, isGemini2_5]);
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isFluxKontext, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -372,7 +372,6 @@ export class CanvasCompositorModule extends CanvasModuleBase {
position: { x: Math.floor(rect.x), y: Math.floor(rect.y) },
},
mergedEntitiesToDelete: deleteMergedEntities ? entityIdentifiers.map(mapId) : [],
addAfter: entityIdentifiers.map(mapId).at(-1),
};
switch (type) {

View File

@@ -214,9 +214,6 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
const isVisible = this.parent.konva.layer.visible();
const isCached = this.konva.objectGroup.isCached();
// We should also never cache if the entity has no dimensions. Konva will log an error to console like this:
// Konva error: Can not cache the node. Width or height of the node equals 0. Caching is skipped.
if (isVisible && (force || !isCached)) {
this.log.trace('Caching object group');
this.konva.objectGroup.clearCache();

View File

@@ -482,24 +482,13 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
// "contain" means that the entity should be scaled to fit within the bbox, but it should not exceed the bbox.
const scale = Math.min(scaleX, scaleY);
// Calculate the scaled dimensions
const scaledWidth = width * scale;
const scaledHeight = height * scale;
// Calculate centered position
const centerX = rect.x + (rect.width - scaledWidth) / 2;
const centerY = rect.y + (rect.height - scaledHeight) / 2;
// Round to grid and clamp to valid bounds
const roundedX = gridSize > 1 ? roundToMultiple(centerX, gridSize) : centerX;
const roundedY = gridSize > 1 ? roundToMultiple(centerY, gridSize) : centerY;
const x = clamp(roundedX, rect.x, rect.x + rect.width - scaledWidth);
const y = clamp(roundedY, rect.y, rect.y + rect.height - scaledHeight);
// Center the shape within the bounding box
const offsetX = (rect.width - width * scale) / 2;
const offsetY = (rect.height - height * scale) / 2;
this.konva.proxyRect.setAttrs({
x,
y,
x: clamp(roundToMultiple(rect.x + offsetX, gridSize), rect.x, rect.x + rect.width),
y: clamp(roundToMultiple(rect.y + offsetY, gridSize), rect.y, rect.y + rect.height),
scaleX: scale,
scaleY: scale,
rotation: 0,
@@ -524,32 +513,16 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const scaleX = rect.width / width;
const scaleY = rect.height / height;
// "cover" means the entity should cover the entire bbox, potentially overflowing
// "cover" is the same as "contain", but we choose the larger scale to cover the shape
const scale = Math.max(scaleX, scaleY);
// Calculate the scaled dimensions
const scaledWidth = width * scale;
const scaledHeight = height * scale;
// Calculate position - center only if entity exceeds bbox
let x = rect.x;
let y = rect.y;
// If scaled width exceeds bbox width, center horizontally
if (scaledWidth > rect.width) {
const centerX = rect.x + (rect.width - scaledWidth) / 2;
x = gridSize > 1 ? roundToMultiple(centerX, gridSize) : centerX;
}
// If scaled height exceeds bbox height, center vertically
if (scaledHeight > rect.height) {
const centerY = rect.y + (rect.height - scaledHeight) / 2;
y = gridSize > 1 ? roundToMultiple(centerY, gridSize) : centerY;
}
// Center the shape within the bounding box
const offsetX = (rect.width - width * scale) / 2;
const offsetY = (rect.height - height * scale) / 2;
this.konva.proxyRect.setAttrs({
x,
y,
x: roundToMultiple(rect.x + offsetX, gridSize),
y: roundToMultiple(rect.y + offsetY, gridSize),
scaleX: scale,
scaleY: scale,
rotation: 0,

View File

@@ -115,7 +115,7 @@ export abstract class CanvasModuleBase {
* ```
*/
destroy: () => void = () => {
this.log.debug('Destroying module');
this.log('Destroying module');
};
/**

View File

@@ -2,7 +2,6 @@ import { objectEquals } from '@observ33r/object-equals';
import { Mutex } from 'async-mutex';
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
@@ -11,21 +10,12 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
import type { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
import type { CanvasStagingAreaModule } from 'features/controlLayers/konva/CanvasStagingAreaModule';
import { getKonvaNodeDebugAttrs, loadImage } from 'features/controlLayers/konva/util';
import type { CanvasImageState, Dimensions } from 'features/controlLayers/store/types';
import type { CanvasImageState } from 'features/controlLayers/store/types';
import { t } from 'i18next';
import Konva from 'konva';
import type { Logger } from 'roarr';
import type { JsonObject } from 'roarr/dist/types';
import { getImageDTOSafe } from 'services/api/endpoints/images';
type CanvasObjectImageConfig = {
usePhysicalDimensions: boolean;
};
const DEFAULT_CONFIG: CanvasObjectImageConfig = {
usePhysicalDimensions: false,
};
export class CanvasObjectImage extends CanvasModuleBase {
readonly type = 'object_image';
readonly id: string;
@@ -40,9 +30,6 @@ export class CanvasObjectImage extends CanvasModuleBase {
readonly log: Logger;
state: CanvasImageState;
config: CanvasObjectImageConfig;
konva: {
group: Konva.Group;
placeholder: { group: Konva.Group; rect: Konva.Rect; text: Konva.Text };
@@ -60,8 +47,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
| CanvasEntityBufferObjectRenderer
| CanvasStagingAreaModule
| CanvasSegmentAnythingModule
| CanvasEntityFilterer,
config = DEFAULT_CONFIG
| CanvasEntityFilterer
) {
super();
this.id = state.id;
@@ -69,7 +55,6 @@ export class CanvasObjectImage extends CanvasModuleBase {
this.manager = parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.config = config;
this.log.debug({ state }, 'Creating module');
@@ -131,10 +116,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url, true));
if (imageElementResult.isErr()) {
// Image loading failed (e.g. the URL to the "physical" image is invalid)
this.onFailedToLoadImage(
t('controlLayers.unableToLoadImage', 'Unable to load image'),
parseify(imageElementResult.error)
);
this.onFailedToLoadImage(t('controlLayers.unableToLoadImage', 'Unable to load image'));
return;
}
@@ -157,10 +139,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
const imageElementResult = await withResultAsync(() => loadImage(dataURL, false));
if (imageElementResult.isErr()) {
// Image loading failed (e.g. the URL to the "physical" image is invalid)
this.onFailedToLoadImage(
t('controlLayers.unableToLoadImage', 'Unable to load image'),
parseify(imageElementResult.error)
);
this.onFailedToLoadImage(t('controlLayers.unableToLoadImage', 'Unable to load image'));
return;
}
@@ -169,8 +148,8 @@ export class CanvasObjectImage extends CanvasModuleBase {
this.updateImageElement();
};
onFailedToLoadImage = (message: string, error?: JsonObject) => {
this.log.error({ image: this.state.image, error }, message);
onFailedToLoadImage = (message: string) => {
this.log({ image: this.state.image }, message);
this.konva.image?.visible(false);
this.isLoading = false;
this.isError = true;
@@ -178,22 +157,9 @@ export class CanvasObjectImage extends CanvasModuleBase {
this.konva.placeholder.group.visible(true);
};
getDimensions = (): Dimensions => {
if (this.config.usePhysicalDimensions && this.imageElement) {
return {
width: this.imageElement.width,
height: this.imageElement.height,
};
}
return {
width: this.state.image.width,
height: this.state.image.height,
};
};
updateImageElement = () => {
if (this.imageElement) {
const { width, height } = this.getDimensions();
const { width, height } = this.state.image;
if (this.konva.image) {
this.log.trace('Updating Konva image attrs');
@@ -230,6 +196,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
this.log.trace({ state }, 'Updating image');
const { image } = state;
const { width, height } = image;
if (force || (!objectEquals(this.state, state) && !this.isLoading)) {
const release = await this.mutex.acquire();
@@ -245,7 +212,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
}
}
this.konva.image?.setAttrs(this.getDimensions());
this.konva.image?.setAttrs({ width, height });
this.state = state;
return true;
}

View File

@@ -230,16 +230,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
if (imageSrc) {
const image = this._getImageFromSrc(imageSrc, width, height);
if (!this.image) {
this.image = new CanvasObjectImage({ id: 'staging-area-image', type: 'image', image }, this, {
// Some models do not make guarantees about their output dimensions. This flag allows the staged images to
// render at their real dimensions, instead of the bbox size.
//
// When the image source is an image name, it is a final output image. In that case, we should use its
// physical dimensions. Otherwise, if it is a dataURL, that means it is a progress image. These come in at
// a smaller resolution and need to be stretched to fill the bbox, so we do not use the physical
// dimensions in that case.
usePhysicalDimensions: imageSrc.type === 'imageName',
});
this.image = new CanvasObjectImage({ id: 'staging-area-image', type: 'image', image }, this);
await this.image.update(this.image.state, true);
this.konva.group.add(this.image.konva.group);
} else if (this.image.isLoading || this.image.isError) {

View File

@@ -231,7 +231,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
/**
* Sets the drawing color, pushing state to redux.
*/
setColor = (color: Partial<RgbaColor>) => {
setColor = (color: RgbaColor) => {
return this.store.dispatch(settingsColorChanged(color));
};
@@ -319,14 +319,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
getPositionGridSize = (): number => {
const snapToGrid = this.getSettings().snapToGrid;
if (!snapToGrid) {
const overrideSnap = this.$ctrlKey.get() || this.$metaKey.get();
if (overrideSnap) {
const useFine = this.$shiftKey.get();
if (useFine) {
return 8;
}
return 64;
}
return 1;
}
const useFine = this.$ctrlKey.get() || this.$metaKey.get();

View File

@@ -30,6 +30,7 @@ const ALL_ANCHORS: string[] = [
'bottom-center',
'bottom-right',
];
const CORNER_ANCHORS: string[] = ['top-left', 'top-right', 'bottom-left', 'bottom-right'];
const NO_ANCHORS: string[] = [];
/**
@@ -65,11 +66,6 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
*/
$aspectRatioBuffer = atom(1);
/**
* Buffer to store the visibility of the bbox.
*/
$isBboxHidden = atom(false);
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
@@ -195,9 +191,6 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
// Update on busy state changes
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
// Listen for stage changes to update the bbox's visibility
this.subscriptions.add(this.$isBboxHidden.listen(this.render));
}
// This is a noop. The cursor is changed when the cursor enters or leaves the bbox.
@@ -213,15 +206,13 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
};
/**
* Renders the bbox.
* Renders the bbox. The bbox is only visible when the tool is set to 'bbox'.
*/
render = () => {
const tool = this.manager.tool.$tool.get();
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
this.konva.group.visible(!this.$isBboxHidden.get());
// We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with.
// If the mangaer is busy, we disable listening so the bbox cannot be interacted with.
this.konva.group.listening(tool === 'bbox' && !this.manager.$isBusy.get());
@@ -343,23 +334,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
let width = roundToMultipleMin(this.konva.proxyRect.width() * this.konva.proxyRect.scaleX(), gridSize);
let height = roundToMultipleMin(this.konva.proxyRect.height() * this.konva.proxyRect.scaleY(), gridSize);
// When resizing the bbox using the transformer, we may need to do some extra math to maintain the current aspect
// ratio. Need to check a few things to determine if we should be maintaining the aspect ratio or not.
let shouldMaintainAspectRatio = false;
if (alt) {
// If alt is held, we are doing center-anchored transforming. In this case, maintaining aspect ratio is rather
// complicated.
shouldMaintainAspectRatio = false;
} else if (this.manager.stateApi.getBbox().aspectRatio.isLocked) {
// When the aspect ratio is locked, holding shift means we SHOULD NOT maintain the aspect ratio
shouldMaintainAspectRatio = !shift;
} else {
// When the aspect ratio is not locked, holding shift means we SHOULD maintain aspect ratio
shouldMaintainAspectRatio = shift;
}
if (shouldMaintainAspectRatio) {
// If shift is held and we are resizing from a corner, retain aspect ratio - needs special handling. We skip this
// if alt/opt is held - this requires math too big for my brain.
if (shift && CORNER_ANCHORS.includes(anchor) && !alt) {
// Fit the bbox to the last aspect ratio
let fittedWidth = Math.sqrt(width * height * this.$aspectRatioBuffer.get());
let fittedHeight = fittedWidth / this.$aspectRatioBuffer.get();
@@ -400,7 +377,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
// Update the aspect ratio buffer whenever the shift key is not held - this allows for a nice UX where you can start
// a transform, get the right aspect ratio, then hold shift to lock it in.
if (!shouldMaintainAspectRatio) {
if (!shift) {
this.$aspectRatioBuffer.set(bboxRect.width / bboxRect.height);
}
};
@@ -501,8 +478,4 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
this.subscriptions.clear();
this.konva.group.destroy();
};
toggleBboxVisibility = () => {
this.$isBboxHidden.set(!this.$isBboxHidden.get());
};
}

View File

@@ -289,14 +289,6 @@ export class CanvasColorPickerToolModule extends CanvasModuleBase {
this.manager.stage.setCursor('none');
};
getCanPick = () => {
if (this.manager.stage.getIsDragging()) {
return false;
}
return true;
};
/**
* Renders the color picker tool preview on the canvas.
*/
@@ -306,11 +298,6 @@ export class CanvasColorPickerToolModule extends CanvasModuleBase {
return;
}
if (!this.getCanPick()) {
this.setVisibility(false);
return;
}
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
@@ -419,21 +406,11 @@ export class CanvasColorPickerToolModule extends CanvasModuleBase {
};
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
if (!this.getCanPick()) {
this.setVisibility(false);
return;
}
const { a: _, ...color } = this.$colorUnderCursor.get();
this.manager.stateApi.setColor(color);
const color = this.$colorUnderCursor.get();
this.manager.stateApi.setColor({ ...color, a: color.a / 255 });
};
onStagePointerMove = (_e: KonvaEventObject<PointerEvent>) => {
if (!this.getCanPick()) {
this.setVisibility(false);
return;
}
this.syncColorUnderCursor();
};

View File

@@ -164,7 +164,7 @@ export class CanvasToolModule extends CanvasModuleBase {
const selectedEntityAdapter = this.manager.stateApi.getSelectedEntityAdapter();
if (this.manager.stage.getIsDragging()) {
stage.setCursor('grabbing');
this.tools.view.syncCursorStyle();
} else if (tool === 'view') {
this.tools.view.syncCursorStyle();
} else if (segmentingAdapter) {

View File

@@ -134,8 +134,8 @@ const slice = createSlice({
settingsEraserWidthChanged: (state, action: PayloadAction<CanvasSettingsState['eraserWidth']>) => {
state.eraserWidth = Math.round(action.payload);
},
settingsColorChanged: (state, action: PayloadAction<Partial<CanvasSettingsState['color']>>) => {
state.color = { ...state.color, ...action.payload };
settingsColorChanged: (state, action: PayloadAction<CanvasSettingsState['color']>) => {
state.color = action.payload;
},
settingsInvertScrollForToolWidthChanged: (
state,

View File

@@ -72,14 +72,12 @@ import {
CHATGPT_ASPECT_RATIOS,
DEFAULT_ASPECT_RATIO_CONFIG,
FLUX_KONTEXT_ASPECT_RATIOS,
GEMINI_2_5_ASPECT_RATIOS,
getEntityIdentifier,
getInitialCanvasState,
IMAGEN_ASPECT_RATIOS,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isFLUXReduxConfig,
isGemini2_5AspectRatioID,
isImagenAspectRatioID,
isIPAdapterConfig,
zCanvasState,
@@ -113,16 +111,12 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}>
) => {
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [], addAfter } = action.payload;
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [] } = action.payload;
const entityState = getRasterLayerState(id, overrides);
const index = addAfter
? state.rasterLayers.entities.findIndex((e) => e.id === addAfter) + 1
: state.rasterLayers.entities.length;
state.rasterLayers.entities.splice(index, 0, entityState);
state.rasterLayers.entities.push(entityState);
if (mergedEntitiesToDelete.length > 0) {
state.rasterLayers.entities = state.rasterLayers.entities.filter(
@@ -145,7 +139,6 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}) => ({
payload: { ...payload, id: getPrefixedId('raster_layer') },
}),
@@ -279,17 +272,13 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}>
) => {
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [], addAfter } = action.payload;
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [] } = action.payload;
const entityState = getControlLayerState(id, overrides);
const index = addAfter
? state.controlLayers.entities.findIndex((e) => e.id === addAfter) + 1
: state.controlLayers.entities.length;
state.controlLayers.entities.splice(index, 0, entityState);
state.controlLayers.entities.push(entityState);
if (mergedEntitiesToDelete.length > 0) {
state.controlLayers.entities = state.controlLayers.entities.filter(
@@ -311,7 +300,6 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}) => ({
payload: { ...payload, id: getPrefixedId('control_layer') },
}),
@@ -582,17 +570,13 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}>
) => {
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [], addAfter } = action.payload;
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [] } = action.payload;
const entityState = getRegionalGuidanceState(id, overrides);
const index = addAfter
? state.regionalGuidance.entities.findIndex((e) => e.id === addAfter) + 1
: state.regionalGuidance.entities.length;
state.regionalGuidance.entities.splice(index, 0, entityState);
state.regionalGuidance.entities.push(entityState);
if (mergedEntitiesToDelete.length > 0) {
state.regionalGuidance.entities = state.regionalGuidance.entities.filter(
@@ -614,7 +598,6 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}) => ({
payload: { ...payload, id: getPrefixedId('regional_guidance') },
}),
@@ -891,17 +874,13 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}>
) => {
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [], addAfter } = action.payload;
const { id, overrides, isSelected, isBookmarked, mergedEntitiesToDelete = [] } = action.payload;
const entityState = getInpaintMaskState(id, overrides);
const index = addAfter
? state.inpaintMasks.entities.findIndex((e) => e.id === addAfter) + 1
: state.inpaintMasks.entities.length;
state.inpaintMasks.entities.splice(index, 0, entityState);
state.inpaintMasks.entities.push(entityState);
if (mergedEntitiesToDelete.length > 0) {
state.inpaintMasks.entities = state.inpaintMasks.entities.filter(
@@ -923,7 +902,6 @@ const slice = createSlice({
isSelected?: boolean;
isBookmarked?: boolean;
mergedEntitiesToDelete?: string[];
addAfter?: string;
}) => ({
payload: { ...payload, id: getPrefixedId('inpaint_mask') },
}),
@@ -1113,15 +1091,6 @@ const slice = createSlice({
syncScaledSize(state);
},
bboxSizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
const { width, height } = action.payload;
const gridSize = getGridSize(state.bbox.modelBase);
state.bbox.rect.width = Math.max(roundDownToMultiple(width, gridSize), 64);
state.bbox.rect.height = Math.max(roundDownToMultiple(height, gridSize), 64);
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.id = 'Free';
state.bbox.aspectRatio.isLocked = true;
},
bboxAspectRatioLockToggled: (state) => {
state.bbox.aspectRatio.isLocked = !state.bbox.aspectRatio.isLocked;
syncScaledSize(state);
@@ -1146,12 +1115,6 @@ const slice = createSlice({
state.bbox.rect.height = height;
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'gemini-2.5' && isGemini2_5AspectRatioID(id)) {
const { width, height } = GEMINI_2_5_ASPECT_RATIOS[id];
state.bbox.rect.width = width;
state.bbox.rect.height = height;
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'flux-kontext' && isFluxKontextAspectRatioID(id)) {
const { width, height } = FLUX_KONTEXT_ASPECT_RATIOS[id];
state.bbox.rect.width = width;
@@ -1277,33 +1240,25 @@ const slice = createSlice({
newEntity.name = `${newEntity.name} (Copy)`;
}
switch (newEntity.type) {
case 'raster_layer': {
case 'raster_layer':
newEntity.id = getPrefixedId('raster_layer');
const newEntityIndex = state.rasterLayers.entities.findIndex((e) => e.id === entityIdentifier.id) + 1;
state.rasterLayers.entities.splice(newEntityIndex, 0, newEntity);
state.rasterLayers.entities.push(newEntity);
break;
}
case 'control_layer': {
case 'control_layer':
newEntity.id = getPrefixedId('control_layer');
const newEntityIndex = state.controlLayers.entities.findIndex((e) => e.id === entityIdentifier.id) + 1;
state.controlLayers.entities.splice(newEntityIndex, 0, newEntity);
state.controlLayers.entities.push(newEntity);
break;
}
case 'regional_guidance': {
case 'regional_guidance':
newEntity.id = getPrefixedId('regional_guidance');
for (const refImage of newEntity.referenceImages) {
refImage.id = getPrefixedId('regional_guidance_ip_adapter');
}
const newEntityIndex = state.regionalGuidance.entities.findIndex((e) => e.id === entityIdentifier.id) + 1;
state.regionalGuidance.entities.splice(newEntityIndex, 0, newEntity);
state.regionalGuidance.entities.push(newEntity);
break;
}
case 'inpaint_mask': {
case 'inpaint_mask':
newEntity.id = getPrefixedId('inpaint_mask');
const newEntityIndex = state.inpaintMasks.entities.findIndex((e) => e.id === entityIdentifier.id) + 1;
state.inpaintMasks.entities.splice(newEntityIndex, 0, newEntity);
state.inpaintMasks.entities.push(newEntity);
break;
}
}
state.selectedEntityIdentifier = getEntityIdentifier(newEntity);
@@ -1664,7 +1619,6 @@ export const {
entityArrangedToBack,
entityOpacityChanged,
entitiesReordered,
allEntitiesDeleted,
allEntitiesOfTypeIsHiddenToggled,
allNonRasterLayersIsHiddenToggled,
// bbox
@@ -1672,7 +1626,6 @@ export const {
bboxScaledWidthChanged,
bboxScaledHeightChanged,
bboxScaleMethodChanged,
bboxSizeRecalled,
bboxWidthChanged,
bboxHeightChanged,
bboxAspectRatioLockToggled,

View File

@@ -11,26 +11,15 @@ import {
CHATGPT_ASPECT_RATIOS,
DEFAULT_ASPECT_RATIO_CONFIG,
FLUX_KONTEXT_ASPECT_RATIOS,
GEMINI_2_5_ASPECT_RATIOS,
getInitialParamsState,
IMAGEN_ASPECT_RATIOS,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isGemini2_5AspectRatioID,
isImagenAspectRatioID,
zParamsState,
} from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import {
API_BASE_MODELS,
CLIP_SKIP_MAP,
SUPPORTS_ASPECT_RATIO_BASE_MODELS,
SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS,
SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS,
SUPPORTS_PIXEL_DIMENSIONS_BASE_MODELS,
SUPPORTS_REF_IMAGES_BASE_MODELS,
SUPPORTS_SEED_BASE_MODELS,
} from 'features/parameters/types/constants';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
import type {
ParameterCanvasCoherenceMode,
ParameterCFGRescaleMultiplier,
@@ -118,15 +107,14 @@ const slice = createSlice({
return;
}
if (API_BASE_MODELS.includes(model.base)) {
state.dimensions.aspectRatio.isLocked = true;
state.dimensions.aspectRatio.value = 1;
state.dimensions.aspectRatio.id = '1:1';
state.dimensions.rect.width = 1024;
state.dimensions.rect.height = 1024;
// Clamp CLIP skip layer count to the bounds of the new model
if (model.base === 'sdxl') {
// We don't support user-defined CLIP skip for SDXL because it doesn't do anything useful
state.clipSkip = 0;
} else {
const { maxClip } = CLIP_SKIP_MAP[model.base];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
}
applyClipSkip(state, model, state.clipSkip);
},
vaeSelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
// null is a valid VAE!
@@ -182,7 +170,7 @@ const slice = createSlice({
state.vaePrecision = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => {
applyClipSkip(state, state.model, action.payload);
state.clipSkip = action.payload;
},
shouldUseCpuNoiseChanged: (state, action: PayloadAction<boolean>) => {
state.shouldUseCpuNoise = action.payload;
@@ -193,6 +181,15 @@ const slice = createSlice({
negativePromptChanged: (state, action: PayloadAction<ParameterNegativePrompt>) => {
state.negativePrompt = action.payload;
},
positivePrompt2Changed: (state, action: PayloadAction<string>) => {
state.positivePrompt2 = action.payload;
},
negativePrompt2Changed: (state, action: PayloadAction<string>) => {
state.negativePrompt2 = action.payload;
},
shouldConcatPromptsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldConcatPrompts = action.payload;
},
refinerModelChanged: (state, action: PayloadAction<ParameterSDXLRefinerModel | null>) => {
const result = zParamsState.shape.refinerModel.safeParse(action.payload);
if (!result.success) {
@@ -244,15 +241,6 @@ const slice = createSlice({
},
//#region Dimensions
sizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
const { width, height } = action.payload;
const gridSize = getGridSize(state.model?.base);
state.dimensions.rect.width = Math.max(roundDownToMultiple(width, gridSize), 64);
state.dimensions.rect.height = Math.max(roundDownToMultiple(height, gridSize), 64);
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.id = 'Free';
state.dimensions.aspectRatio.isLocked = true;
},
widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { width, updateAspectRatio, clamp } = action.payload;
const gridSize = getGridSize(state.model?.base);
@@ -309,12 +297,6 @@ const slice = createSlice({
state.dimensions.rect.height = height;
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.isLocked = true;
} else if (state.model?.base === 'gemini-2.5' && isGemini2_5AspectRatioID(id)) {
const { width, height } = GEMINI_2_5_ASPECT_RATIOS[id];
state.dimensions.rect.width = width;
state.dimensions.rect.height = height;
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.isLocked = true;
} else if (state.model?.base === 'flux-kontext' && isFluxKontextAspectRatioID(id)) {
const { width, height } = FLUX_KONTEXT_ASPECT_RATIOS[id];
state.dimensions.rect.width = width;
@@ -384,46 +366,17 @@ const slice = createSlice({
},
});
const applyClipSkip = (state: { clipSkip: number }, model: ParameterModel | null, clipSkip: number) => {
if (model === null) {
return;
}
const maxClip = getModelMaxClipSkip(model);
state.clipSkip = clamp(clipSkip, 0, maxClip);
};
const hasModelClipSkip = (model: ParameterModel | null) => {
if (model === null) {
return false;
}
return getModelMaxClipSkip(model) > 0;
};
const getModelMaxClipSkip = (model: ParameterModel) => {
if (model.base === 'sdxl') {
// We don't support user-defined CLIP skip for SDXL because it doesn't do anything useful
return 0;
}
return CLIP_SKIP_MAP[model.base].maxClip;
};
const resetState = (state: ParamsState): ParamsState => {
// When a new session is requested, we need to keep the current model selections, plus dependent state
// like VAE precision. Everything else gets reset to default.
const oldState = deepClone(state);
const newState = getInitialParamsState();
newState.dimensions = oldState.dimensions;
newState.model = oldState.model;
newState.vae = oldState.vae;
newState.fluxVAE = oldState.fluxVAE;
newState.vaePrecision = oldState.vaePrecision;
newState.t5EncoderModel = oldState.t5EncoderModel;
newState.clipEmbedModel = oldState.clipEmbedModel;
newState.refinerModel = oldState.refinerModel;
newState.model = state.model;
newState.vae = state.vae;
newState.fluxVAE = state.fluxVAE;
newState.vaePrecision = state.vaePrecision;
newState.t5EncoderModel = state.t5EncoderModel;
newState.clipEmbedModel = state.clipEmbedModel;
newState.refinerModel = state.refinerModel;
return newState;
};
@@ -461,6 +414,9 @@ export const {
shouldUseCpuNoiseChanged,
positivePromptChanged,
negativePromptChanged,
positivePrompt2Changed,
negativePrompt2Changed,
shouldConcatPromptsChanged,
refinerModelChanged,
setRefinerSteps,
setRefinerCFGScale,
@@ -471,7 +427,6 @@ export const {
modelChanged,
// Dimensions
sizeRecalled,
widthChanged,
heightChanged,
aspectRatioLockToggled,
@@ -493,7 +448,8 @@ export const paramsSliceConfig: SliceConfig<typeof slice> = {
};
export const selectParamsSlice = (state: RootState) => state.params;
const createParamsSelector = <T>(selector: Selector<ParamsState, T>) => createSelector(selectParamsSlice, selector);
export const createParamsSelector = <T>(selector: Selector<ParamsState, T>) =>
createSelector(selectParamsSlice, selector);
export const selectBase = createParamsSelector((params) => params.model?.base);
export const selectIsSDXL = createParamsSelector((params) => params.model?.base === 'sdxl');
@@ -502,6 +458,7 @@ export const selectIsSD3 = createParamsSelector((params) => params.model?.base =
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
export const selectIsFluxKontextApi = createParamsSelector((params) => params.model?.base === 'flux-kontext');
export const selectIsFluxKontext = createParamsSelector((params) => {
if (params.model?.base === 'flux-kontext') {
return true;
@@ -512,7 +469,6 @@ export const selectIsFluxKontext = createParamsSelector((params) => {
return false;
});
export const selectIsChatGPT4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectIsGemini2_5 = createParamsSelector((params) => params.model?.base === 'gemini-2.5');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);
@@ -529,8 +485,7 @@ export const selectCFGScale = createParamsSelector((params) => params.cfgScale);
export const selectGuidance = createParamsSelector((params) => params.guidance);
export const selectSteps = createParamsSelector((params) => params.steps);
export const selectCFGRescaleMultiplier = createParamsSelector((params) => params.cfgRescaleMultiplier);
export const selectCLIPSkip = createParamsSelector((params) => params.clipSkip);
export const selectHasModelCLIPSkip = createParamsSelector((params) => hasModelClipSkip(params.model));
export const selectCLIPSKip = createParamsSelector((params) => params.clipSkip);
export const selectCanvasCoherenceEdgeSize = createParamsSelector((params) => params.canvasCoherenceEdgeSize);
export const selectCanvasCoherenceMinDenoise = createParamsSelector((params) => params.canvasCoherenceMinDenoise);
export const selectCanvasCoherenceMode = createParamsSelector((params) => params.canvasCoherenceMode);
@@ -548,33 +503,12 @@ export const selectNegativePrompt = createParamsSelector((params) => params.nega
export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? '');
export const selectHasNegativePrompt = createParamsSelector((params) => params.negativePrompt !== null);
export const selectModelSupportsNegativePrompt = createSelector(
selectModel,
(model) => !!model && SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsSeed = createSelector(
selectModel,
(model) => !!model && SUPPORTS_SEED_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsRefImages = createSelector(
selectModel,
(model) => !!model && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsAspectRatio = createSelector(
selectModel,
(model) => !!model && SUPPORTS_ASPECT_RATIO_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsPixelDimensions = createSelector(
selectModel,
(model) => !!model && SUPPORTS_PIXEL_DIMENSIONS_BASE_MODELS.includes(model.base)
);
export const selectIsApiBaseModel = createSelector(
selectModel,
(model) => !!model && API_BASE_MODELS.includes(model.base)
);
export const selectModelSupportsOptimizedDenoising = createSelector(
selectModel,
(model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
[selectIsFLUX, selectIsChatGPT4o, selectIsFluxKontext],
(isFLUX, isChatGPT4o, isFluxKontext) => !isFLUX && !isChatGPT4o && !isFluxKontext
);
export const selectPositivePrompt2 = createParamsSelector((params) => params.positivePrompt2);
export const selectNegativePrompt2 = createParamsSelector((params) => params.negativePrompt2);
export const selectShouldConcatPrompts = createParamsSelector((params) => params.shouldConcatPrompts);
export const selectScheduler = createParamsSelector((params) => params.scheduler);
export const selectSeamlessXAxis = createParamsSelector((params) => params.seamlessXAxis);
export const selectSeamlessYAxis = createParamsSelector((params) => params.seamlessYAxis);

View File

@@ -26,7 +26,6 @@ import {
initialChatGPT4oReferenceImage,
initialFluxKontextReferenceImage,
initialFLUXRedux,
initialGemini2_5ReferenceImage,
initialIPAdapter,
} from './util';
@@ -137,16 +136,6 @@ const slice = createSlice({
return;
}
if (entity.config.model.base === 'gemini-2.5') {
// Switching to Gemini 2.5 Flash Preview (nano banana) ref image
entity.config = {
...initialGemini2_5ReferenceImage,
image: entity.config.image,
model: entity.config.model,
};
return;
}
if (
entity.config.model.base === 'flux-kontext' ||
(entity.config.model.base === 'flux' && entity.config.model.name?.toLowerCase().includes('kontext'))

View File

@@ -1,5 +1,6 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { ProgressImage } from 'features/nodes/types/common';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import {
zParameterCanvasCoherenceMode,
@@ -14,7 +15,9 @@ import {
zParameterMaskBlurMethod,
zParameterModel,
zParameterNegativePrompt,
zParameterNegativeStylePromptSDXL,
zParameterPositivePrompt,
zParameterPositiveStylePromptSDXL,
zParameterPrecision,
zParameterScheduler,
zParameterSDXLRefinerModel,
@@ -264,13 +267,6 @@ const zChatGPT4oReferenceImageConfig = z.object({
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zGemini2_5ReferenceImageConfig = z.object({
type: z.literal('gemini_2_5_reference_image'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
});
export type Gemini2_5ReferenceImageConfig = z.infer<typeof zGemini2_5ReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
@@ -293,7 +289,6 @@ export const zRefImageState = z.object({
zFLUXReduxConfig,
zChatGPT4oReferenceImageConfig,
zFluxKontextReferenceImageConfig,
zGemini2_5ReferenceImageConfig,
]),
});
export type RefImageState = z.infer<typeof zRefImageState>;
@@ -306,15 +301,10 @@ export const isFLUXReduxConfig = (config: RefImageState['config']): config is FL
export const isChatGPT4oReferenceImageConfig = (
config: RefImageState['config']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
export const isFluxKontextReferenceImageConfig = (
config: RefImageState['config']
): config is FluxKontextReferenceImageConfig => config.type === 'flux_kontext_reference_image';
export const isGemini2_5ReferenceImageConfig = (
config: RefImageState['config']
): config is Gemini2_5ReferenceImageConfig => config.type === 'gemini_2_5_reference_image';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
export const isFillStyle = (v: unknown): v is FillStyle => zFillStyle.safeParse(v).success;
@@ -424,6 +414,8 @@ export const zLoRA = z.object({
});
export type LoRA = z.infer<typeof zLoRA>;
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
export const zAspectRatioID = z.enum(['Free', '21:9', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16', '9:21']);
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
@@ -460,14 +452,6 @@ export const CHATGPT_ASPECT_RATIOS: Record<ChatGPT4oAspectRatio, Dimensions> = {
'2:3': { width: 1024, height: 1536 },
} as const;
export const zGemini2_5AspectRatioID = z.enum(['1:1']);
type Gemini2_5AspectRatio = z.infer<typeof zGemini2_5AspectRatioID>;
export const isGemini2_5AspectRatioID = (v: unknown): v is Gemini2_5AspectRatio =>
zGemini2_5AspectRatioID.safeParse(v).success;
export const GEMINI_2_5_ASPECT_RATIOS: Record<Gemini2_5AspectRatio, Dimensions> = {
'1:1': { width: 1024, height: 1024 },
} as const;
export const zFluxKontextAspectRatioID = z.enum(['21:9', '16:9', '4:3', '1:1', '3:4', '9:16', '9:21']);
type FluxKontextAspectRatio = z.infer<typeof zFluxKontextAspectRatioID>;
export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer<typeof zFluxKontextAspectRatioID> =>
@@ -512,8 +496,6 @@ const zBboxState = z.object({
});
const zDimensionsState = z.object({
// TODO(psyche): There is no concept of x/y coords for the dimensions state here... It's just width and height.
// Remove the extraneous data.
rect: z.object({
x: z.number().int(),
y: z.number().int(),
@@ -555,6 +537,9 @@ export const zParamsState = z.object({
shouldUseCpuNoise: z.boolean(),
positivePrompt: zParameterPositivePrompt,
negativePrompt: zParameterNegativePrompt,
positivePrompt2: zParameterPositiveStylePromptSDXL,
negativePrompt2: zParameterNegativeStylePromptSDXL,
shouldConcatPrompts: z.boolean(),
refinerModel: zParameterSDXLRefinerModel.nullable(),
refinerSteps: z.number(),
refinerCFGScale: z.number(),
@@ -602,6 +587,9 @@ export const getInitialParamsState = (): ParamsState => ({
shouldUseCpuNoise: true,
positivePrompt: '',
negativePrompt: null,
positivePrompt2: '',
negativePrompt2: '',
shouldConcatPrompts: true,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
@@ -678,12 +666,7 @@ export const getInitialRefImagesState = (): RefImagesState => ({
export const zCanvasReferenceImageState_OLD = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
ipAdapter: z.discriminatedUnion('type', [
zIPAdapterConfig,
zFLUXReduxConfig,
zChatGPT4oReferenceImageConfig,
zGemini2_5ReferenceImageConfig,
]),
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
});
export const zCanvasMetadata = z.object({

View File

@@ -10,9 +10,9 @@ import type {
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
Dimensions,
FluxKontextReferenceImageConfig,
FLUXReduxConfig,
Gemini2_5ReferenceImageConfig,
ImageWithDims,
IPAdapterConfig,
RefImageState,
@@ -38,6 +38,22 @@ export const imageDTOToImageObject = (imageDTO: ImageDTO, overrides?: Partial<Ca
};
};
export const imageNameToImageObject = (
imageName: string,
dimensions: Dimensions,
overrides?: Partial<CanvasImageState>
): CanvasImageState => {
return {
id: getPrefixedId('image'),
type: 'image',
image: {
image_name: imageName,
...dimensions,
},
...overrides,
};
};
export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO): ImageWithDims => ({
image_name,
width,
@@ -89,11 +105,6 @@ export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
image: null,
model: null,
};
export const initialGemini2_5ReferenceImage: Gemini2_5ReferenceImageConfig = {
type: 'gemini_2_5_reference_image',
image: null,
model: null,
};
export const initialFluxKontextReferenceImage: FluxKontextReferenceImageConfig = {
type: 'flux_kontext_reference_image',
image: null,

View File

@@ -27,7 +27,6 @@ export const DndImageIcon = memo((props: Props) => {
return (
<IconButton
onClick={onClick}
tooltip={tooltip}
aria-label={tooltip}
icon={icon}
variant="link"

View File

@@ -53,7 +53,6 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => {
color={isSelected ? 'base.100' : 'base.300'}
onDoubleClick={editable.startEditing}
cursor="text"
noOfLines={1}
>
{editable.value}
</Text>

Some files were not shown because too many files have changed in this diff Show More