mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 19:58:13 -05:00
Compare commits
36 Commits
v5.10.0dev
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a8f4c593f | ||
|
|
29c78f0e5e | ||
|
|
501534e2e1 | ||
|
|
50c7318004 | ||
|
|
7f14597012 | ||
|
|
dbe68b364f | ||
|
|
0c7aa85a5c | ||
|
|
703e1c8001 | ||
|
|
b056c93ea3 | ||
|
|
4289241943 | ||
|
|
51f5abf5f9 | ||
|
|
e59fa59ad7 | ||
|
|
2407cb64b3 | ||
|
|
70f704ab44 | ||
|
|
b786032b89 | ||
|
|
e8cc06cc92 | ||
|
|
8e6c56c93d | ||
|
|
69d4ee7f93 | ||
|
|
567fd3e0da | ||
|
|
0b8f88e554 | ||
|
|
60f0c4bf99 | ||
|
|
900ec92ef1 | ||
|
|
2594768479 | ||
|
|
91ab81eca9 | ||
|
|
b20c745c6e | ||
|
|
e41a37bca0 | ||
|
|
9ca44f27a5 | ||
|
|
b9ddf67853 | ||
|
|
afe088045f | ||
|
|
09ca61a962 | ||
|
|
dd69a96c03 | ||
|
|
4a54e594d0 | ||
|
|
936ed1960a | ||
|
|
9fac7986c7 | ||
|
|
e4b603f44e | ||
|
|
7edfe6edcf |
@@ -1,11 +1,9 @@
|
||||
*
|
||||
!invokeai
|
||||
!pyproject.toml
|
||||
!uv.lock
|
||||
!docker/docker-entrypoint.sh
|
||||
!LICENSE
|
||||
|
||||
**/dist
|
||||
**/node_modules
|
||||
**/__pycache__
|
||||
**/*.egg-info
|
||||
**/*.egg-info
|
||||
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@@ -2,11 +2,11 @@
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
@@ -22,7 +22,7 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername @jazzhaiku
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
|
||||
2
.github/workflows/build-container.yml
vendored
2
.github/workflows/build-container.yml
vendored
@@ -97,8 +97,6 @@ jobs:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
build-args: |
|
||||
GPU_DRIVER=${{ matrix.gpu-driver }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
2
.github/workflows/build-installer.yml
vendored
2
.github/workflows/build-installer.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
|
||||
@@ -1,6 +1,77 @@
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
#### Web UI ------------------------------------
|
||||
## Builder stage
|
||||
|
||||
FROM library/ubuntu:24.04 AS builder
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
build-essential \
|
||||
git
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV UV_PYTHON=3.11
|
||||
ENV UV_COMPILE_BYTECODE=1
|
||||
ENV UV_LINK_MODE=copy
|
||||
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
|
||||
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
|
||||
RUN mkdir -p ${VIRTUAL_ENV} && \
|
||||
mkdir -p ${INVOKEAI_SRC} && \
|
||||
chmod -R a+w /opt && \
|
||||
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
|
||||
USER ubuntu
|
||||
|
||||
# Install python
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
uv python install ${PYTHON_VERSION}
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
|
||||
fi && \
|
||||
uv sync --no-install-project
|
||||
|
||||
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
|
||||
COPY invokeai invokeai
|
||||
COPY pyproject.toml .
|
||||
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
|
||||
fi && \
|
||||
uv sync
|
||||
|
||||
|
||||
#### Build the Web UI ------------------------------------
|
||||
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
@@ -14,89 +85,69 @@ RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN npx vite build
|
||||
|
||||
## Backend ---------------------------------------
|
||||
#### Runtime stage ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:24.04
|
||||
FROM library/ubuntu:24.04 AS runtime
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
apt update && apt install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
git \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
ENV \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
INVOKEAI_SRC=/opt/invokeai \
|
||||
PYTHON_VERSION=3.12 \
|
||||
UV_PYTHON=3.12 \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cu124" \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
CONTAINER_UID=${CONTAINER_UID:-1000} \
|
||||
CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
RUN apt update && apt install -y --no-install-recommends \
|
||||
git \
|
||||
curl \
|
||||
vim \
|
||||
tmux \
|
||||
ncdu \
|
||||
iotop \
|
||||
bzip2 \
|
||||
gosu \
|
||||
magic-wormhole \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV INVOKEAI_HOST=0.0.0.0
|
||||
ENV INVOKEAI_PORT=9090
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
|
||||
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
|
||||
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
|
||||
USER ubuntu
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
USER root
|
||||
|
||||
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install ${PYTHON_VERSION} && \
|
||||
# chmod --recursive a+rX /root/.local/share/uv/python
|
||||
chmod 711 /root
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
|
||||
fi && \
|
||||
uv sync --frozen
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web"]
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# add sources last to minimize image changes on code changes
|
||||
COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
```
|
||||
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
@@ -43,10 +43,10 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
3. Create a virtual environment in that directory:
|
||||
|
||||
```sh
|
||||
uv venv --relocatable --prompt invoke --python 3.12 --python-preference only-managed .venv
|
||||
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
|
||||
```
|
||||
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.12. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
|
||||
4. Activate the virtual environment:
|
||||
|
||||
@@ -64,7 +64,7 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
|
||||
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
|
||||
|
||||
6. Determine the package specifier to use when installing. This is a performance optimization.
|
||||
6. Determine the package package specifier to use when installing. This is a performance optimization.
|
||||
|
||||
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
|
||||
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
|
||||
@@ -88,13 +88,13 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
|
||||
```
|
||||
|
||||
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
```
|
||||
|
||||
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
@@ -41,7 +41,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
|
||||
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
|
||||
|
||||
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
|
||||
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
|
||||
|
||||
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
|
||||
@@ -49,19 +49,19 @@ Check that your system has an up-to-date Python installed by running `python3 --
|
||||
|
||||
=== "Windows"
|
||||
|
||||
- Install python with [an official installer].
|
||||
- Install python 3.11 with [an official installer].
|
||||
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
|
||||
- You may need to install [Microsoft Visual C++ Redistributable].
|
||||
|
||||
=== "macOS"
|
||||
|
||||
- Install python with [an official installer].
|
||||
- Install python 3.11 with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
- Installing python varies depending on your system. We recommend [using `uv` to manage your python installation](https://docs.astral.sh/uv/concepts/python-versions/#installing-a-python-version).
|
||||
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
|
||||
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
|
||||
|
||||
## Drivers
|
||||
|
||||
@@ -37,13 +37,7 @@ from invokeai.app.services.style_preset_records.style_preset_records_sqlite impo
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -107,25 +101,10 @@ class ApiDependencies:
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[torch.Tensor](
|
||||
output_folder / "tensors",
|
||||
safe_globals=[torch.Tensor],
|
||||
ephemeral=True,
|
||||
),
|
||||
max_cache_size=0,
|
||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](
|
||||
output_folder / "conditioning",
|
||||
safe_globals=[
|
||||
ConditioningFieldData,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.fields import BoardField
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
@@ -23,6 +26,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.compose_pydantic_model import compose_model_from_fields
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
@@ -35,10 +39,15 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
class ValidationRunData(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow being published.")
|
||||
input_fields: list[FieldIdentifier] = Body(description="The input fields for the published workflow")
|
||||
output_fields: list[FieldIdentifier] = Body(description="The output fields for the published workflow")
|
||||
class SimpleModelIdentifer(BaseModel):
|
||||
id: str = Field(description="The model id")
|
||||
|
||||
|
||||
model_field_overrides = {ModelIdentifierField: (SimpleModelIdentifer, Field(description="The model identifier"))}
|
||||
|
||||
|
||||
def model_field_filter(field_type: type[Any]) -> bool:
|
||||
return field_type not in {BoardField, Optional[BoardField]}
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
@@ -52,13 +61,52 @@ async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
validation_run_data: Optional[ValidationRunData] = Body(
|
||||
default=None,
|
||||
description="The validation run data to use for this batch. This is only used if this is a validation run.",
|
||||
is_api_validation_run: bool = Body(
|
||||
default=False,
|
||||
description="Whether or not this is a validation run.",
|
||||
),
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Body(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
),
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Body(
|
||||
default=None, description="The fields that were used as output from the API"
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
if is_api_validation_run:
|
||||
session_count = batch.get_session_count()
|
||||
assert session_count == 1, "API validation run only supports single session batches"
|
||||
|
||||
if api_input_fields:
|
||||
composed_model = compose_model_from_fields(
|
||||
g=batch.graph,
|
||||
field_identifiers=api_input_fields,
|
||||
composed_model_class_name="APIInputModel",
|
||||
model_field_overrides=model_field_overrides,
|
||||
model_field_filter=model_field_filter,
|
||||
)
|
||||
json_schema = composed_model.model_json_schema(mode="validation")
|
||||
print("API Input Model")
|
||||
print(json.dumps(json_schema))
|
||||
|
||||
if api_output_fields:
|
||||
composed_model = compose_model_from_fields(
|
||||
g=batch.graph,
|
||||
field_identifiers=api_output_fields,
|
||||
composed_model_class_name="APIOutputModel",
|
||||
)
|
||||
json_schema = composed_model.model_json_schema(mode="validation")
|
||||
print("API Output Model")
|
||||
print(json.dumps(json_schema))
|
||||
|
||||
print("graph")
|
||||
print(batch.graph.model_dump_json())
|
||||
|
||||
if batch.workflow is not None:
|
||||
print("workflow")
|
||||
print(batch.workflow.model_dump_json())
|
||||
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import io
|
||||
import random
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
@@ -24,6 +25,37 @@ from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
ids = {
|
||||
"6614752a-0420-4d81-98fc-e110069d4f38": random.choice([True, False]),
|
||||
"default_5e8b008d-c697-45d0-8883-085a954c6ace": random.choice([True, False]),
|
||||
"4b2b297a-0d47-4f43-8113-ebbf3f403089": random.choice([True, False]),
|
||||
"d0ce602a-049e-4368-97ae-977b49eed042": random.choice([True, False]),
|
||||
"f170a187-fd74-40b8-ba9c-00de173ea4b9": random.choice([True, False]),
|
||||
"default_f96e794f-eb3e-4d01-a960-9b4e43402bcf": random.choice([True, False]),
|
||||
"default_cbf0e034-7b54-4b2c-b670-3b1e2e4b4a88": random.choice([True, False]),
|
||||
"default_dec5a2e9-f59c-40d9-8869-a056751d79b8": random.choice([True, False]),
|
||||
"default_dbe46d95-22aa-43fb-9c16-94400d0ce2fd": random.choice([True, False]),
|
||||
"default_d7a1c60f-ca2f-4f90-9e33-75a826ca6d8f": random.choice([True, False]),
|
||||
"default_e71d153c-2089-43c7-bd2c-f61f37d4c1c1": random.choice([True, False]),
|
||||
"default_7dde3e36-d78f-4152-9eea-00ef9c8124ed": random.choice([True, False]),
|
||||
"default_444fe292-896b-44fd-bfc6-c0b5d220fffc": random.choice([True, False]),
|
||||
"default_2d05e719-a6b9-4e64-9310-b875d3b2f9d2": random.choice([True, False]),
|
||||
"acae7e87-070b-4999-9074-c5b593c86618": random.choice([True, False]),
|
||||
"3008fc77-1521-49c7-ba95-94c5a4508d1d": random.choice([True, False]),
|
||||
"default_686bb1d0-d086-4c70-9fa3-2f600b922023": random.choice([True, False]),
|
||||
"36905c46-e768-4dc3-8ecd-e55fe69bf03c": random.choice([True, False]),
|
||||
"7c3e4951-183b-40ef-a890-28eef4d50097": random.choice([True, False]),
|
||||
"7a053b2f-64e4-4152-80e9-296006e77131": random.choice([True, False]),
|
||||
"27d4f1be-4156-46e9-8d22-d0508cd72d4f": random.choice([True, False]),
|
||||
"e881dc06-70d2-438f-b007-6f3e0c3c0e78": random.choice([True, False]),
|
||||
"265d2244-a1d7-495c-a2eb-88217f5eae37": random.choice([True, False]),
|
||||
"caebcbc7-2bf0-41c4-b553-106b585fddda": random.choice([True, False]),
|
||||
"a7998705-474e-417d-bd37-a2a9480beedf": random.choice([True, False]),
|
||||
"554d94b5-94b3-4d8e-8aed-51ebfc9deea5": random.choice([True, False]),
|
||||
"e6898540-c1bc-408b-b944-c1e242cddbcd": random.choice([True, False]),
|
||||
"363b0960-ab2c-4902-8df3-f592d6194bb3": random.choice([True, False]),
|
||||
}
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/i/{workflow_id}",
|
||||
@@ -39,6 +71,8 @@ async def get_workflow(
|
||||
try:
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
workflow.is_published = ids.get(workflow_id, False)
|
||||
workflow.workflow.is_published = ids.get(workflow_id, False)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
@@ -110,7 +144,7 @@ async def list_workflows(
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
workflow_record_list_items = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by,
|
||||
direction=direction,
|
||||
page=page,
|
||||
@@ -121,19 +155,21 @@ async def list_workflows(
|
||||
has_been_opened=has_been_opened,
|
||||
is_published=is_published,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
for item in workflow_record_list_items.items:
|
||||
data = item.model_dump()
|
||||
data["is_published"] = ids.get(item.workflow_id, False)
|
||||
workflows_with_thumbnails.append(
|
||||
WorkflowRecordListItemWithThumbnailDTO(
|
||||
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
|
||||
**workflow.model_dump(),
|
||||
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(item.workflow_id),
|
||||
**data,
|
||||
)
|
||||
)
|
||||
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
|
||||
items=workflows_with_thumbnails,
|
||||
total=workflows.total,
|
||||
page=workflows.page,
|
||||
pages=workflows.pages,
|
||||
per_page=workflows.per_page,
|
||||
total=workflow_record_list_items.total,
|
||||
page=workflow_record_list_items.page,
|
||||
pages=workflow_record_list_items.pages,
|
||||
per_page=workflow_record_list_items.per_page,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
716
invokeai/app/invocations/controlnet_image_processors.py
Normal file
716
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@@ -0,0 +1,716 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from controlnet_aux import (
|
||||
ContentShuffleDetector,
|
||||
LeresDetector,
|
||||
MediapipeFaceDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.images.save(image=processed_image)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width=image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height=image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"canny_image_processor",
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.3.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processed_image = get_canny_edges(
|
||||
image,
|
||||
self.low_threshold,
|
||||
self.high_threshold,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"hed_image_processor",
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_image_processor",
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
lineart_processor = LineartProcessor()
|
||||
processed_image = lineart_processor.run(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_anime_image_processor",
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"midas_depth_image_processor",
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"normalbae_image_processor",
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor",
|
||||
title="MLSD Processor",
|
||||
tags=["controlnet", "mlsd"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor",
|
||||
title="PIDI Processor",
|
||||
tags=["controlnet", "pidi"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"content_shuffle_image_processor",
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
@invocation(
|
||||
"zoe_depth_image_processor",
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mediapipe_face_processor",
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
max_faces=self.max_faces,
|
||||
min_confidence=self.min_confidence,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"leres_image_processor",
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"tile_image_processor",
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(
|
||||
self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
np_img = HWC3(np_img)
|
||||
if down_sampling_rate < 1.1:
|
||||
return np_img
|
||||
H, W, C = np_img.shape
|
||||
H = int(float(H) / float(down_sampling_rate))
|
||||
W = int(float(W) / float(down_sampling_rate))
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate,
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything_processor",
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
# so using ADE20k color palette instead
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||
h, w = anns[0]["segmentation"].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann["segmentation"]
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_map_image_processor",
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map = Image.fromarray(color_map)
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"depth_anything_image_processor",
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
"dw_openpose_image_processor",
|
||||
title="DW Openpose Image Processor",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
draw_body: bool = InputField(default=True)
|
||||
draw_face: bool = InputField(default=False)
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
draw_hands=self.draw_hands,
|
||||
draw_body=self.draw_body,
|
||||
resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
|
||||
@@ -4,7 +4,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -25,20 +25,20 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_pose())
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
assert isinstance(session_det, ort.InferenceSession)
|
||||
assert isinstance(session_pose, ort.InferenceSession)
|
||||
detector = DWOpenposeDetector(session_det=session_det, session_pose=session_pose)
|
||||
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
|
||||
detected_image = detector.run(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
||||
@@ -14,7 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
|
||||
@@ -21,16 +21,10 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
:param output_dir: The folder where the serialized objects will be stored
|
||||
:param safe_globals: A list of types to be added to the safe globals for torch serialization
|
||||
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: Path,
|
||||
safe_globals: list[type],
|
||||
ephemeral: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||
super().__init__()
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
@@ -48,8 +42,6 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||
self.__obj_class_name: Optional[str] = None
|
||||
|
||||
torch.serialization.add_safe_globals(safe_globals) if safe_globals else None
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
|
||||
@@ -33,7 +33,12 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
def enqueue_batch(
|
||||
self,
|
||||
queue_id: str,
|
||||
batch: Batch,
|
||||
prepend: bool,
|
||||
) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
"""Enqueues all permutations of a batch for execution."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -157,6 +157,28 @@ class Batch(BaseModel):
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||
creating them, as is done in `create_session_nfv_tuples()`.
|
||||
|
||||
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||
|
||||
If the session count has already been calculated, return the cached value.
|
||||
"""
|
||||
if not self.data:
|
||||
return self.runs
|
||||
data = []
|
||||
for batch_datum_list in self.data:
|
||||
to_zip = []
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip, strict=True)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * self.runs
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
@@ -247,10 +269,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
default=False,
|
||||
description="Whether this queue item is an API validation run.",
|
||||
)
|
||||
published_workflow_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The ID of the published workflow associated with this queue item",
|
||||
)
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
)
|
||||
@@ -556,28 +574,6 @@ def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str
|
||||
count += 1
|
||||
|
||||
|
||||
def calc_session_count(batch: Batch) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||
creating them, as is done in `create_session_nfv_tuples()`.
|
||||
|
||||
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||
"""
|
||||
# TODO: Should this be a class method on Batch?
|
||||
if not batch.data:
|
||||
return batch.runs
|
||||
data = []
|
||||
for batch_datum_list in batch.data:
|
||||
to_zip = []
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip, strict=True)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
||||
ValueToInsertTuple: TypeAlias = tuple[
|
||||
str, # queue_id
|
||||
str, # session (as stringified JSON)
|
||||
|
||||
@@ -28,7 +28,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
ValueToInsertTuple,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
@@ -118,7 +117,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = calc_session_count(batch)
|
||||
requested_count = batch.get_session_count()
|
||||
|
||||
values_to_insert = prepare_values_to_insert(
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
|
||||
204
invokeai/app/services/shared/compose_pydantic_model.py
Normal file
204
invokeai/app/services/shared/compose_pydantic_model.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, TypeAlias, get_args
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from invokeai.app.services.session_queue.session_queue_common import FieldIdentifier
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
|
||||
DictOfFieldsMetadata: TypeAlias = dict[str, tuple[type[Any], FieldInfo]]
|
||||
|
||||
|
||||
class ComposedFieldMetadata(BaseModel):
|
||||
node_id: str
|
||||
field_name: str
|
||||
field_type_class_name: str
|
||||
|
||||
|
||||
def dedupe_field_name(field_metadata: DictOfFieldsMetadata, field_name: str) -> str:
|
||||
"""Given a field name, return a name that is not already in the field metadata.
|
||||
If the field name is not in the field metadata, return the field name.
|
||||
If the field name is in the field metadata, generate a new name by appending an underscore and integer to the field name, starting with 2.
|
||||
"""
|
||||
|
||||
if field_name not in field_metadata:
|
||||
return field_name
|
||||
|
||||
i = 2
|
||||
while True:
|
||||
new_field_name = f"{field_name}_{i}"
|
||||
if new_field_name not in field_metadata:
|
||||
return new_field_name
|
||||
i += 1
|
||||
|
||||
|
||||
def compose_model_from_fields(
|
||||
g: Graph,
|
||||
field_identifiers: list[FieldIdentifier],
|
||||
composed_model_class_name: str = "ComposedModel",
|
||||
model_field_overrides: dict[type[Any], tuple[type[Any], FieldInfo]] | None = None,
|
||||
model_field_filter: Callable[[type[Any]], bool] | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Given a graph and a list of field identifiers, create a new pydantic model composed of the fields of the nodes in the graph.
|
||||
|
||||
The resultant model can be used to validate a JSON payload that contains the fields of the nodes in the graph, or generate an
|
||||
OpenAPI schema for the model.
|
||||
|
||||
Args:
|
||||
g: The graph containing the nodes whose fields will be composed into the new model.
|
||||
field_identifiers: A list of FieldIdentifier instances, each representing a field on a node in the graph.
|
||||
model_name: The name of the composed model.
|
||||
kind: The kind of model to create. Must be "input" or "output". Defaults to "input".
|
||||
model_field_overrides: A dictionary mapping type annotations to tuples of (new_type_annotation, new_field_info).
|
||||
This can be used to override the type annotation and field info of a field in the composed model. For example,
|
||||
if `ModelIdentifierField` should be replaced by a string, the dictionary would look like this:
|
||||
```python
|
||||
{ModelIdentifierField: (str, Field(description="The model id."))}
|
||||
```
|
||||
model_field_filter: A function that takes a type annotation and returns True if the field should be included in the composed model.
|
||||
If None, all fields will be included. For example, to omit `BoardField` fields, the filter would look like this:
|
||||
```python
|
||||
def model_field_filter(field_type: type[Any]) -> bool:
|
||||
return field_type not in {BoardField}
|
||||
```
|
||||
Optional fields - or any other complex field types like unions - must be explicitly included in the filter. For example,
|
||||
to omit `BoardField` _and_ `Optional[BoardField]`:
|
||||
```python
|
||||
def model_field_filter(field_type: type[Any]) -> bool:
|
||||
return field_type not in {BoardField, Optional[BoardField]}
|
||||
```
|
||||
Note that the filter is applied to the type annotation of the field, not the field itself.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Create some nodes.
|
||||
add_node = AddInvocation()
|
||||
sub_node = SubtractInvocation()
|
||||
color_node = ColorInvocation()
|
||||
|
||||
# Create a graph with the nodes.
|
||||
g = Graph(
|
||||
nodes={
|
||||
add_node.id: add_node,
|
||||
sub_node.id: sub_node,
|
||||
color_node.id: color_node,
|
||||
}
|
||||
)
|
||||
|
||||
# Select the fields to compose.
|
||||
fields_to_compose = [
|
||||
FieldIdentifier(node_id=add_node.id, field_name="a"),
|
||||
FieldIdentifier(node_id=sub_node.id, field_name="a"), # this will be deduped to "a_2"
|
||||
FieldIdentifier(node_id=add_node.id, field_name="b"),
|
||||
FieldIdentifier(node_id=color_node.id, field_name="color"),
|
||||
]
|
||||
|
||||
# Compose the model from the fields.
|
||||
composed_model = compose_model_from_fields(g, fields_to_compose, model_name="ComposedModel")
|
||||
|
||||
# Generate the OpenAPI schema for the model.
|
||||
json_schema = composed_model.model_json_schema(mode="validation")
|
||||
```
|
||||
"""
|
||||
|
||||
# Temp storage for the composed fields. Pydantic needs a type annotation and instance of FieldInfo to create a model.
|
||||
field_metadata: DictOfFieldsMetadata = {}
|
||||
model_field_overrides = model_field_overrides or {}
|
||||
|
||||
# The list of required fields. This is used to ensure the composed model's fields retain their required state.
|
||||
required: list[str] = []
|
||||
|
||||
for field_identifier in field_identifiers:
|
||||
node_id = field_identifier.node_id
|
||||
field_name = field_identifier.field_name
|
||||
|
||||
# Pull the node instance from the graph so we can introspect it.
|
||||
node_instance = g.nodes[node_id]
|
||||
|
||||
if field_identifier.kind == "input":
|
||||
# Get the class of the node. This will be a BaseInvocation subclass, e.g. AddInvocation, DenoiseLatentsInvocation, etc.
|
||||
pydantic_model = type(node_instance)
|
||||
else:
|
||||
# Otherwise the the type of the node's output class. This will be a BaseInvocationOutput subclass, e.g. IntegerOutput, ImageOutput, etc.
|
||||
pydantic_model = type(node_instance).get_output_annotation()
|
||||
|
||||
# Get the FieldInfo instance for the field. For example:
|
||||
# a: int = Field(..., description="The first number to add.")
|
||||
# ^^^^^ The return value of this Field call is the FieldInfo instance (Field is a function).
|
||||
og_field_info = pydantic_model.model_fields[field_name]
|
||||
|
||||
# Get the type annotation of the field. For example:
|
||||
# a: int = Field(..., description="The first number to add.")
|
||||
# ^^^ this is the type annotation
|
||||
og_field_type = og_field_info.annotation
|
||||
|
||||
# Apparently pydantic allows fields without type annotations. We don't support that.
|
||||
assert og_field_type is not None, (
|
||||
f"{field_identifier.kind.capitalize()} field {field_name} on node {node_id} has no type annotation."
|
||||
)
|
||||
|
||||
# Now that we have the type annotation, we can apply the filter to see if we should include the field in the composed model.
|
||||
if model_field_filter and not model_field_filter(og_field_type):
|
||||
continue
|
||||
|
||||
# Ok, we want this type of field. Retrieve any overrides for the field type. This is a dictionary mapping
|
||||
# type annotations to tuples of (override_type_annotation, override_field_info).
|
||||
(override_field_type, override_field_info) = model_field_overrides.get(og_field_type, (None, None))
|
||||
|
||||
# The override tuple's first element is the new type annotation, if it exists.
|
||||
composed_field_type = override_field_type if override_field_type is not None else og_field_type
|
||||
|
||||
# Create a deep copy of the FieldInfo instance (or override it if it exists) so we can modify it without
|
||||
# affecting the original. This is important because we are going to modify the FieldInfo instance and
|
||||
# don't want to affect the original model's schema.
|
||||
composed_field_info = deepcopy(override_field_info if override_field_info is not None else og_field_info)
|
||||
|
||||
json_schema_extra = og_field_info.json_schema_extra if isinstance(og_field_info.json_schema_extra, dict) else {}
|
||||
|
||||
# The field's original required state is stored in the json_schema_extra dict. For more information about why,
|
||||
# see the definition of `InputField` in invokeai/app/invocations/fields.py.
|
||||
#
|
||||
# Add the field to the required list if it is required, which we will use when creating the composed model.
|
||||
if json_schema_extra.get("orig_required", False):
|
||||
required.append(field_name)
|
||||
|
||||
# Invocation fields have some extra metadata, used by the UI to render the field in the frontend. This data is
|
||||
# included in the OpenAPI schema for each field. For example, we add a "ui_order" field, which the UI uses to
|
||||
# sort fields when rendering them.
|
||||
#
|
||||
# The composed model's OpenAPI schema should not have this information. It should only have a standard OpenAPI
|
||||
# schema for the field. We need to strip out the UI-specific metadata from the FieldInfo instance before adding
|
||||
# it to the composed model.
|
||||
#
|
||||
# We will replace this metadata with some custom metadata:
|
||||
# - node_id: The id of the node that this field belongs to.
|
||||
# - field_name: The name of the field on the node.
|
||||
# - original_data_type: The original data type of the field.
|
||||
|
||||
field_type_class = get_args(og_field_type)[0] if hasattr(og_field_type, "__args__") else og_field_type
|
||||
field_type_class_name = field_type_class.__name__
|
||||
|
||||
composed_field_metadata = ComposedFieldMetadata(
|
||||
node_id=node_id,
|
||||
field_name=field_name,
|
||||
field_type_class_name=field_type_class_name,
|
||||
)
|
||||
|
||||
composed_field_info.json_schema_extra = {
|
||||
"composed_field_extra": composed_field_metadata.model_dump(),
|
||||
}
|
||||
|
||||
# Override the name, title and description if overrides are provided. Dedupe the field name if necessary.
|
||||
final_field_name = dedupe_field_name(field_metadata, field_name)
|
||||
|
||||
# Store the field metadata.
|
||||
field_metadata.update({final_field_name: (composed_field_type, composed_field_info)})
|
||||
|
||||
# Splat in the composed fields to create the new model. There are type errors here because create_model's kwargs are not typed,
|
||||
# and for some reason pydantic's ConfigDict doesn't like lists in `json_schema_extra`. Anyways, the inputs here are correct.
|
||||
return create_model(
|
||||
composed_model_class_name,
|
||||
**field_metadata,
|
||||
__config__=ConfigDict(json_schema_extra={"required": required}),
|
||||
)
|
||||
@@ -65,6 +65,9 @@ def apply_monkeypatches() -> None:
|
||||
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
def register_mime_types() -> None:
|
||||
"""Register additional mime types for windows."""
|
||||
|
||||
@@ -5,14 +5,62 @@ import huggingface_hub
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
from invokeai.backend.image_util.util import np_to_pil
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image: Image.Image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
dwpose_image = Image.fromarray(dwpose_image)
|
||||
|
||||
return dwpose_image
|
||||
|
||||
|
||||
class DWOpenposeDetector:
|
||||
"""
|
||||
@@ -20,6 +68,62 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.pose_estimation(np_image)
|
||||
nums, keys, locs = candidate.shape
|
||||
candidate[..., 0] /= float(W)
|
||||
candidate[..., 1] /= float(H)
|
||||
body = candidate[:, :18].copy()
|
||||
body = body.reshape(nums * 18, locs)
|
||||
score = subset[:, :18]
|
||||
for i in range(len(score)):
|
||||
for j in range(len(score[i])):
|
||||
if score[i][j] > 0.3:
|
||||
score[i][j] = int(18 * i + j)
|
||||
else:
|
||||
score[i][j] = -1
|
||||
|
||||
un_visible = subset < 0.3
|
||||
candidate[un_visible] = -1
|
||||
|
||||
# foot = candidate[:, 18:24]
|
||||
|
||||
faces = candidate[:, 24:92]
|
||||
|
||||
hands = candidate[:, 92:113]
|
||||
hands = np.vstack([hands, candidate[:, 113:]])
|
||||
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
|
||||
|
||||
class DWOpenposeDetector2:
|
||||
"""
|
||||
Code from the original implementation of the DW Openpose Detector.
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
This implementation is similar to DWOpenposeDetector, with some alterations to allow the onnx models to be loaded
|
||||
and managed by the model manager.
|
||||
"""
|
||||
|
||||
hf_repo_id = "yzd-v/DWPose"
|
||||
hf_filename_onnx_det = "yolox_l.onnx"
|
||||
hf_filename_onnx_pose = "dw-ll_ucoco_384.onnx"
|
||||
@@ -109,7 +213,7 @@ class DWOpenposeDetector:
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return DWOpenposeDetector.draw_pose(
|
||||
return DWOpenposeDetector2.draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
@@ -126,13 +127,11 @@ def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||
x2 = int(x2 * W)
|
||||
y2 = int(y2 * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
hsv_color = np.array([[[ie / float(len(edges)) * 180, 255, 255]]], dtype=np.uint8)
|
||||
rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0, 0]
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
rgb_color.tolist(),
|
||||
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
|
||||
44
invokeai/backend/image_util/dw_openpose/wholebody.py
Normal file
44
invokeai/backend/image_util/dw_openpose/wholebody.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
def __call__(self, oriImg):
|
||||
det_result = inference_detector(self.session_det, oriImg)
|
||||
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
||||
|
||||
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
|
||||
# compute neck joint
|
||||
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
||||
# neck score when visualizing pred
|
||||
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
||||
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
|
||||
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
|
||||
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
|
||||
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
|
||||
keypoints_info = new_keypoints_info
|
||||
|
||||
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
|
||||
|
||||
return keypoints, scores
|
||||
@@ -69,9 +69,6 @@ class SD3ConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
# If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
|
||||
# invokeai/app/api/dependencies.py, adding the types to the list of safe globals. If you do not, torch will be
|
||||
# unable to deserialize the object and will raise an error.
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
|
||||
245
invokeai/backend/util/mps_fixes.py
Normal file
245
invokeai/backend/util/mps_fixes.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import math
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
|
||||
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
if bias is not None:
|
||||
bias = bias.float()
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
|
||||
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
|
||||
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
end = end.float()
|
||||
if isinstance(weight, torch.Tensor):
|
||||
weight = weight.float()
|
||||
if out is not None:
|
||||
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
||||
else:
|
||||
out_fp32 = None
|
||||
result = _torch_lerp(input, end, weight, out=out_fp32)
|
||||
if out is not None:
|
||||
out.copy_(out_fp32.half())
|
||||
del out_fp32
|
||||
return result.half()
|
||||
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
|
||||
|
||||
def new_torch_interpolate(
|
||||
input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
antialias=False,
|
||||
):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(
|
||||
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
||||
|
||||
# TODO: refactor it
|
||||
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
|
||||
class ChunkedSlicedAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
assert isinstance(slice_size, int)
|
||||
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
|
||||
self.slice_size = slice_size
|
||||
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
if self.slice_size != 1 or attn.upcast_attention:
|
||||
return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
chunk_tmp_tensor = torch.empty(
|
||||
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
self.get_attention_scores_chunked(
|
||||
attn,
|
||||
query_slice,
|
||||
key_slice,
|
||||
attn_mask_slice,
|
||||
hidden_states[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
chunk_tmp_tensor,
|
||||
)
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
|
||||
# batch size = 1
|
||||
assert query.shape[0] == 1
|
||||
assert key.shape[0] == 1
|
||||
assert value.shape[0] == 1
|
||||
assert hidden_states.shape[0] == 1
|
||||
|
||||
# dtype = query.dtype
|
||||
if attn.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
# out_item_size = query.dtype.itemsize
|
||||
# if attn.upcast_attention:
|
||||
# out_item_size = torch.float32.itemsize
|
||||
out_item_size = query.element_size()
|
||||
if attn.upcast_attention:
|
||||
out_item_size = 4
|
||||
|
||||
chunk_size = 2**29
|
||||
|
||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||
chunk_step = max(1, int(query.shape[1] / chunks_count))
|
||||
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
def _get_chunk_view(tensor, start, length):
|
||||
if start + length > tensor.shape[1]:
|
||||
length = tensor.shape[1] - start
|
||||
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
|
||||
return tensor[:, start : start + length]
|
||||
|
||||
for chunk_pos in range(0, query.shape[1], chunk_step):
|
||||
if attention_mask is not None:
|
||||
torch.baddbmm(
|
||||
_get_chunk_view(attention_mask, chunk_pos, chunk_step),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=1,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
else:
|
||||
torch.baddbmm(
|
||||
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=0,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
chunk = chunk.softmax(dim=-1)
|
||||
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
|
||||
|
||||
# del chunk
|
||||
|
||||
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor
|
||||
@@ -62,7 +62,7 @@
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.6.1",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.5.3",
|
||||
"@xyflow/react": "^12.5.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.0.0",
|
||||
@@ -162,6 +162,5 @@
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": "8"
|
||||
},
|
||||
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
|
||||
}
|
||||
}
|
||||
|
||||
14
invokeai/frontend/web/pnpm-lock.yaml
generated
14
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -36,8 +36,8 @@ dependencies:
|
||||
specifier: ^1.3.0
|
||||
version: 1.3.0
|
||||
'@xyflow/react':
|
||||
specifier: ^12.5.3
|
||||
version: 12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^12.5.1
|
||||
version: 12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
async-mutex:
|
||||
specifier: ^0.5.0
|
||||
version: 0.5.0
|
||||
@@ -3951,8 +3951,8 @@ packages:
|
||||
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
|
||||
dev: false
|
||||
|
||||
/@xyflow/react@12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-saovy/aQRoW8qQoIqMFUtmC3F6oEV7n6+J1pVbhSG45NI/hOFvK0qozsIPKqX5Va6lGQnkl/o53NHLja3NiweQ==}
|
||||
/@xyflow/react@12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-jMKQVqGwCz0x6pUyvxTIuCMbyehfua7CfEEWDj29zQSHigQpCy0/5d8aOmZrqK4cwur/pVHLQomT6Rm10gXfHg==}
|
||||
peerDependencies:
|
||||
react: '>=17'
|
||||
react-dom: '>=17'
|
||||
@@ -9123,8 +9123,8 @@ packages:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/use-sync-external-store@1.5.0(react@18.3.1):
|
||||
resolution: {integrity: sha512-Rb46I4cGGVBmjamjphe8L/UnvJD+uPPtTkNvX5mZgqdbavhI4EbgIWJiIHXJ8bc/i9EQGPRh4DwEURJ552Do0A==}
|
||||
/use-sync-external-store@1.4.0(react@18.3.1):
|
||||
resolution: {integrity: sha512-9WXSPC5fMv61vaupRkCKCxsPxBocVnwakBEkMIHHpkTTg6icbJtg6jzgtLDm4bl3cSHAca52rYWih0k4K3PfHw==}
|
||||
peerDependencies:
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
dependencies:
|
||||
@@ -9592,5 +9592,5 @@ packages:
|
||||
dependencies:
|
||||
'@types/react': 18.3.11
|
||||
react: 18.3.1
|
||||
use-sync-external-store: 1.5.0(react@18.3.1)
|
||||
use-sync-external-store: 1.4.0(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
@@ -116,10 +116,7 @@
|
||||
"combinatorial": "Kombinatorisch",
|
||||
"saveChanges": "Änderungen speichern",
|
||||
"error_withCount_one": "{{count}} Fehler",
|
||||
"error_withCount_other": "{{count}} Fehler",
|
||||
"value": "Wert",
|
||||
"label": "Label",
|
||||
"systemInformation": "Systeminformationen"
|
||||
"error_withCount_other": "{{count}} Fehler"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -698,10 +695,7 @@
|
||||
"guidance": "Führung",
|
||||
"coherenceMode": "Modus",
|
||||
"recallMetadata": "Metadaten abrufen",
|
||||
"gaussianBlur": "Gaußsche Unschärfe",
|
||||
"sendToUpscale": "An Hochskalieren senden",
|
||||
"useCpuNoise": "CPU-Rauschen verwenden",
|
||||
"sendToCanvas": "An Leinwand senden"
|
||||
"gaussianBlur": "Gaußsche Unschärfe"
|
||||
},
|
||||
"settings": {
|
||||
"displayInProgress": "Zwischenbilder anzeigen",
|
||||
@@ -1334,8 +1328,7 @@
|
||||
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
|
||||
"loadingTemplates": "Lade {{name}}",
|
||||
"missingSourceOrTargetHandle": "Fehlender Quell- oder Zielgriff",
|
||||
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten",
|
||||
"showEdgeLabelsHelp": "Beschriftungen an Kanten anzeigen, um die verknüpften Knoten zu kennzeichnen"
|
||||
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||
|
||||
@@ -1706,7 +1706,6 @@
|
||||
"noRecentWorkflows": "No Recent Workflows",
|
||||
"private": "Private",
|
||||
"shared": "Shared",
|
||||
"published": "Published",
|
||||
"browseWorkflows": "Browse Workflows",
|
||||
"deselectAll": "Deselect All",
|
||||
"recommended": "Recommended For You",
|
||||
@@ -1814,9 +1813,7 @@
|
||||
"publishedWorkflowIsLocked": "Published workflow is locked",
|
||||
"publishingValidationRun": "Publishing Validation Run",
|
||||
"publishingValidationRunInProgress": "Publishing validation run in progress.",
|
||||
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow.",
|
||||
"selectingOutputNode": "Selecting output node",
|
||||
"selectingOutputNodeDesc": "Click a node to select it as the workflow's output node."
|
||||
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow."
|
||||
}
|
||||
},
|
||||
"controlLayers": {
|
||||
|
||||
@@ -115,8 +115,7 @@
|
||||
"error_withCount_many": "{{count}} errori",
|
||||
"error_withCount_other": "{{count}} errori",
|
||||
"value": "Valore",
|
||||
"label": "Etichetta",
|
||||
"systemInformation": "Informazioni di sistema"
|
||||
"label": "Etichetta"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -716,8 +715,7 @@
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
|
||||
"collectionEmpty": "raccolta vuota",
|
||||
"batchNodeCollectionSizeMismatchNoGroupId": "Dimensione della raccolta di gruppo nel Lotto non corrisponde"
|
||||
"collectionEmpty": "raccolta vuota"
|
||||
},
|
||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||
"iterations": "Iterazioni",
|
||||
@@ -2367,9 +2365,8 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"Flussi di lavoro: supporto per menu a discesa di stringhe personalizzate nel Generatore di Flussi di lavoro.",
|
||||
"FLUX: supporto per FLUX Fill in Flussi di lavoro e Tela.",
|
||||
"LLaVA OneVision VLLM: supporto beta nei flussi di lavoro."
|
||||
"Flussi di lavoro: nuova e migliorata libreria dei flussi di lavoro.",
|
||||
"FLUX: supporto per FLUX Redux e FLUX Fill in Flussi di lavoro e Tela."
|
||||
]
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -237,10 +237,7 @@
|
||||
"row": "Hàng",
|
||||
"board": "Bảng",
|
||||
"saveChanges": "Lưu Thay Đổi",
|
||||
"error_withCount_other": "{{count}} lỗi",
|
||||
"value": "Giá Trị",
|
||||
"label": "Nhãn Tên",
|
||||
"systemInformation": "Thông Tin Hệ Thống"
|
||||
"error_withCount_other": "{{count}} lỗi"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Prompt Trigger",
|
||||
@@ -2303,10 +2300,7 @@
|
||||
"minimum": "Tối Thiểu",
|
||||
"maximum": "Tối Đa",
|
||||
"containerRowLayout": "Hộp Chứa (bố cục hàng)",
|
||||
"containerColumnLayout": "Hộp Chứa (bố cục cột)",
|
||||
"resetOptions": "Tải Lại Lựa Chọn",
|
||||
"addOption": "Thêm Lựa Chọn",
|
||||
"dropdown": "Danh Sách Thả Xuống"
|
||||
"containerColumnLayout": "Hộp Chứa (bố cục cột)"
|
||||
},
|
||||
"yourWorkflows": "Workflow Của Bạn",
|
||||
"browseWorkflows": "Khám Phá Workflow",
|
||||
@@ -2322,8 +2316,7 @@
|
||||
"view": "Xem",
|
||||
"deselectAll": "Huỷ Chọn Tất Cả",
|
||||
"noRecentWorkflows": "Không Có Workflows Gần Đây",
|
||||
"recommended": "Có Thể Bạn Sẽ Cần",
|
||||
"emptyStringPlaceholder": "<xâu ký tự trống>"
|
||||
"recommended": "Có Thể Bạn Sẽ Cần"
|
||||
},
|
||||
"upscaling": {
|
||||
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
|
||||
@@ -2359,9 +2352,8 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"Workflow: Hỗ trợ xâu ký tự thả xuống tùy chỉnh trong Trình Tạo Vùng Nhập.",
|
||||
"FLUX: Hỗ trợ FLUX Fill trong Workflow và Canvas.",
|
||||
"LLaVA OneVision VLLM: Hỗ trợ phiên bản Beta trong Workflow."
|
||||
"Workflow: Thư Viện Workflow mới và đã được cải tiến.",
|
||||
"FLUX: Hỗ trợ FLUX Redux & FLUX Fill trong Workflow và Canvas."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
|
||||
@@ -28,7 +28,8 @@ export type AppFeature =
|
||||
| 'starterModels'
|
||||
| 'hfToken'
|
||||
| 'retryQueueItem'
|
||||
| 'cancelAndClearAll';
|
||||
| 'cancelAndClearAll'
|
||||
| 'deployWorkflow';
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
@@ -74,7 +75,6 @@ export type AppConfig = {
|
||||
allowPrivateBoards: boolean;
|
||||
allowPrivateStylePresets: boolean;
|
||||
allowClientSideUpload: boolean;
|
||||
allowPublishWorkflows: boolean;
|
||||
disabledTabs: TabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
|
||||
@@ -49,11 +49,7 @@ export const useGalleryHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavLeft',
|
||||
category: 'gallery',
|
||||
callback: (e) => {
|
||||
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
|
||||
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
|
||||
return;
|
||||
}
|
||||
callback: () => {
|
||||
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
|
||||
goPrev('arrow');
|
||||
return;
|
||||
@@ -75,11 +71,7 @@ export const useGalleryHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'galleryNavRight',
|
||||
category: 'gallery',
|
||||
callback: (e) => {
|
||||
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
|
||||
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
|
||||
return;
|
||||
}
|
||||
callback: () => {
|
||||
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext('arrow');
|
||||
return;
|
||||
|
||||
@@ -3,11 +3,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isSelectingOutputNode,
|
||||
useIsValidationRunInProgress,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { $isInPublishFlow, useIsValidationRunInProgress } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
@@ -18,7 +14,6 @@ export const TopLeftPanel = memo(() => {
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isValidationRunInProgress = useIsValidationRunInProgress();
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
@@ -39,16 +34,11 @@ export const TopLeftPanel = memo(() => {
|
||||
{t('workflows.builder.publishingValidationRunInProgress')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && !isSelectingOutputNode && (
|
||||
{isInPublishFlow && !isValidationRunInProgress && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedDuringPublishing')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && isSelectingOutputNode && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.selectingOutputNodeDesc')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isPublished && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedPublished')}
|
||||
|
||||
@@ -67,7 +67,7 @@ type NodeFieldDndData = {
|
||||
fieldName: string;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
};
|
||||
const buildNodeFieldDndData = (
|
||||
export const buildNodeFieldDndData = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
fieldTemplate: FieldInputTemplate
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
@@ -39,12 +38,12 @@ import { selectHasBatchOrGeneratorNodes } from 'features/nodes/store/selectors';
|
||||
import { selectIsWorkflowSaved } from 'features/nodes/store/workflowSlice';
|
||||
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
|
||||
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowLineRightBold, PiLightningFill, PiXBold } from 'react-icons/pi';
|
||||
import { PiLightningFill, PiSignOutBold, PiXBold } from 'react-icons/pi';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@@ -54,6 +53,7 @@ export const PublishWorkflowPanelContent = memo(() => {
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} h="full">
|
||||
<ButtonGroup isAttached={false} size="sm" variant="ghost">
|
||||
<SelectOutputNodeButton />
|
||||
<Spacer />
|
||||
<CancelPublishButton />
|
||||
<PublishWorkflowButton />
|
||||
@@ -68,41 +68,38 @@ export const PublishWorkflowPanelContent = memo(() => {
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishWorkflowPanelContent.displayName = 'PublishWorkflowPanelContent';
|
||||
PublishWorkflowPanelContent.displayName = 'DeployWorkflowPanelContent';
|
||||
|
||||
const OutputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Flex alignItems="center">
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
|
||||
<Spacer />
|
||||
<SelectOutputNodeButton variant="link" size="sm" />
|
||||
</Flex>
|
||||
|
||||
<Divider />
|
||||
{!outputNodeId && (
|
||||
if (!outputNodeId) {
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold" color="error.300">
|
||||
{t('workflows.builder.noOutputNodeSelected')}
|
||||
</Text>
|
||||
)}
|
||||
{outputNodeId && <OutputFieldsContent outputNodeId={outputNodeId} />}
|
||||
</Flex>
|
||||
);
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return <OutputFieldsContent outputNodeId={outputNodeId} />;
|
||||
});
|
||||
OutputFields.displayName = 'OutputFields';
|
||||
|
||||
const OutputFieldsContent = memo(({ outputNodeId }: { outputNodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const outputFieldNames = useOutputFieldNames(outputNodeId);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
|
||||
<Divider />
|
||||
{outputFieldNames.map((fieldName) => (
|
||||
<NodeOutputFieldPreview key={`${outputNodeId}-${fieldName}`} nodeId={outputNodeId} fieldName={fieldName} />
|
||||
))}
|
||||
</>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
OutputFieldsContent.displayName = 'OutputFieldsContent';
|
||||
@@ -155,7 +152,7 @@ const UnpublishableInputFields = memo(() => {
|
||||
});
|
||||
UnpublishableInputFields.displayName = 'UnpublishableInputFields';
|
||||
|
||||
const SelectOutputNodeButton = memo((props: ButtonProps) => {
|
||||
const SelectOutputNodeButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
@@ -164,18 +161,8 @@ const SelectOutputNodeButton = memo((props: ButtonProps) => {
|
||||
$isSelectingOutputNode.set(true);
|
||||
}, []);
|
||||
return (
|
||||
<Button
|
||||
leftIcon={<PiArrowLineRightBold />}
|
||||
isDisabled={isSelectingOutputNode}
|
||||
tooltip={isSelectingOutputNode ? t('workflows.builder.selectingOutputNodeDesc') : undefined}
|
||||
onClick={onClick}
|
||||
{...props}
|
||||
>
|
||||
{isSelectingOutputNode
|
||||
? t('workflows.builder.selectingOutputNode')
|
||||
: outputNodeId
|
||||
? t('workflows.builder.changeOutputNode')
|
||||
: t('workflows.builder.selectOutputNode')}
|
||||
<Button leftIcon={<PiSignOutBold />} isDisabled={isSelectingOutputNode} onClick={onClick}>
|
||||
{outputNodeId ? t('workflows.builder.changeOutputNode') : t('workflows.builder.selectOutputNode')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
@@ -205,7 +192,6 @@ const PublishWorkflowButton = memo(() => {
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const inputs = usePublishInputs();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
|
||||
const projectUrl = useStore($projectUrl);
|
||||
|
||||
@@ -254,11 +240,9 @@ const PublishWorkflowButton = memo(() => {
|
||||
<Button
|
||||
leftIcon={<PiLightningFill />}
|
||||
isDisabled={
|
||||
!allowPublishWorkflows ||
|
||||
!isReadyToEnqueue ||
|
||||
!isWorkflowSaved ||
|
||||
hasBatchOrGeneratorNodes ||
|
||||
!isReadyToDoValidationRun ||
|
||||
!isReadyToEnqueue ||
|
||||
hasBatchOrGeneratorNodes ||
|
||||
!(outputNodeId !== null && !isSelectingOutputNode)
|
||||
}
|
||||
onClick={onClick}
|
||||
@@ -323,7 +307,7 @@ NodeOutputFieldPreview.displayName = 'NodeOutputFieldPreview';
|
||||
|
||||
export const StartPublishFlowButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
const deployWorkflowIsEnabled = useFeatureStatus('deployWorkflow');
|
||||
const isReadyToEnqueue = useStore($isReadyToEnqueue);
|
||||
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
|
||||
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
|
||||
@@ -347,7 +331,7 @@ export const StartPublishFlowButton = memo(() => {
|
||||
leftIcon={<PiLightningFill />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
isDisabled={!allowPublishWorkflows || !isReadyToEnqueue || !isWorkflowSaved || hasBatchOrGeneratorNodes}
|
||||
isDisabled={!deployWorkflowIsEnabled || !isWorkflowSaved || hasBatchOrGeneratorNodes}
|
||||
>
|
||||
{t('workflows.builder.publish')}
|
||||
</Button>
|
||||
|
||||
@@ -10,7 +10,7 @@ export const LockedWorkflowIcon = memo(() => {
|
||||
<Tooltip label={t('workflows.builder.publishedWorkflowsLocked')} closeOnScroll>
|
||||
<IconButton
|
||||
size="sm"
|
||||
cursor="not-allowed"
|
||||
cursor='not-allowed'
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('workflows.builder.publishedWorkflowsLocked')}
|
||||
|
||||
@@ -26,7 +26,7 @@ import {
|
||||
workflowLibraryTagToggled,
|
||||
workflowLibraryViewChanged,
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
@@ -40,7 +40,7 @@ export const WorkflowLibrarySideNav = () => {
|
||||
const { t } = useTranslation();
|
||||
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
|
||||
const view = useAppSelector(selectWorkflowLibraryView);
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
const deployWorkflow = useFeatureStatus('deployWorkflow');
|
||||
|
||||
return (
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
|
||||
@@ -60,8 +60,8 @@ export const WorkflowLibrarySideNav = () => {
|
||||
</Flex>
|
||||
</Collapse>
|
||||
)}
|
||||
{allowPublishWorkflows && (
|
||||
<WorkflowLibraryViewButton view="published">{t('workflows.published')}</WorkflowLibraryViewButton>
|
||||
{deployWorkflow && (
|
||||
<WorkflowLibraryViewButton view="published">{t('workflows.publishedWorkflows')}</WorkflowLibraryViewButton>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowBuilder } from 'features/nodes/components/sidePanel/builder/WorkflowBuilder';
|
||||
import { StartPublishFlowButton } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -11,7 +10,7 @@ import WorkflowJSONTab from './WorkflowJSONTab';
|
||||
|
||||
const WorkflowFieldsLinearViewPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
const deployWorkflowIsEnabled = useFeatureStatus('deployWorkflow');
|
||||
return (
|
||||
<Tabs variant="enclosed" display="flex" w="full" h="full" flexDir="column">
|
||||
<TabList>
|
||||
@@ -19,7 +18,7 @@ const WorkflowFieldsLinearViewPanel = () => {
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>JSON</Tab>
|
||||
<Spacer />
|
||||
{allowPublishWorkflows && <StartPublishFlowButton />}
|
||||
{deployWorkflowIsEnabled && <StartPublishFlowButton />}
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="full" pt={2}>
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useInputFieldTemplateTitleSafe = (nodeId: string, fieldName: string): string => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const title = useMemo(() => template.inputs[fieldName]?.title ?? '', [fieldName, template.inputs]);
|
||||
return title;
|
||||
};
|
||||
@@ -0,0 +1,22 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Gets the user-defined description of an input field for a given node.
|
||||
*
|
||||
* If the node doesn't exist or is not an invocation node, an error is thrown.
|
||||
*
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldUserDescriptionOrThrow = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() => createSelector(selectNodesSlice, (nodes) => selectFieldInputInstance(nodes, nodeId, fieldName).description),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const description = useAppSelector(selector);
|
||||
return description;
|
||||
};
|
||||
@@ -470,8 +470,31 @@ export const nodesSlice = createSlice({
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
const { nodes, edges } = action.payload;
|
||||
|
||||
state.nodes = nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node }));
|
||||
state.edges = edges;
|
||||
const changes: NodeChange<AnyNode>[] = [];
|
||||
for (const node of nodes) {
|
||||
if (node.type === 'notes') {
|
||||
changes.push({
|
||||
type: 'add',
|
||||
item: {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
...node,
|
||||
},
|
||||
});
|
||||
} else if (node.type === 'invocation') {
|
||||
changes.push({
|
||||
type: 'add',
|
||||
item: {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
...node,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
state.nodes = applyNodeChanges<AnyNode>(changes, []);
|
||||
state.edges = applyEdgeChanges(
|
||||
edges.map((edge) => ({ type: 'add', item: edge })),
|
||||
[]
|
||||
);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -79,4 +79,4 @@ export const isInvocationOutputSchemaObject = (
|
||||
|
||||
export const isInvocationFieldSchema = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
|
||||
): obj is InvocationFieldSchema => 'field_kind' in obj;
|
||||
): obj is InvocationFieldSchema => !('$ref' in obj);
|
||||
|
||||
@@ -148,11 +148,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stash invalid edges here to be deleted later
|
||||
const edgesToDelete = new Set<string>();
|
||||
|
||||
for (const edge of edges) {
|
||||
edges.forEach((edge, i) => {
|
||||
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
|
||||
const sourceNode = nodes.find(({ id }) => id === edge.source);
|
||||
const targetNode = nodes.find(({ id }) => id === edge.target);
|
||||
@@ -219,7 +215,8 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
}
|
||||
|
||||
if (issues.length) {
|
||||
edgesToDelete.add(edge.id);
|
||||
// This edge has some issues. Remove it.
|
||||
delete edges[i];
|
||||
const source = edge.type === 'default' ? `${edge.source}.${edge.sourceHandle}` : edge.source;
|
||||
const target = edge.type === 'default' ? `${edge.source}.${edge.targetHandle}` : edge.target;
|
||||
warnings.push({
|
||||
@@ -228,10 +225,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
data: edge,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Remove invalid edges
|
||||
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
|
||||
});
|
||||
|
||||
// Migrated exposed fields to form elements if they exist and the form does not
|
||||
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import {
|
||||
$outputNodeId,
|
||||
@@ -17,13 +16,10 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
|
||||
import type { Batch, EnqueueBatchArg } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
|
||||
|
||||
export const useEnqueueWorkflows = () => {
|
||||
const { getState, dispatch } = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
async (prepend: boolean, isApiValidationRun: boolean) => {
|
||||
dispatch(enqueueRequestedWorkflows());
|
||||
const state = getState();
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const workflow = state.workflow;
|
||||
@@ -134,13 +130,9 @@ export const useEnqueueWorkflows = () => {
|
||||
} as const;
|
||||
});
|
||||
|
||||
assert(workflow.id, 'Workflow without ID cannot be used for API validation run');
|
||||
|
||||
batchConfig.validation_run_data = {
|
||||
workflow_id: workflow.id,
|
||||
input_fields: api_input_fields,
|
||||
output_fields: api_output_fields,
|
||||
};
|
||||
batchConfig.is_api_validation_run = true;
|
||||
batchConfig.api_input_fields = api_input_fields;
|
||||
batchConfig.api_output_fields = api_output_fields;
|
||||
|
||||
// If the batch is an API validation run, we only want to run it once
|
||||
batchConfig.batch.runs = 1;
|
||||
|
||||
@@ -29,6 +29,7 @@ import type { NodesState, Templates } from 'features/nodes/store/types';
|
||||
import { getInvocationNodeErrors } from 'features/nodes/store/util/fieldValidators';
|
||||
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
|
||||
import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
|
||||
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
|
||||
@@ -148,6 +149,7 @@ export const useReadinessWatcher = () => {
|
||||
const canvasIsSelectingObject = useStore(canvasManager?.stateApi.$isSegmenting ?? $true);
|
||||
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
|
||||
useEffect(() => {
|
||||
debouncedUpdateReasons(
|
||||
@@ -187,6 +189,7 @@ export const useReadinessWatcher = () => {
|
||||
upscale,
|
||||
workflowSettings,
|
||||
isInPublishFlow,
|
||||
isPublished,
|
||||
]);
|
||||
};
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ const initialConfigState: AppConfig = {
|
||||
allowPrivateBoards: false,
|
||||
allowPrivateStylePresets: false,
|
||||
allowClientSideUpload: false,
|
||||
allowPublishWorkflows: false,
|
||||
disabledTabs: [],
|
||||
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
|
||||
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
|
||||
@@ -221,5 +220,4 @@ export const selectMetadataFetchDebounce = createConfigSelector((config) => conf
|
||||
|
||||
export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models'));
|
||||
export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload);
|
||||
export const selectAllowPublishWorkflows = createConfigSelector((config) => config.allowPublishWorkflows);
|
||||
export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
|
||||
|
||||
@@ -14,6 +14,7 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type LoadWorkflowOptions = {
|
||||
asCopy?: boolean;
|
||||
onSuccess?: (workflow: WorkflowV3) => void;
|
||||
onError?: () => void;
|
||||
onCompleted?: () => void;
|
||||
@@ -64,11 +65,12 @@ const useLoadImmediate = () => {
|
||||
if (!dialogState) {
|
||||
return;
|
||||
}
|
||||
const { type, data, onSuccess, onError, onCompleted } = dialogState;
|
||||
const { type, data, onSuccess, onError, onCompleted, asCopy } = dialogState;
|
||||
const options = {
|
||||
onSuccess,
|
||||
onError,
|
||||
onCompleted,
|
||||
asCopy,
|
||||
};
|
||||
if (type === 'object') {
|
||||
await loadWorkflowFromObject(data, options);
|
||||
|
||||
@@ -29,7 +29,7 @@ export const useLoadWorkflowFromFile = () => {
|
||||
const { onSuccess, onError, onCompleted } = options;
|
||||
try {
|
||||
const unvalidatedWorkflow = JSON.parse(rawJSON as string);
|
||||
const validatedWorkflow = await validatedAndLoadWorkflow(unvalidatedWorkflow, 'file');
|
||||
const validatedWorkflow = await validatedAndLoadWorkflow(unvalidatedWorkflow);
|
||||
|
||||
if (!validatedWorkflow) {
|
||||
reader.abort();
|
||||
|
||||
@@ -41,7 +41,7 @@ export const useLoadWorkflowFromImage = () => {
|
||||
|
||||
assert(unvalidatedWorkflow !== null, 'No workflow or graph provided');
|
||||
|
||||
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow, 'image');
|
||||
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow);
|
||||
|
||||
if (!validatedWorkflow) {
|
||||
onError?.();
|
||||
|
||||
@@ -24,13 +24,14 @@ export const useLoadWorkflowFromLibrary = () => {
|
||||
onSuccess?: (workflow: WorkflowV3) => void;
|
||||
onError?: () => void;
|
||||
onCompleted?: () => void;
|
||||
asCopy?: boolean;
|
||||
} = {}
|
||||
) => {
|
||||
const { onSuccess, onError, onCompleted } = options;
|
||||
try {
|
||||
const res = await getWorkflow(workflowId).unwrap();
|
||||
|
||||
const validatedWorkflow = await validateAndLoadWorkflow(res.workflow, 'library');
|
||||
const validatedWorkflow = await validateAndLoadWorkflow(res.workflow);
|
||||
|
||||
if (!validatedWorkflow) {
|
||||
onError?.();
|
||||
|
||||
@@ -21,7 +21,7 @@ export const useLoadWorkflowFromObject = () => {
|
||||
) => {
|
||||
const { onSuccess, onError, onCompleted } = options;
|
||||
try {
|
||||
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow, 'object');
|
||||
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow);
|
||||
|
||||
if (!validatedWorkflow) {
|
||||
onError?.();
|
||||
|
||||
@@ -43,10 +43,7 @@ export const useValidateAndLoadWorkflow = () => {
|
||||
*
|
||||
* This function catches all errors. It toasts and logs on success and error.
|
||||
*/
|
||||
async (
|
||||
unvalidatedWorkflow: unknown,
|
||||
origin: 'file' | 'image' | 'object' | 'library'
|
||||
): Promise<WorkflowV3 | null> => {
|
||||
async (unvalidatedWorkflow: unknown): Promise<WorkflowV3 | null> => {
|
||||
try {
|
||||
const templates = $templates.get();
|
||||
const { workflow, warnings } = await validateWorkflow({
|
||||
@@ -57,11 +54,8 @@ export const useValidateAndLoadWorkflow = () => {
|
||||
checkModelAccess,
|
||||
});
|
||||
|
||||
if (origin !== 'library') {
|
||||
// Workflow IDs should always map directly to the workflow in the library. If the workflow is loaded from
|
||||
// some other source, and has an ID, we should remove it to ensure the app does not treat it as a library workflow.
|
||||
// For example, when saving a workflow, we might accidentally attempt to save instead of save-as.
|
||||
delete workflow.id;
|
||||
if (workflow.is_published) {
|
||||
//TODO: How to handle this?
|
||||
}
|
||||
|
||||
$nodeExecutionStates.set({});
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
__version__ = "5.10.0dev3"
|
||||
__version__ = "5.9.1"
|
||||
|
||||
14
pins.json
14
pins.json
@@ -1,14 +0,0 @@
|
||||
{
|
||||
"python": "3.12",
|
||||
"torchIndexUrl": {
|
||||
"win32": {
|
||||
"cuda": "https://download.pytorch.org/whl/cu126"
|
||||
},
|
||||
"linux": {
|
||||
"cpu": "https://download.pytorch.org/whl/cpu",
|
||||
"rocm": "https://download.pytorch.org/whl/rocm6.2.4",
|
||||
"cuda": "https://download.pytorch.org/whl/cu126"
|
||||
},
|
||||
"darwin": {}
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "InvokeAI"
|
||||
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
|
||||
requires-python = ">=3.10, <3.13"
|
||||
requires-python = ">=3.10, <3.12"
|
||||
readme = { content-type = "text/markdown", file = "README.md" }
|
||||
keywords = ["stable-diffusion", "AI"]
|
||||
dynamic = ["version"]
|
||||
@@ -33,46 +33,69 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate",
|
||||
"bitsandbytes; sys_platform!='darwin'",
|
||||
"accelerate==1.0.1",
|
||||
"bitsandbytes==0.45.0; sys_platform!='darwin'",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"diffusers[torch]",
|
||||
"gguf",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.31.0",
|
||||
"gguf==0.10.0",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
|
||||
"numpy<2.0.0",
|
||||
"onnx==1.16.1",
|
||||
"onnxruntime==1.19.2",
|
||||
"opencv-python==4.9.0.80",
|
||||
"safetensors",
|
||||
"spandrel",
|
||||
"torch~=2.6.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
|
||||
"torchsde", # diffusers needs this for SDE solvers, but it is not an explicit dep of diffusers
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
|
||||
"sentencepiece==0.2.0",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch<2.5.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
|
||||
"torchmetrics",
|
||||
"torchsde",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"transformers==4.46.3",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events",
|
||||
"fastapi",
|
||||
"huggingface-hub",
|
||||
"pydantic-settings",
|
||||
"pydantic",
|
||||
"python-socketio",
|
||||
"uvicorn[standard]",
|
||||
"fastapi-events==0.11.1",
|
||||
"fastapi==0.111.0",
|
||||
"huggingface-hub==0.26.1",
|
||||
"pydantic-settings==2.2.1",
|
||||
"pydantic==2.7.2",
|
||||
"python-socketio==5.11.1",
|
||||
"uvicorn[standard]==0.28.0",
|
||||
|
||||
# Auxiliary dependencies, pinned only if necessary.
|
||||
"albumentations",
|
||||
"blake3",
|
||||
"click",
|
||||
"datasets",
|
||||
"Deprecated",
|
||||
"dnspython",
|
||||
"dynamicprompts",
|
||||
"einops",
|
||||
"facexlib",
|
||||
# Exclude 3.9.1 which has a problem on windows, see https://github.com/matplotlib/matplotlib/issues/28551
|
||||
"matplotlib!=3.9.1",
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
"picklescan",
|
||||
"pillow",
|
||||
"prompt-toolkit",
|
||||
"pympler",
|
||||
"pypatchmatch",
|
||||
"pyperclip",
|
||||
"pyreadline3",
|
||||
"python-multipart",
|
||||
"requests",
|
||||
"rich~=13.3",
|
||||
"scikit-image",
|
||||
"semver~=3.0.1",
|
||||
"test-tube",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
"humanize==4.12.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -104,8 +127,7 @@ dependencies = [
|
||||
"pytest-datadir",
|
||||
"requests_testadapter",
|
||||
"httpx",
|
||||
"polyfactory==2.19.0",
|
||||
"humanize==4.12.1",
|
||||
"polyfactory==2.19.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -185,9 +207,9 @@ exclude = [
|
||||
".venv*",
|
||||
"*.ipynb",
|
||||
"invokeai/backend/image_util/mediapipe_face/", # External code
|
||||
"invokeai/backend/image_util/mlsd/", # External code
|
||||
"invokeai/backend/image_util/normal_bae/", # External code
|
||||
"invokeai/backend/image_util/pidi/", # External code
|
||||
"invokeai/backend/image_util/mlsd/", # External code
|
||||
"invokeai/backend/image_util/normal_bae/", # External code
|
||||
"invokeai/backend/image_util/pidi/", # External code
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
@@ -21,18 +21,16 @@ def count_files(path: Path):
|
||||
|
||||
@pytest.fixture
|
||||
def obj_serializer(tmp_path: Path):
|
||||
return ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
|
||||
return ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fwd_cache(tmp_path: Path):
|
||||
return ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass]), max_cache_size=2
|
||||
)
|
||||
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
|
||||
|
||||
|
||||
def test_obj_serializer_disk_initializes(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
assert obj_serializer._output_dir == tmp_path
|
||||
|
||||
|
||||
@@ -72,7 +70,7 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
|
||||
assert obj_serializer._base_output_dir == tmp_path
|
||||
assert obj_serializer._output_dir != tmp_path
|
||||
@@ -80,21 +78,21 @@ def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
del obj_serializer
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
@@ -104,19 +102,19 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
assert not tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=False)
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False)
|
||||
assert tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer_1.save(obj_1)
|
||||
obj_1_loaded = obj_serializer_1.load(obj_1_name)
|
||||
@@ -125,19 +123,19 @@ def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
assert obj_1_loaded.foo == "bar"
|
||||
assert obj_1_name.startswith("MockDataclass_")
|
||||
|
||||
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path, safe_globals=[int])
|
||||
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
|
||||
obj_2_name = obj_serializer_2.save(9001)
|
||||
assert obj_serializer_2._obj_class_name == "int"
|
||||
assert obj_serializer_2.load(obj_2_name) == 9001
|
||||
assert obj_2_name.startswith("int_")
|
||||
|
||||
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path, safe_globals=[str])
|
||||
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
|
||||
obj_3_name = obj_serializer_3.save("foo")
|
||||
assert obj_serializer_3._obj_class_name == "str"
|
||||
assert obj_serializer_3.load(obj_3_name) == "foo"
|
||||
assert obj_3_name.startswith("str_")
|
||||
|
||||
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path, safe_globals=[torch.Tensor])
|
||||
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
|
||||
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
|
||||
obj_4_loaded = obj_serializer_4.load(obj_4_name)
|
||||
assert obj_serializer_4._obj_class_name == "Tensor"
|
||||
|
||||
Reference in New Issue
Block a user