Compare commits

..

83 Commits

Author SHA1 Message Date
psychedelicious
3e200a2ba2 chore: bump version to v5.10.0dev3 2025-04-04 16:48:12 +10:00
psychedelicious
4610b55a5d chore: update uv.lock 2025-04-04 16:46:15 +10:00
psychedelicious
b3b3dbd92d build: remove pin on spandrel dependency 2025-04-04 16:41:10 +10:00
psychedelicious
6c36b0508b build: add comment about torchsde to pyproject 2025-04-04 16:40:50 +10:00
psychedelicious
2756c539e0 build: remove pin on gguf dependency
This allows it to pull in sentencepiece on its own. In 0.10.0, it didn't have this package listed as a dependency, but in recent releases it does. So we are able to remove sentencepiece as an explicit dep.
2025-04-04 16:40:36 +10:00
psychedelicious
a34383d460 build: remove unused clip_anytorch dependency 2025-04-04 16:39:20 +10:00
psychedelicious
77f22497d2 build: remove unused pytorch-lightning dependency 2025-04-04 16:39:20 +10:00
psychedelicious
5967d4e1da build: remove unused pyreadline3 dependency 2025-04-04 16:39:20 +10:00
psychedelicious
1253ad5053 build: remove unused pyperclip dependency 2025-04-04 16:39:20 +10:00
psychedelicious
5aa08ab09b build: remove unused pympler dependency 2025-04-04 16:39:19 +10:00
psychedelicious
6ce527768b build: remove unused scikit-image dependency 2025-04-04 16:39:19 +10:00
psychedelicious
fe88012236 build: remove unused npyscreen dependency 2025-04-04 16:39:19 +10:00
psychedelicious
8609b98217 build: remove unused torchmetrics dependency 2025-04-04 16:13:45 +10:00
psychedelicious
19f0bf828c build: remove unused datasets dependency 2025-04-04 16:12:13 +10:00
psychedelicious
26cbeccfdf build: remove unused click dependency 2025-04-04 16:11:38 +10:00
psychedelicious
b5be81b97b build: remove unused omegaconf dependency 2025-04-04 16:09:53 +10:00
psychedelicious
f14d07968b build: remove unused facexlib dependency 2025-04-04 16:09:36 +10:00
psychedelicious
525a89900a build: remove unused timm dependency 2025-04-04 16:08:31 +10:00
psychedelicious
d8df31a8ac chore(ui): typegen 2025-04-04 16:03:29 +10:00
psychedelicious
380a41be34 chore: update uv.lock 2025-04-04 16:03:29 +10:00
psychedelicious
e990afbccb build: remove unused matplotlib dep 2025-04-04 16:03:29 +10:00
psychedelicious
c591478d24 tidy(nodes): remove matplotlib dependency
It was only used for a single color conversion function. Replaced with cv2 code, tested functionality to confirm it works the same.
2025-04-04 16:03:29 +10:00
psychedelicious
30def6a9bd build: move humanize to test deps 2025-04-04 16:03:29 +10:00
psychedelicious
6cf88a601d build: remove unused albumentations dependency
This is not used
2025-04-04 16:03:29 +10:00
psychedelicious
5e14545c32 tidy: delete unused file 2025-04-04 16:03:29 +10:00
psychedelicious
eefbcd2485 build: remove controlnet_aux dependency, remove pin for timm 2025-04-04 16:03:29 +10:00
psychedelicious
13cc44a22c tidy(nodes): rename controlnet_image_processors.py -> controlnet.py 2025-04-04 16:03:29 +10:00
psychedelicious
2cca339a5c tidy(nodes): remove unused old dw openpose detector class 2025-04-04 16:03:29 +10:00
psychedelicious
0a7cf6c0ec tidy(nodes): remove deprecated controlnet "processor" nodes 2025-04-04 16:03:29 +10:00
psychedelicious
06abc1d40a build: upgrade python to 3.12 in pins 2025-04-04 16:03:29 +10:00
psychedelicious
2cde86b7b8 build: update uv.lock 2025-04-04 16:03:28 +10:00
psychedelicious
0a49463c79 fix(backend): remove mps_fixes
The fixes in this module monkeypatched `torch` to resolve some issues with FP16 on macOS. These issues have long since been resolved.

Included in the now-removed fixes is `CustomSlicedAttentionProcessor`, which is intended to reduce memory requirements for MPS. This overrides `diffusers`' own `SlicedAttentionProcessor`.

Unfortunately, `attention_type: sliced` produces hot garbage with the fixes and black images without the fixes. So this class appears to now be a moot point.

Regardless, SDPA is supported on MPS and very efficient, so sliced attention is largely obsolete.
2025-04-04 16:03:28 +10:00
psychedelicious
f3402b6ce7 chore: bump version to v5.10.0dev2
Doing a dev build so I can test the launcher.
2025-04-04 16:03:28 +10:00
psychedelicious
5d3fb822c5 build: downgrade python to 3.11 in pins 2025-04-04 16:03:28 +10:00
psychedelicious
9e70d8eb6e build: restore prev setuptools config to fix wheel build 2025-04-04 16:03:28 +10:00
psychedelicious
402758d502 ci: use py3.12 to build installer 2025-04-04 16:03:28 +10:00
psychedelicious
b97cc51f23 experiment: add pins.json to repo
The launcher will query this file to get the pins needed for installation
2025-04-04 16:03:28 +10:00
psychedelicious
f6f33b5999 chore: bump version to v5.10.0dev1
Doing a dev build so I can test the launcher.
2025-04-04 16:03:28 +10:00
psychedelicious
cd873f1fe5 chore: update uv.lock for latest pydantic
Ran `uv lock --upgrade-package pydantic`
2025-04-04 16:03:28 +10:00
psychedelicious
5f3d398074 fix(ui): handle updated schema structure during invocation parsing
In https://github.com/pydantic/pydantic/pull/10029, pydantic made an improvement to its generated JSON schemas (OpenAPI schemas). The previous and new generated schemas both meet the schema spec.

When we parse the OpenAPI schema to generate node templates, we use some typeguard to narrow schema components from generic OpenAPI schema objects to a node field schema objects. The narrower node field schema objects contain extra data.

For example, they contain a `field_kind` attribute that indicates it the field is an input field or output field. These extra attributes are not part of the OpenAPI spec (but the spec allows does allow for this extra data).

This typeguard relied on a pydantic implementation detail. This was changed in the linked pydantic PR, which released with v2.9.0. With the change, our typeguard rejects input field schema objects, causing parsing to fail with errors/warnings like `Unhandled input property` in the JS console.

In the UI, this causes many fields - mostly model fields - to not show up in the workflow editor.

The fix for this is very simple - instead of relying on an implementation detail for the typeguard, we can check if the incoming schema object has any of our invoke-specific extra attributes. Specifically, we now look for the presence of the `field_kind` attribute on the incoming schema object. If it is present, we know we are dealing with an invocation input field and can parse it appropriately.
2025-04-04 16:03:28 +10:00
psychedelicious
e6b366ff61 chore: typegen 2025-04-04 16:03:28 +10:00
psychedelicious
bcd50ed688 chore: remove pydantic pin 2025-04-04 16:03:27 +10:00
psychedelicious
a5966c3197 chore(ui): typegen 2025-04-04 16:03:27 +10:00
psychedelicious
f28b054872 tests: update tests/test_object_serializer_disk.py 2025-04-04 16:03:27 +10:00
psychedelicious
31681f4ad7 fix(app): add trusted classes to torch safe globals to prevent errors when loading them
In `ObjectSerializerDisk`, we use `torch.load` to load serialized objects from disk. With torch 2.6.0, torch defaults to `weights_only=True`. As a result, torch will raise when attempting to deserialize anything with an unrecognized class.

For example, our `ConditioningFieldData` class is untrusted. When we load conditioning from disk, we will get a runtime error.

Torch provides a method to add trusted classes to an allowlist. This change adds an arg to `ObjectSerializerDisk` to add a list of safe globals to the allowlist and uses it for both `ObjectSerializerDisk` instances.

Note: My first attempt inferred the class from the generic type arg that `ObjectSerializerDisk` accepts, and added that to the allowlist. Unfortunately, this doesn't work.

For example, `ConditioningFieldData` has a `conditionings` attribute that may be one some other untrusted classes representing model-specific conditioning data. So, even if we allowlist `ConditioningFieldData`, loading will fail when torch deserializes the `conditionings` attribute.
2025-04-04 16:03:27 +10:00
Eugene Brodsky
aaf042de48 resolve conflict between timm version needed by LLaVA and controlnet-aux 2025-04-04 16:03:27 +10:00
Eugene Brodsky
c28e685409 reintroduce GPU_DRIVER build arg in CI container build, as it has apparently been removed 2025-04-04 16:03:27 +10:00
Eugene Brodsky
d6ac822a1f remove obsoleted depenencies that were used by the CLI 2025-04-04 16:03:27 +10:00
Eugene Brodsky
f0a4d7ac7f modify docs for python 3.12 2025-04-04 16:03:27 +10:00
Eugene Brodsky
04b0e658df update nodes schema / typegen 2025-04-04 16:03:27 +10:00
Eugene Brodsky
68845f4d85 update uv.lock 2025-04-04 16:03:27 +10:00
Eugene Brodsky
6df5614b54 refactor Dockerfile; get rid of multi-stage build; upgrade to python 3.12 2025-04-04 16:03:27 +10:00
Eugene Brodsky
0bd6f0245b use uv.lock to pin dependencies 2025-04-04 16:03:26 +10:00
Eugene Brodsky
6c9165046e upgrade pytorch and unpin some of the strict dependency pins to facilitate upgrading co-dependencies.
we will use uv.lock to ensure reproducibility
2025-04-04 16:03:26 +10:00
Chantell
2b5da91beb Update manual.md
Removed a redundancy of package specifier on step 6.
2025-04-04 16:52:04 +11:00
psychedelicious
74bede14be feat(ui): put all validatoin run data into single object 2025-04-04 11:38:04 +11:00
psychedelicious
04ea3c491a chore(ui): typegen 2025-04-04 11:38:04 +11:00
psychedelicious
38e7b23d18 feat(api): put all validatoin run data into single object 2025-04-04 11:38:04 +11:00
psychedelicious
c052846e05 feat(ui): ensure workflow id is passed when doing validation run 2025-04-04 11:38:04 +11:00
psychedelicious
af3a31dfec chore(ui): typegen 2025-04-04 11:38:04 +11:00
psychedelicious
571710fab6 feat(app): add optional published_workflow_id to enqueue payloads and queue item 2025-04-04 11:38:04 +11:00
psychedelicious
a175a5c252 feat(ui): add safeguard against accidentally loading non-library workflow as library workflow 2025-04-04 11:38:04 +11:00
psychedelicious
8b3c36c6fa refactor(ui): better UX for choosing output nodes 2025-04-04 11:38:04 +11:00
psychedelicious
b9ffacd4bf fix(ui): disable publish button when not ready to enqueue (i.e. invalid graph) 2025-04-04 11:38:04 +11:00
psychedelicious
ae45fc8a74 gh: update codeowners
- Add @psychedelicious as codeowner for docs
- Remove inactive contributors
2025-04-03 18:34:39 -04:00
psychedelicious
85db9c65e5 fix(ui): add missing tkey 2025-04-03 12:42:28 +11:00
psychedelicious
ddddaef7ca refactor(ui): use dedicated allowPublishWorkflows instead of disabledFeatures 2025-04-03 12:42:28 +11:00
psychedelicious
e4678201cb feat(ui): add conditionally-enabled workflow publishing ui
This is a squash of a lot of scattered commits that became very difficult to clean up and make individually. Sorry.

Besides the new UI, there are a number of notable changes:
- Publishing logic is disabled in OSS by default. To enable it, provided a `disabledFeatures` prop _without_ "publishWorkflow".
- Enqueuing a workflow is no longer handled in a redux listener. It was  hard to track the state of the enqueue logic in the listener. It is now in a hook. I did not migrate the canvas and upscaling tabs - their enqueue logic is still in the listener.
- When queueing a validation run, the new `useEnqueueWorkflows()` hook will update the payload with the required data for the run.
- Some logic is added to the socket event listeners to handle workflow publish runs completing.
- The workflow library side nav has a new "published" view. It is hidden when the "publishWorkflow" feature is disabled.
- I've added `Safe` and `OrThrow` versions of some workflows hooks. These hooks typically retrieve some data from redux. For example, a node. The `Safe` hooks return the node or null if it cannot be found, while the `OrThrow` hooks return the node or raise if it cannot be found. The `OrThrow` hooks should be used within one of the gate components. These components use the `Safe` hooks and render a fallback if e.g. the node isn't found. This change is required for some of the publish flow UI.
- Add support for locking the workflow editor. When locked, you can pan and zoom but that's it. Currently, it is only locked during publish flow and if a published workflow is opened.
2025-04-03 12:42:28 +11:00
psychedelicious
d66fdfde71 chore(ui): typegen 2025-04-03 12:42:28 +11:00
psychedelicious
08ee08557b feat(app): add noop api validation run stuff to routes and methods 2025-04-03 12:42:28 +11:00
psychedelicious
496f1262c6 feat(app): truncate warnings for invalid model config in db
This message is logged _every_ time we retrieve a list of models if there is an invalid model. Previously it logged the _whole_ row which can be a lot of data. Truncate the row to 64 characters to reduce log pollution.
2025-04-03 12:42:28 +11:00
psychedelicious
188d52e4a5 chore(ui): bump tsafe to latest 2025-04-03 12:42:28 +11:00
Riku
db03c196a1 translationBot(ui): update translation (German)
Currently translated at 66.8% (1230 of 1840 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2025-04-03 07:42:43 +11:00
Riccardo Giovanetti
6bc36b697d translationBot(ui): update translation (Italian)
Currently translated at 98.8% (1818 of 1840 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1816 of 1840 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1816 of 1839 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-04-03 07:42:43 +11:00
Linos
b7d71d3028 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1840 of 1840 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1838 of 1838 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-04-03 07:42:43 +11:00
psychedelicious
fa1ebd9d2f fix(ui): do not switch between images when focused on a tab element
Arrow keys should only navigate between tabs, not gallery images.
2025-04-03 07:40:10 +11:00
psychedelicious
eed5d02069 fix(ui): handling for invalid edges when loading workflows
Previously, reactflow appears to have handled an edge case when using its `applyChanges` utility. If a change was provided without an item, it would skip that change. For example, an "add edge" change that somehow passed `null` as the edge, instead of a valid edge.

In our workflow loading and validation logic, invalid edges were removed from the array using `delete edges[i]`. This left "holes" in the array of edges. We then asked `reactflow` to add these edges to state. When it encountered one of the "holes", it skipped over it.

In a recent release (unsure which, somewhere between the latest v11 and ~v12.4) this seems to have changed. It no longer skips over the "holes" and instead trusts the data. This can cause a couple issues:
- Error when loading the workflow if `reactflow` attempt to do anything with the nonexistent edge.
- If somehow the workflow makes it into state with "holes" in the array of edges, all sorts of other stuff breaks when our code does anything with the nonexistent edge.

Two-part fix:
- Update the invalid edge handling to not use `delete edges[i]`. Instead, as we check each edge, we add invalid ones to a set. Then, after all the checks are finished, filter out the invalid edges. The resultant edges array has no holes.
- Simplify the logic around setting nodes and edges in redux. Previously we were using `reactflow`'s `applyChanges` utils, but this does literally nothing except take extra CPU cycles. We can simply set the loaded nodes and edges directly in redux. Perhaps we were using `applyChanges` because it addressed the "holes" issue? Not sure. But we don't need it now.

Closes #7868
2025-04-03 07:37:49 +11:00
psychedelicious
3650d91045 chore(ui): bump @xyflow/react to latest 2025-04-03 07:37:49 +11:00
Eugene Brodsky
6c7d08cacb Change timm and controlnet-aux pins to fix LLaVA model support (#7846)
## Summary

`timm` below 1.0.0 prevents llava models from working (broken in
transformers). but `controlnet-aux` pins `timm` to an earlier version
because otherwise it was breaking the ZoeDepth controlnet.

we don't use ZoeDepth (replaced by depthAnything), and downgrading
controlnet-aux seems to be acceptable.

more context here:

https://github.com/huggingface/controlnet_aux/issues/106
https://github.com/huggingface/controlnet_aux/pull/101


Note that this results in some warnings on startup, stemming from
controlnet-aux:

![image](https://github.com/user-attachments/assets/fa908837-6154-42a2-a93b-eb5e363f5783)

we can probably silence the warnings as a separate enhancement

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-04-01 21:16:40 -04:00
Eugene Brodsky
bb1c40f222 Merge branch 'main' into pin-timm-for-llava 2025-04-01 21:10:30 -04:00
Eugene Brodsky
d26b7a1a12 Merge branch 'main' into pin-timm-for-llava 2025-03-31 11:37:29 -04:00
Eugene Brodsky
c9992914d6 Merge branch 'main' into pin-timm-for-llava 2025-03-28 09:20:30 -04:00
Eugene Brodsky
3f12a43e75 remove pin for controlnet-aux and pin timm to a version that works with llava
timm < 1.0.0 prevents llava models from working (broken in transformers). but controlnet-aux pinned it to an earlier version because otherwise it was breaking the ZoeDepth controlnet.

we don't use ZoeDepth (replaced by depthAnything), and downgrading controlnet-aux seems to be acceptable.

more context here:

https://github.com/huggingface/controlnet_aux/issues/106
https://github.com/huggingface/controlnet_aux/pull/101
2025-03-26 16:58:18 -04:00
63 changed files with 4190 additions and 3219 deletions

View File

@@ -1,9 +1,11 @@
*
!invokeai
!pyproject.toml
!uv.lock
!docker/docker-entrypoint.sh
!LICENSE
**/dist
**/node_modules
**/__pycache__
**/*.egg-info
**/*.egg-info

8
.github/CODEOWNERS vendored
View File

@@ -2,11 +2,11 @@
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
# installation and configuration
/pyproject.toml @lstein @blessedcoolant @hipsterusername
@@ -22,7 +22,7 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername @jazzhaiku
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
# front ends
/invokeai/frontend/CLI @lstein @hipsterusername

View File

@@ -97,6 +97,8 @@ jobs:
context: .
file: docker/Dockerfile
platforms: ${{ env.PLATFORMS }}
build-args: |
GPU_DRIVER=${{ matrix.gpu-driver }}
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.12'
cache: pip
cache-dependency-path: pyproject.toml

2
.nvmrc
View File

@@ -1 +1 @@
v22.12.0
v22.14.0

View File

@@ -1,77 +1,6 @@
# syntax=docker/dockerfile:1.4
## Builder stage
FROM library/ubuntu:24.04 AS builder
ARG DEBIAN_FRONTEND=noninteractive
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt-get install -y \
build-essential \
git
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
ENV VIRTUAL_ENV=/opt/venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ENV INVOKEAI_SRC=/opt/invokeai
ENV PYTHON_VERSION=3.11
ENV UV_PYTHON=3.11
ENV UV_COMPILE_BYTECODE=1
ENV UV_LINK_MODE=copy
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
ARG GPU_DRIVER=cuda
# unused but available
ARG BUILDPLATFORM
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
RUN mkdir -p ${VIRTUAL_ENV} && \
mkdir -p ${INVOKEAI_SRC} && \
chmod -R a+w /opt && \
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
USER ubuntu
# Install python
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
uv python install ${PYTHON_VERSION}
WORKDIR ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv sync --no-install-project
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
COPY invokeai invokeai
COPY pyproject.toml .
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv sync
#### Build the Web UI ------------------------------------
#### Web UI ------------------------------------
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
@@ -85,69 +14,89 @@ RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN npx vite build
#### Runtime stage ---------------------------------------
## Backend ---------------------------------------
FROM library/ubuntu:24.04 AS runtime
FROM library/ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
apt update && apt install -y --no-install-recommends \
ca-certificates \
git \
gosu \
libglib2.0-0 \
libgl1 \
libglx-mesa0 \
build-essential \
libopencv-dev \
libstdc++-10-dev
RUN apt update && apt install -y --no-install-recommends \
git \
curl \
vim \
tmux \
ncdu \
iotop \
bzip2 \
gosu \
magic-wormhole \
libglib2.0-0 \
libgl1 \
libglx-mesa0 \
build-essential \
libopencv-dev \
libstdc++-10-dev &&\
apt-get clean && apt-get autoclean
ENV \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
VIRTUAL_ENV=/opt/venv \
INVOKEAI_SRC=/opt/invokeai \
PYTHON_VERSION=3.12 \
UV_PYTHON=3.12 \
UV_COMPILE_BYTECODE=1 \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
UV_INDEX="https://download.pytorch.org/whl/cu124" \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
PATH="/opt/venv/bin:$PATH" \
CONTAINER_UID=${CONTAINER_UID:-1000} \
CONTAINER_GID=${CONTAINER_GID:-1000}
ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV PYTHON_VERSION=3.11
ENV INVOKEAI_ROOT=/invokeai
ENV INVOKEAI_HOST=0.0.0.0
ENV INVOKEAI_PORT=9090
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
ARG GPU_DRIVER=cuda
# Install `uv` for package management
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
USER ubuntu
RUN uv python install ${PYTHON_VERSION}
USER root
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
# --link requires buldkit w/ dockerfile syntax 1.4
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
RUN --mount=type=cache,target=/root/.cache/uv \
uv python install ${PYTHON_VERSION} && \
# chmod --recursive a+rX /root/.local/share/uv/python
chmod 711 /root
WORKDIR ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv sync --frozen
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
CMD ["invokeai-web"]
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# add sources last to minimize image changes on code changes
COPY invokeai ${INVOKEAI_SRC}/invokeai

View File

@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
With the modifications made, the install command should look something like this:
```sh
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
```
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.

View File

@@ -43,10 +43,10 @@ The following commands vary depending on the version of Invoke being installed a
3. Create a virtual environment in that directory:
```sh
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
uv venv --relocatable --prompt invoke --python 3.12 --python-preference only-managed .venv
```
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
This command creates a portable virtual environment at `.venv` complete with a portable python 3.12. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
4. Activate the virtual environment:
@@ -64,7 +64,7 @@ The following commands vary depending on the version of Invoke being installed a
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
6. Determine the package package specifier to use when installing. This is a performance optimization.
6. Determine the package specifier to use when installing. This is a performance optimization.
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
@@ -88,13 +88,13 @@ The following commands vary depending on the version of Invoke being installed a
8. Install the `invokeai` package. Substitute the package specifier and version.
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
```
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
```
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:

View File

@@ -41,7 +41,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
@@ -49,19 +49,19 @@ Check that your system has an up-to-date Python installed by running `python3 --
=== "Windows"
- Install python 3.11 with [an official installer].
- Install python with [an official installer].
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
- You may need to install [Microsoft Visual C++ Redistributable].
=== "macOS"
- Install python 3.11 with [an official installer].
- Install python with [an official installer].
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
=== "Linux"
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
- Installing python varies depending on your system. We recommend [using `uv` to manage your python installation](https://docs.astral.sh/uv/concepts/python-versions/#installing-a-python-version).
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
## Drivers

View File

@@ -37,7 +37,13 @@ from invokeai.app.services.style_preset_records.style_preset_records_sqlite impo
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
FLUXConditioningInfo,
SD3ConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@@ -101,10 +107,25 @@ class ApiDependencies:
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = ObjectSerializerForwardCache(
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
ObjectSerializerDisk[torch.Tensor](
output_folder / "tensors",
safe_globals=[torch.Tensor],
ephemeral=True,
),
max_cache_size=0,
)
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
ObjectSerializerDisk[ConditioningFieldData](
output_folder / "conditioning",
safe_globals=[
ConditioningFieldData,
BasicConditioningInfo,
SDXLConditioningInfo,
FLUXConditioningInfo,
SD3ConditioningInfo,
],
ephemeral=True,
),
)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")

View File

@@ -1,13 +1,10 @@
import json
from typing import Any, Optional
from typing import Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.fields import BoardField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@@ -26,7 +23,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.compose_pydantic_model import compose_model_from_fields
from invokeai.app.services.shared.pagination import CursorPaginatedResults
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
@@ -39,15 +35,10 @@ class SessionQueueAndProcessorStatus(BaseModel):
processor: SessionProcessorStatus
class SimpleModelIdentifer(BaseModel):
id: str = Field(description="The model id")
model_field_overrides = {ModelIdentifierField: (SimpleModelIdentifer, Field(description="The model identifier"))}
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField, Optional[BoardField]}
class ValidationRunData(BaseModel):
workflow_id: str = Field(description="The id of the workflow being published.")
input_fields: list[FieldIdentifier] = Body(description="The input fields for the published workflow")
output_fields: list[FieldIdentifier] = Body(description="The output fields for the published workflow")
@session_queue_router.post(
@@ -61,52 +52,13 @@ async def enqueue_batch(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch: Batch = Body(description="Batch to process"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
is_api_validation_run: bool = Body(
default=False,
description="Whether or not this is a validation run.",
),
api_input_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as input to the API"
),
api_output_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as output from the API"
validation_run_data: Optional[ValidationRunData] = Body(
default=None,
description="The validation run data to use for this batch. This is only used if this is a validation run.",
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
if is_api_validation_run:
session_count = batch.get_session_count()
assert session_count == 1, "API validation run only supports single session batches"
if api_input_fields:
composed_model = compose_model_from_fields(
g=batch.graph,
field_identifiers=api_input_fields,
composed_model_class_name="APIInputModel",
model_field_overrides=model_field_overrides,
model_field_filter=model_field_filter,
)
json_schema = composed_model.model_json_schema(mode="validation")
print("API Input Model")
print(json.dumps(json_schema))
if api_output_fields:
composed_model = compose_model_from_fields(
g=batch.graph,
field_identifiers=api_output_fields,
composed_model_class_name="APIOutputModel",
)
json_schema = composed_model.model_json_schema(mode="validation")
print("API Output Model")
print(json.dumps(json_schema))
print("graph")
print(batch.graph.model_dump_json())
if batch.workflow is not None:
print("workflow")
print(batch.workflow.model_dump_json())
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)

View File

@@ -1,5 +1,4 @@
import io
import random
import traceback
from typing import Optional
@@ -25,37 +24,6 @@ from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import
IMAGE_MAX_AGE = 31536000
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
ids = {
"6614752a-0420-4d81-98fc-e110069d4f38": random.choice([True, False]),
"default_5e8b008d-c697-45d0-8883-085a954c6ace": random.choice([True, False]),
"4b2b297a-0d47-4f43-8113-ebbf3f403089": random.choice([True, False]),
"d0ce602a-049e-4368-97ae-977b49eed042": random.choice([True, False]),
"f170a187-fd74-40b8-ba9c-00de173ea4b9": random.choice([True, False]),
"default_f96e794f-eb3e-4d01-a960-9b4e43402bcf": random.choice([True, False]),
"default_cbf0e034-7b54-4b2c-b670-3b1e2e4b4a88": random.choice([True, False]),
"default_dec5a2e9-f59c-40d9-8869-a056751d79b8": random.choice([True, False]),
"default_dbe46d95-22aa-43fb-9c16-94400d0ce2fd": random.choice([True, False]),
"default_d7a1c60f-ca2f-4f90-9e33-75a826ca6d8f": random.choice([True, False]),
"default_e71d153c-2089-43c7-bd2c-f61f37d4c1c1": random.choice([True, False]),
"default_7dde3e36-d78f-4152-9eea-00ef9c8124ed": random.choice([True, False]),
"default_444fe292-896b-44fd-bfc6-c0b5d220fffc": random.choice([True, False]),
"default_2d05e719-a6b9-4e64-9310-b875d3b2f9d2": random.choice([True, False]),
"acae7e87-070b-4999-9074-c5b593c86618": random.choice([True, False]),
"3008fc77-1521-49c7-ba95-94c5a4508d1d": random.choice([True, False]),
"default_686bb1d0-d086-4c70-9fa3-2f600b922023": random.choice([True, False]),
"36905c46-e768-4dc3-8ecd-e55fe69bf03c": random.choice([True, False]),
"7c3e4951-183b-40ef-a890-28eef4d50097": random.choice([True, False]),
"7a053b2f-64e4-4152-80e9-296006e77131": random.choice([True, False]),
"27d4f1be-4156-46e9-8d22-d0508cd72d4f": random.choice([True, False]),
"e881dc06-70d2-438f-b007-6f3e0c3c0e78": random.choice([True, False]),
"265d2244-a1d7-495c-a2eb-88217f5eae37": random.choice([True, False]),
"caebcbc7-2bf0-41c4-b553-106b585fddda": random.choice([True, False]),
"a7998705-474e-417d-bd37-a2a9480beedf": random.choice([True, False]),
"554d94b5-94b3-4d8e-8aed-51ebfc9deea5": random.choice([True, False]),
"e6898540-c1bc-408b-b944-c1e242cddbcd": random.choice([True, False]),
"363b0960-ab2c-4902-8df3-f592d6194bb3": random.choice([True, False]),
}
@workflows_router.get(
"/i/{workflow_id}",
@@ -71,8 +39,6 @@ async def get_workflow(
try:
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
workflow.is_published = ids.get(workflow_id, False)
workflow.workflow.is_published = ids.get(workflow_id, False)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
@@ -144,7 +110,7 @@ async def list_workflows(
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflow_record_list_items = ApiDependencies.invoker.services.workflow_records.get_many(
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
direction=direction,
page=page,
@@ -155,21 +121,19 @@ async def list_workflows(
has_been_opened=has_been_opened,
is_published=is_published,
)
for item in workflow_record_list_items.items:
data = item.model_dump()
data["is_published"] = ids.get(item.workflow_id, False)
for workflow in workflows.items:
workflows_with_thumbnails.append(
WorkflowRecordListItemWithThumbnailDTO(
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(item.workflow_id),
**data,
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
**workflow.model_dump(),
)
)
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
items=workflows_with_thumbnails,
total=workflow_record_list_items.total,
page=workflow_record_list_items.page,
pages=workflow_record_list_items.pages,
per_page=workflow_record_list_items.per_page,
total=workflows.total,
page=workflows.page,
pages=workflows.pages,
per_page=workflows.per_page,
)

View File

@@ -0,0 +1,128 @@
# Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
from typing import List, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)
@invocation(
"heuristic_resize",
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
version="1.0.1",
classification=Classification.Prototype,
)
class HeuristicResizeInvocation(BaseInvocation):
"""Resize an image using a heuristic method. Preserves edge maps."""
image: ImageField = InputField(description="The image to resize")
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_img = pil_to_np(image)
np_resized = heuristic_resize(np_img, (self.width, self.height))
resized = np_to_pil(np_resized)
image_dto = context.images.save(image=resized)
return ImageOutput.build(image_dto)

View File

@@ -1,716 +0,0 @@
# Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union
import cv2
import numpy as np
from controlnet_aux import (
ContentShuffleDetector,
LeresDetector,
MediapipeFaceDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
PidiNetDetector,
SamDetector,
ZoeDetector,
)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator
from transformers import pipeline
from transformers.pipelines import DepthEstimationPipeline
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)
# This invocation exists for other invocations to subclass it - do not register with @invocation!
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Base class for invocations that preprocess images for ControlNet"""
image: ImageField = InputField(description="The image to process")
def run_processor(self, image: Image.Image) -> Image.Image:
# superclass just passes through image without processing
return image
def load_image(self, context: InvocationContext) -> Image.Image:
# allows override for any special formatting specific to the preprocessor
return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.images.save(image=processed_image)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,
width=image_dto.width,
# height=processed_image.height,
height=image_dto.height,
# mode=processed_image.mode,
)
@invocation(
"canny_image_processor",
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
version="1.3.3",
classification=Classification.Deprecated,
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
)
high_threshold: int = InputField(
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
)
def load_image(self, context: InvocationContext) -> Image.Image:
# Keep alpha channel for Canny processing to detect edges of transparent areas
return context.images.get_pil(self.image.image_name, "RGBA")
def run_processor(self, image: Image.Image) -> Image.Image:
processed_image = get_canny_edges(
image,
self.low_threshold,
self.high_threshold,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
@invocation(
"hed_image_processor",
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
# safe not supported in controlnet_aux v0.0.3
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
hed_processor = HEDProcessor()
processed_image = hed_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image
@invocation(
"lineart_image_processor",
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image:
lineart_processor = LineartProcessor()
processed_image = lineart_processor.run(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
)
return processed_image
@invocation(
"lineart_anime_image_processor",
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
processor = LineartAnimeProcessor()
processed_image = processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
@invocation(
"midas_depth_image_processor",
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
version="1.2.4",
classification=Classification.Deprecated,
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
return processed_image
@invocation(
"normalbae_image_processor",
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
)
return processed_image
@invocation(
"mlsd_image_processor",
title="MLSD Processor",
tags=["controlnet", "mlsd"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d,
)
return processed_image
@invocation(
"pidi_image_processor",
title="PIDI Processor",
tags=["controlnet", "pidi"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble,
)
return processed_image
@invocation(
"content_shuffle_image_processor",
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image: Image.Image) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
h=self.h,
w=self.w,
f=self.f,
)
return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
@invocation(
"zoe_depth_image_processor",
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
@invocation(
"mediapipe_face_processor",
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.2.4",
classification=Classification.Deprecated,
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image,
max_faces=self.max_faces,
min_confidence=self.min_confidence,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
)
return processed_image
@invocation(
"leres_image_processor",
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
boost: bool = InputField(default=False, description="Whether to use boost mode")
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
thr_a=self.thr_a,
thr_b=self.thr_b,
boost=self.boost,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
@invocation(
"tile_image_processor",
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(
self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
return np_img
H, W, C = np_img.shape
H = int(float(H) / float(down_sampling_rate))
W = int(float(W) / float(down_sampling_rate))
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, image: Image.Image) -> Image.Image:
np_img = np.array(image, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
down_sampling_rate=self.down_sampling_rate,
)
processed_image = Image.fromarray(processed_np_image)
return processed_image
@invocation(
"segment_anything_processor",
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.2.4",
classification=Classification.Deprecated,
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
)
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
)
return processed_image
class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
# so using ADE20k color palette instead
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
h, w = anns[0]["segmentation"].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann["segmentation"]
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)
@invocation(
"color_map_image_processor",
title="Color Map Processor",
tags=["controlnet"],
category="controlnet",
version="1.2.3",
classification=Classification.Deprecated,
)
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image"""
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image) -> Image.Image:
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]
width_tile_size = min(self.color_map_tile_size, width)
height_tile_size = min(self.color_map_tile_size, height)
color_map = cv2.resize(
np_image,
(width // width_tile_size, height // height_tile_size),
interpolation=cv2.INTER_CUBIC,
)
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
color_map = Image.fromarray(color_map)
return color_map
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
DEPTH_ANYTHING_MODELS = {
"large": "LiheYoung/depth-anything-large-hf",
"base": "LiheYoung/depth-anything-base-hf",
"small": "LiheYoung/depth-anything-small-hf",
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
}
@invocation(
"depth_anything_image_processor",
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.1.3",
classification=Classification.Deprecated,
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small_v2", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def load_depth_anything(model_path: Path):
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
# Resizing to user target specified size
new_height = int(image.size[1] * (self.resolution / image.size[0]))
depth_map = depth_map.resize((self.resolution, new_height))
return depth_map
@invocation(
"dw_openpose_image_processor",
title="DW Openpose Image Processor",
tags=["controlnet", "dwpose", "openpose"],
category="controlnet",
version="1.1.1",
classification=Classification.Deprecated,
)
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Generates an openpose pose from an image using DWPose"""
draw_body: bool = InputField(default=True)
draw_face: bool = InputField(default=False)
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose(
image,
draw_face=self.draw_face,
draw_hands=self.draw_hands,
draw_body=self.draw_body,
resolution=self.image_resolution,
)
return processed_image
@invocation(
"heuristic_resize",
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
version="1.0.1",
classification=Classification.Prototype,
)
class HeuristicResizeInvocation(BaseInvocation):
"""Resize an image using a heuristic method. Preserves edge maps."""
image: ImageField = InputField(description="The image to resize")
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_img = pil_to_np(image)
np_resized = heuristic_resize(np_img, (self.width, self.height))
resized = np_to_pil(np_resized)
image_dto = context.images.save(image=resized)
return ImageOutput.build(image_dto)

View File

@@ -22,7 +22,7 @@ from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.controlnet import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
DenoiseMaskField,

View File

@@ -4,7 +4,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
@invocation(
@@ -25,20 +25,20 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_det())
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_pose())
loaded_session_det = context.models.load_local_model(
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
onnx_det_path, DWOpenposeDetector.create_onnx_inference_session
)
loaded_session_pose = context.models.load_local_model(
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
onnx_pose_path, DWOpenposeDetector.create_onnx_inference_session
)
with loaded_session_det as session_det, loaded_session_pose as session_pose:
assert isinstance(session_det, ort.InferenceSession)
assert isinstance(session_pose, ort.InferenceSession)
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
detector = DWOpenposeDetector(session_det=session_det, session_pose=session_pose)
detected_image = detector.run(
image,
draw_face=self.draw_face,

View File

@@ -14,7 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField, ControlNetInvocation
from invokeai.app.invocations.controlnet import ControlField, ControlNetInvocation
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation
from invokeai.app.invocations.fields import (
FieldDescriptions,

View File

@@ -9,7 +9,7 @@ from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.controlnet import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,

View File

@@ -21,10 +21,16 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
:param output_dir: The folder where the serialized objects will be stored
:param safe_globals: A list of types to be added to the safe globals for torch serialization
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
"""
def __init__(self, output_dir: Path, ephemeral: bool = False):
def __init__(
self,
output_dir: Path,
safe_globals: list[type],
ephemeral: bool = False,
) -> None:
super().__init__()
self._ephemeral = ephemeral
self._base_output_dir = output_dir
@@ -42,6 +48,8 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
self.__obj_class_name: Optional[str] = None
torch.serialization.add_safe_globals(safe_globals) if safe_globals else None
def load(self, name: str) -> T:
file_path = self._get_path(name)
try:

View File

@@ -33,12 +33,7 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def enqueue_batch(
self,
queue_id: str,
batch: Batch,
prepend: bool,
) -> Coroutine[Any, Any, EnqueueBatchResult]:
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
"""Enqueues all permutations of a batch for execution."""
pass

View File

@@ -157,28 +157,6 @@ class Batch(BaseModel):
v.validate_self()
return v
def get_session_count(self) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
If the session count has already been calculated, return the cached value.
"""
if not self.data:
return self.runs
data = []
for batch_datum_list in self.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * self.runs
model_config = ConfigDict(
json_schema_extra={
"required": [
@@ -269,6 +247,10 @@ class SessionQueueItemWithoutGraph(BaseModel):
default=False,
description="Whether this queue item is an API validation run.",
)
published_workflow_id: Optional[str] = Field(
default=None,
description="The ID of the published workflow associated with this queue item",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
@@ -574,6 +556,28 @@ def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
return batch.runs
data = []
for batch_datum_list in batch.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)

View File

@@ -28,6 +28,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -117,8 +118,7 @@ class SqliteSessionQueue(SessionQueueBase):
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = batch.get_session_count()
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,

View File

@@ -1,204 +0,0 @@
from copy import deepcopy
from typing import Any, Callable, TypeAlias, get_args
from pydantic import BaseModel, ConfigDict, create_model
from pydantic.fields import FieldInfo
from invokeai.app.services.session_queue.session_queue_common import FieldIdentifier
from invokeai.app.services.shared.graph import Graph
DictOfFieldsMetadata: TypeAlias = dict[str, tuple[type[Any], FieldInfo]]
class ComposedFieldMetadata(BaseModel):
node_id: str
field_name: str
field_type_class_name: str
def dedupe_field_name(field_metadata: DictOfFieldsMetadata, field_name: str) -> str:
"""Given a field name, return a name that is not already in the field metadata.
If the field name is not in the field metadata, return the field name.
If the field name is in the field metadata, generate a new name by appending an underscore and integer to the field name, starting with 2.
"""
if field_name not in field_metadata:
return field_name
i = 2
while True:
new_field_name = f"{field_name}_{i}"
if new_field_name not in field_metadata:
return new_field_name
i += 1
def compose_model_from_fields(
g: Graph,
field_identifiers: list[FieldIdentifier],
composed_model_class_name: str = "ComposedModel",
model_field_overrides: dict[type[Any], tuple[type[Any], FieldInfo]] | None = None,
model_field_filter: Callable[[type[Any]], bool] | None = None,
) -> type[BaseModel]:
"""Given a graph and a list of field identifiers, create a new pydantic model composed of the fields of the nodes in the graph.
The resultant model can be used to validate a JSON payload that contains the fields of the nodes in the graph, or generate an
OpenAPI schema for the model.
Args:
g: The graph containing the nodes whose fields will be composed into the new model.
field_identifiers: A list of FieldIdentifier instances, each representing a field on a node in the graph.
model_name: The name of the composed model.
kind: The kind of model to create. Must be "input" or "output". Defaults to "input".
model_field_overrides: A dictionary mapping type annotations to tuples of (new_type_annotation, new_field_info).
This can be used to override the type annotation and field info of a field in the composed model. For example,
if `ModelIdentifierField` should be replaced by a string, the dictionary would look like this:
```python
{ModelIdentifierField: (str, Field(description="The model id."))}
```
model_field_filter: A function that takes a type annotation and returns True if the field should be included in the composed model.
If None, all fields will be included. For example, to omit `BoardField` fields, the filter would look like this:
```python
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField}
```
Optional fields - or any other complex field types like unions - must be explicitly included in the filter. For example,
to omit `BoardField` _and_ `Optional[BoardField]`:
```python
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField, Optional[BoardField]}
```
Note that the filter is applied to the type annotation of the field, not the field itself.
Example usage:
```python
# Create some nodes.
add_node = AddInvocation()
sub_node = SubtractInvocation()
color_node = ColorInvocation()
# Create a graph with the nodes.
g = Graph(
nodes={
add_node.id: add_node,
sub_node.id: sub_node,
color_node.id: color_node,
}
)
# Select the fields to compose.
fields_to_compose = [
FieldIdentifier(node_id=add_node.id, field_name="a"),
FieldIdentifier(node_id=sub_node.id, field_name="a"), # this will be deduped to "a_2"
FieldIdentifier(node_id=add_node.id, field_name="b"),
FieldIdentifier(node_id=color_node.id, field_name="color"),
]
# Compose the model from the fields.
composed_model = compose_model_from_fields(g, fields_to_compose, model_name="ComposedModel")
# Generate the OpenAPI schema for the model.
json_schema = composed_model.model_json_schema(mode="validation")
```
"""
# Temp storage for the composed fields. Pydantic needs a type annotation and instance of FieldInfo to create a model.
field_metadata: DictOfFieldsMetadata = {}
model_field_overrides = model_field_overrides or {}
# The list of required fields. This is used to ensure the composed model's fields retain their required state.
required: list[str] = []
for field_identifier in field_identifiers:
node_id = field_identifier.node_id
field_name = field_identifier.field_name
# Pull the node instance from the graph so we can introspect it.
node_instance = g.nodes[node_id]
if field_identifier.kind == "input":
# Get the class of the node. This will be a BaseInvocation subclass, e.g. AddInvocation, DenoiseLatentsInvocation, etc.
pydantic_model = type(node_instance)
else:
# Otherwise the the type of the node's output class. This will be a BaseInvocationOutput subclass, e.g. IntegerOutput, ImageOutput, etc.
pydantic_model = type(node_instance).get_output_annotation()
# Get the FieldInfo instance for the field. For example:
# a: int = Field(..., description="The first number to add.")
# ^^^^^ The return value of this Field call is the FieldInfo instance (Field is a function).
og_field_info = pydantic_model.model_fields[field_name]
# Get the type annotation of the field. For example:
# a: int = Field(..., description="The first number to add.")
# ^^^ this is the type annotation
og_field_type = og_field_info.annotation
# Apparently pydantic allows fields without type annotations. We don't support that.
assert og_field_type is not None, (
f"{field_identifier.kind.capitalize()} field {field_name} on node {node_id} has no type annotation."
)
# Now that we have the type annotation, we can apply the filter to see if we should include the field in the composed model.
if model_field_filter and not model_field_filter(og_field_type):
continue
# Ok, we want this type of field. Retrieve any overrides for the field type. This is a dictionary mapping
# type annotations to tuples of (override_type_annotation, override_field_info).
(override_field_type, override_field_info) = model_field_overrides.get(og_field_type, (None, None))
# The override tuple's first element is the new type annotation, if it exists.
composed_field_type = override_field_type if override_field_type is not None else og_field_type
# Create a deep copy of the FieldInfo instance (or override it if it exists) so we can modify it without
# affecting the original. This is important because we are going to modify the FieldInfo instance and
# don't want to affect the original model's schema.
composed_field_info = deepcopy(override_field_info if override_field_info is not None else og_field_info)
json_schema_extra = og_field_info.json_schema_extra if isinstance(og_field_info.json_schema_extra, dict) else {}
# The field's original required state is stored in the json_schema_extra dict. For more information about why,
# see the definition of `InputField` in invokeai/app/invocations/fields.py.
#
# Add the field to the required list if it is required, which we will use when creating the composed model.
if json_schema_extra.get("orig_required", False):
required.append(field_name)
# Invocation fields have some extra metadata, used by the UI to render the field in the frontend. This data is
# included in the OpenAPI schema for each field. For example, we add a "ui_order" field, which the UI uses to
# sort fields when rendering them.
#
# The composed model's OpenAPI schema should not have this information. It should only have a standard OpenAPI
# schema for the field. We need to strip out the UI-specific metadata from the FieldInfo instance before adding
# it to the composed model.
#
# We will replace this metadata with some custom metadata:
# - node_id: The id of the node that this field belongs to.
# - field_name: The name of the field on the node.
# - original_data_type: The original data type of the field.
field_type_class = get_args(og_field_type)[0] if hasattr(og_field_type, "__args__") else og_field_type
field_type_class_name = field_type_class.__name__
composed_field_metadata = ComposedFieldMetadata(
node_id=node_id,
field_name=field_name,
field_type_class_name=field_type_class_name,
)
composed_field_info.json_schema_extra = {
"composed_field_extra": composed_field_metadata.model_dump(),
}
# Override the name, title and description if overrides are provided. Dedupe the field name if necessary.
final_field_name = dedupe_field_name(field_metadata, field_name)
# Store the field metadata.
field_metadata.update({final_field_name: (composed_field_type, composed_field_info)})
# Splat in the composed fields to create the new model. There are type errors here because create_model's kwargs are not typed,
# and for some reason pydantic's ConfigDict doesn't like lists in `json_schema_extra`. Anyways, the inputs here are correct.
return create_model(
composed_model_class_name,
**field_metadata,
__config__=ConfigDict(json_schema_extra={"required": required}),
)

View File

@@ -65,9 +65,6 @@ def apply_monkeypatches() -> None:
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
def register_mime_types() -> None:
"""Register additional mime types for windows."""

View File

@@ -5,62 +5,14 @@ import huggingface_hub
import numpy as np
import onnxruntime as ort
import torch
from controlnet_aux.util import resize_image
from PIL import Image
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
from invokeai.backend.image_util.util import np_to_pil
from invokeai.backend.util.devices import TorchDevice
DWPOSE_MODELS = {
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
def draw_pose(
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
H: int,
W: int,
draw_face: bool = True,
draw_body: bool = True,
draw_hands: bool = True,
resolution: int = 512,
) -> Image.Image:
bodies = pose["bodies"]
faces = pose["faces"]
hands = pose["hands"]
assert isinstance(bodies, dict)
candidate = bodies["candidate"]
assert isinstance(bodies, dict)
subset = bodies["subset"]
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body:
canvas = draw_bodypose(canvas, candidate, subset)
if draw_hands:
assert isinstance(hands, np.ndarray)
canvas = draw_handpose(canvas, hands)
if draw_face:
assert isinstance(hands, np.ndarray)
canvas = draw_facepose(canvas, faces) # type: ignore
dwpose_image: Image.Image = resize_image(
canvas,
resolution,
)
dwpose_image = Image.fromarray(dwpose_image)
return dwpose_image
class DWOpenposeDetector:
"""
@@ -68,62 +20,6 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose
"""
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
def __call__(
self,
image: Image.Image,
draw_face: bool = False,
draw_body: bool = True,
draw_hands: bool = False,
resolution: int = 512,
) -> Image.Image:
np_image = np.array(image)
H, W, C = np_image.shape
with torch.no_grad():
candidate, subset = self.pose_estimation(np_image)
nums, keys, locs = candidate.shape
candidate[..., 0] /= float(W)
candidate[..., 1] /= float(H)
body = candidate[:, :18].copy()
body = body.reshape(nums * 18, locs)
score = subset[:, :18]
for i in range(len(score)):
for j in range(len(score[i])):
if score[i][j] > 0.3:
score[i][j] = int(18 * i + j)
else:
score[i][j] = -1
un_visible = subset < 0.3
candidate[un_visible] = -1
# foot = candidate[:, 18:24]
faces = candidate[:, 24:92]
hands = candidate[:, 92:113]
hands = np.vstack([hands, candidate[:, 113:]])
bodies = {"candidate": body, "subset": score}
pose = {"bodies": bodies, "hands": hands, "faces": faces}
return draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
)
class DWOpenposeDetector2:
"""
Code from the original implementation of the DW Openpose Detector.
Credits: https://github.com/IDEA-Research/DWPose
This implementation is similar to DWOpenposeDetector, with some alterations to allow the onnx models to be loaded
and managed by the model manager.
"""
hf_repo_id = "yzd-v/DWPose"
hf_filename_onnx_det = "yolox_l.onnx"
hf_filename_onnx_pose = "dw-ll_ucoco_384.onnx"
@@ -213,7 +109,7 @@ class DWOpenposeDetector2:
bodies = {"candidate": body, "subset": score}
pose = {"bodies": bodies, "hands": hands, "faces": faces}
return DWOpenposeDetector2.draw_pose(
return DWOpenposeDetector.draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body
)

View File

@@ -3,7 +3,6 @@
import math
import cv2
import matplotlib
import numpy as np
import numpy.typing as npt
@@ -127,11 +126,13 @@ def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
x2 = int(x2 * W)
y2 = int(y2 * H)
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
hsv_color = np.array([[[ie / float(len(edges)) * 180, 255, 255]]], dtype=np.uint8)
rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0, 0]
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
rgb_color.tolist(),
thickness=2,
)

View File

@@ -1,44 +0,0 @@
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
# Modified pathing to suit Invoke
from pathlib import Path
import numpy as np
import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
from invokeai.backend.util.devices import TorchDevice
config = get_config()
class Wholebody:
def __init__(self, onnx_det: Path, onnx_pose: Path):
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
def __call__(self, oriImg):
det_result = inference_detector(self.session_det, oriImg)
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
# compute neck joint
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
# neck score when visualizing pred
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
keypoints_info = new_keypoints_info
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
return keypoints, scores

View File

@@ -69,6 +69,9 @@ class SD3ConditioningInfo:
@dataclass
class ConditioningFieldData:
# If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
# invokeai/app/api/dependencies.py, adding the types to the list of safe globals. If you do not, torch will be
# unable to deserialize the object and will raise an error.
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]

View File

@@ -1,245 +0,0 @@
import math
import diffusers
import torch
if torch.backends.mps.is_available():
torch.empty = torch.zeros
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
if weight is not None:
weight = weight.float()
if bias is not None:
bias = bias.float()
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
_torch_tensor_permute = torch.Tensor.permute
def new_torch_tensor_permute(input, *dims):
result = _torch_tensor_permute(input, *dims)
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute
_torch_lerp = torch.lerp
def new_torch_lerp(input, end, weight, *, out=None):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
end = end.float()
if isinstance(weight, torch.Tensor):
weight = weight.float()
if out is not None:
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
else:
out_fp32 = None
result = _torch_lerp(input, end, weight, out=out_fp32)
if out is not None:
out.copy_(out_fp32.half())
del out_fp32
return result.half()
else:
return _torch_lerp(input, end, weight, out=out)
torch.lerp = new_torch_lerp
_torch_interpolate = torch.nn.functional.interpolate
def new_torch_interpolate(
input,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
if input.device.type == "mps" and input.dtype == torch.float16:
return _torch_interpolate(
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
).half()
else:
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_torch_interpolate
# TODO: refactor it
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
class ChunkedSlicedAttnProcessor:
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
assert isinstance(slice_size, int)
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
self.slice_size = slice_size
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
if self.slice_size != 1 or attn.upcast_attention:
return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
chunk_tmp_tensor = torch.empty(
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
self.get_attention_scores_chunked(
attn,
query_slice,
key_slice,
attn_mask_slice,
hidden_states[start_idx:end_idx],
value[start_idx:end_idx],
chunk_tmp_tensor,
)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
# batch size = 1
assert query.shape[0] == 1
assert key.shape[0] == 1
assert value.shape[0] == 1
assert hidden_states.shape[0] == 1
# dtype = query.dtype
if attn.upcast_attention:
query = query.float()
key = key.float()
# out_item_size = query.dtype.itemsize
# if attn.upcast_attention:
# out_item_size = torch.float32.itemsize
out_item_size = query.element_size()
if attn.upcast_attention:
out_item_size = 4
chunk_size = 2**29
out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
chunk_step = max(1, int(query.shape[1] / chunks_count))
key = key.transpose(-1, -2)
def _get_chunk_view(tensor, start, length):
if start + length > tensor.shape[1]:
length = tensor.shape[1] - start
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
return tensor[:, start : start + length]
for chunk_pos in range(0, query.shape[1], chunk_step):
if attention_mask is not None:
torch.baddbmm(
_get_chunk_view(attention_mask, chunk_pos, chunk_step),
_get_chunk_view(query, chunk_pos, chunk_step),
key,
beta=1,
alpha=attn.scale,
out=chunk,
)
else:
torch.baddbmm(
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
_get_chunk_view(query, chunk_pos, chunk_step),
key,
beta=0,
alpha=attn.scale,
out=chunk,
)
chunk = chunk.softmax(dim=-1)
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
# del chunk
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor

View File

@@ -62,7 +62,7 @@
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.6.1",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.5.1",
"@xyflow/react": "^12.5.3",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.0.0",
@@ -162,5 +162,6 @@
},
"engines": {
"pnpm": "8"
}
},
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
}

View File

@@ -36,8 +36,8 @@ dependencies:
specifier: ^1.3.0
version: 1.3.0
'@xyflow/react':
specifier: ^12.5.1
version: 12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
specifier: ^12.5.3
version: 12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
async-mutex:
specifier: ^0.5.0
version: 0.5.0
@@ -3951,8 +3951,8 @@ packages:
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
dev: false
/@xyflow/react@12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-jMKQVqGwCz0x6pUyvxTIuCMbyehfua7CfEEWDj29zQSHigQpCy0/5d8aOmZrqK4cwur/pVHLQomT6Rm10gXfHg==}
/@xyflow/react@12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-saovy/aQRoW8qQoIqMFUtmC3F6oEV7n6+J1pVbhSG45NI/hOFvK0qozsIPKqX5Va6lGQnkl/o53NHLja3NiweQ==}
peerDependencies:
react: '>=17'
react-dom: '>=17'
@@ -9123,8 +9123,8 @@ packages:
react: 18.3.1
dev: false
/use-sync-external-store@1.4.0(react@18.3.1):
resolution: {integrity: sha512-9WXSPC5fMv61vaupRkCKCxsPxBocVnwakBEkMIHHpkTTg6icbJtg6jzgtLDm4bl3cSHAca52rYWih0k4K3PfHw==}
/use-sync-external-store@1.5.0(react@18.3.1):
resolution: {integrity: sha512-Rb46I4cGGVBmjamjphe8L/UnvJD+uPPtTkNvX5mZgqdbavhI4EbgIWJiIHXJ8bc/i9EQGPRh4DwEURJ552Do0A==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
dependencies:
@@ -9592,5 +9592,5 @@ packages:
dependencies:
'@types/react': 18.3.11
react: 18.3.1
use-sync-external-store: 1.4.0(react@18.3.1)
use-sync-external-store: 1.5.0(react@18.3.1)
dev: false

View File

@@ -116,7 +116,10 @@
"combinatorial": "Kombinatorisch",
"saveChanges": "Änderungen speichern",
"error_withCount_one": "{{count}} Fehler",
"error_withCount_other": "{{count}} Fehler"
"error_withCount_other": "{{count}} Fehler",
"value": "Wert",
"label": "Label",
"systemInformation": "Systeminformationen"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -695,7 +698,10 @@
"guidance": "Führung",
"coherenceMode": "Modus",
"recallMetadata": "Metadaten abrufen",
"gaussianBlur": "Gaußsche Unschärfe"
"gaussianBlur": "Gaußsche Unschärfe",
"sendToUpscale": "An Hochskalieren senden",
"useCpuNoise": "CPU-Rauschen verwenden",
"sendToCanvas": "An Leinwand senden"
},
"settings": {
"displayInProgress": "Zwischenbilder anzeigen",
@@ -1328,7 +1334,8 @@
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
"loadingTemplates": "Lade {{name}}",
"missingSourceOrTargetHandle": "Fehlender Quell- oder Zielgriff",
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten"
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten",
"showEdgeLabelsHelp": "Beschriftungen an Kanten anzeigen, um die verknüpften Knoten zu kennzeichnen"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -1706,6 +1706,7 @@
"noRecentWorkflows": "No Recent Workflows",
"private": "Private",
"shared": "Shared",
"published": "Published",
"browseWorkflows": "Browse Workflows",
"deselectAll": "Deselect All",
"recommended": "Recommended For You",
@@ -1813,7 +1814,9 @@
"publishedWorkflowIsLocked": "Published workflow is locked",
"publishingValidationRun": "Publishing Validation Run",
"publishingValidationRunInProgress": "Publishing validation run in progress.",
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow."
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow.",
"selectingOutputNode": "Selecting output node",
"selectingOutputNodeDesc": "Click a node to select it as the workflow's output node."
}
},
"controlLayers": {

View File

@@ -115,7 +115,8 @@
"error_withCount_many": "{{count}} errori",
"error_withCount_other": "{{count}} errori",
"value": "Valore",
"label": "Etichetta"
"label": "Etichetta",
"systemInformation": "Informazioni di sistema"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -715,7 +716,8 @@
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
"collectionEmpty": "raccolta vuota"
"collectionEmpty": "raccolta vuota",
"batchNodeCollectionSizeMismatchNoGroupId": "Dimensione della raccolta di gruppo nel Lotto non corrisponde"
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -2365,8 +2367,9 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Flussi di lavoro: nuova e migliorata libreria dei flussi di lavoro.",
"FLUX: supporto per FLUX Redux e FLUX Fill in Flussi di lavoro e Tela."
"Flussi di lavoro: supporto per menu a discesa di stringhe personalizzate nel Generatore di Flussi di lavoro.",
"FLUX: supporto per FLUX Fill in Flussi di lavoro e Tela.",
"LLaVA OneVision VLLM: supporto beta nei flussi di lavoro."
]
},
"system": {

View File

@@ -237,7 +237,10 @@
"row": "Hàng",
"board": "Bảng",
"saveChanges": "Lưu Thay Đổi",
"error_withCount_other": "{{count}} lỗi"
"error_withCount_other": "{{count}} lỗi",
"value": "Giá Trị",
"label": "Nhãn Tên",
"systemInformation": "Thông Tin Hệ Thống"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -2300,7 +2303,10 @@
"minimum": "Tối Thiểu",
"maximum": "Tối Đa",
"containerRowLayout": "Hộp Chứa (bố cục hàng)",
"containerColumnLayout": "Hộp Chứa (bố cục cột)"
"containerColumnLayout": "Hộp Chứa (bố cục cột)",
"resetOptions": "Tải Lại Lựa Chọn",
"addOption": "Thêm Lựa Chọn",
"dropdown": "Danh Sách Thả Xuống"
},
"yourWorkflows": "Workflow Của Bạn",
"browseWorkflows": "Khám Phá Workflow",
@@ -2316,7 +2322,8 @@
"view": "Xem",
"deselectAll": "Huỷ Chọn Tất Cả",
"noRecentWorkflows": "Không Có Workflows Gần Đây",
"recommended": "Có Thể Bạn Sẽ Cần"
"recommended": "Có Thể Bạn Sẽ Cần",
"emptyStringPlaceholder": "<xâu ký tự trống>"
},
"upscaling": {
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
@@ -2352,8 +2359,9 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Workflow: Thư Viện Workflow mới và đã được cải tiến.",
"FLUX: Hỗ trợ FLUX Redux & FLUX Fill trong Workflow và Canvas."
"Workflow: Hỗ trợ xâu ký tự thả xuống tùy chỉnh trong Trình Tạo Vùng Nhập.",
"FLUX: Hỗ trợ FLUX Fill trong Workflow và Canvas.",
"LLaVA OneVision VLLM: Hỗ trợ phiên bản Beta trong Workflow."
]
},
"upsell": {

View File

@@ -28,8 +28,7 @@ export type AppFeature =
| 'starterModels'
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'deployWorkflow';
| 'cancelAndClearAll';
/**
* A disable-able Stable Diffusion feature
*/
@@ -75,6 +74,7 @@ export type AppConfig = {
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@@ -49,7 +49,11 @@ export const useGalleryHotkeys = () => {
useRegisteredHotkeys({
id: 'galleryNavLeft',
category: 'gallery',
callback: () => {
callback: (e) => {
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
return;
}
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
goPrev('arrow');
return;
@@ -71,7 +75,11 @@ export const useGalleryHotkeys = () => {
useRegisteredHotkeys({
id: 'galleryNavRight',
category: 'gallery',
callback: () => {
callback: (e) => {
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
return;
}
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
goNext('arrow');
return;

View File

@@ -3,7 +3,11 @@ import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
import { $isInPublishFlow, useIsValidationRunInProgress } from 'features/nodes/components/sidePanel/workflow/publish';
import {
$isInPublishFlow,
$isSelectingOutputNode,
useIsValidationRunInProgress,
} from 'features/nodes/components/sidePanel/workflow/publish';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
import { memo } from 'react';
@@ -14,6 +18,7 @@ export const TopLeftPanel = memo(() => {
const isInPublishFlow = useStore($isInPublishFlow);
const isPublished = useAppSelector(selectWorkflowIsPublished);
const isValidationRunInProgress = useIsValidationRunInProgress();
const isSelectingOutputNode = useStore($isSelectingOutputNode);
const { t } = useTranslation();
return (
@@ -34,11 +39,16 @@ export const TopLeftPanel = memo(() => {
{t('workflows.builder.publishingValidationRunInProgress')}
</AlertDescription>
)}
{isInPublishFlow && !isValidationRunInProgress && (
{isInPublishFlow && !isValidationRunInProgress && !isSelectingOutputNode && (
<AlertDescription whiteSpace="pre-wrap">
{t('workflows.builder.workflowLockedDuringPublishing')}
</AlertDescription>
)}
{isInPublishFlow && !isValidationRunInProgress && isSelectingOutputNode && (
<AlertDescription whiteSpace="pre-wrap">
{t('workflows.builder.selectingOutputNodeDesc')}
</AlertDescription>
)}
{isPublished && (
<AlertDescription whiteSpace="pre-wrap">
{t('workflows.builder.workflowLockedPublished')}

View File

@@ -67,7 +67,7 @@ type NodeFieldDndData = {
fieldName: string;
fieldTemplate: FieldInputTemplate;
};
export const buildNodeFieldDndData = (
const buildNodeFieldDndData = (
nodeId: string,
fieldName: string,
fieldTemplate: FieldInputTemplate

View File

@@ -1,3 +1,4 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import {
Button,
ButtonGroup,
@@ -38,12 +39,12 @@ import { selectHasBatchOrGeneratorNodes } from 'features/nodes/store/selectors';
import { selectIsWorkflowSaved } from 'features/nodes/store/workflowSlice';
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiLightningFill, PiSignOutBold, PiXBold } from 'react-icons/pi';
import { PiArrowLineRightBold, PiLightningFill, PiXBold } from 'react-icons/pi';
import { serializeError } from 'serialize-error';
import { assert } from 'tsafe';
@@ -53,7 +54,6 @@ export const PublishWorkflowPanelContent = memo(() => {
return (
<Flex flexDir="column" gap={2} h="full">
<ButtonGroup isAttached={false} size="sm" variant="ghost">
<SelectOutputNodeButton />
<Spacer />
<CancelPublishButton />
<PublishWorkflowButton />
@@ -68,38 +68,41 @@ export const PublishWorkflowPanelContent = memo(() => {
</Flex>
);
});
PublishWorkflowPanelContent.displayName = 'DeployWorkflowPanelContent';
PublishWorkflowPanelContent.displayName = 'PublishWorkflowPanelContent';
const OutputFields = memo(() => {
const { t } = useTranslation();
const outputNodeId = useStore($outputNodeId);
if (!outputNodeId) {
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
<Flex alignItems="center">
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
<Spacer />
<SelectOutputNodeButton variant="link" size="sm" />
</Flex>
<Divider />
{!outputNodeId && (
<Text fontWeight="semibold" color="error.300">
{t('workflows.builder.noOutputNodeSelected')}
</Text>
</Flex>
);
}
return <OutputFieldsContent outputNodeId={outputNodeId} />;
)}
{outputNodeId && <OutputFieldsContent outputNodeId={outputNodeId} />}
</Flex>
);
});
OutputFields.displayName = 'OutputFields';
const OutputFieldsContent = memo(({ outputNodeId }: { outputNodeId: string }) => {
const { t } = useTranslation();
const outputFieldNames = useOutputFieldNames(outputNodeId);
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
<Divider />
<>
{outputFieldNames.map((fieldName) => (
<NodeOutputFieldPreview key={`${outputNodeId}-${fieldName}`} nodeId={outputNodeId} fieldName={fieldName} />
))}
</Flex>
</>
);
});
OutputFieldsContent.displayName = 'OutputFieldsContent';
@@ -152,7 +155,7 @@ const UnpublishableInputFields = memo(() => {
});
UnpublishableInputFields.displayName = 'UnpublishableInputFields';
const SelectOutputNodeButton = memo(() => {
const SelectOutputNodeButton = memo((props: ButtonProps) => {
const { t } = useTranslation();
const outputNodeId = useStore($outputNodeId);
const isSelectingOutputNode = useStore($isSelectingOutputNode);
@@ -161,8 +164,18 @@ const SelectOutputNodeButton = memo(() => {
$isSelectingOutputNode.set(true);
}, []);
return (
<Button leftIcon={<PiSignOutBold />} isDisabled={isSelectingOutputNode} onClick={onClick}>
{outputNodeId ? t('workflows.builder.changeOutputNode') : t('workflows.builder.selectOutputNode')}
<Button
leftIcon={<PiArrowLineRightBold />}
isDisabled={isSelectingOutputNode}
tooltip={isSelectingOutputNode ? t('workflows.builder.selectingOutputNodeDesc') : undefined}
onClick={onClick}
{...props}
>
{isSelectingOutputNode
? t('workflows.builder.selectingOutputNode')
: outputNodeId
? t('workflows.builder.changeOutputNode')
: t('workflows.builder.selectOutputNode')}
</Button>
);
});
@@ -192,6 +205,7 @@ const PublishWorkflowButton = memo(() => {
const outputNodeId = useStore($outputNodeId);
const isSelectingOutputNode = useStore($isSelectingOutputNode);
const inputs = usePublishInputs();
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
const projectUrl = useStore($projectUrl);
@@ -240,9 +254,11 @@ const PublishWorkflowButton = memo(() => {
<Button
leftIcon={<PiLightningFill />}
isDisabled={
!isReadyToDoValidationRun ||
!allowPublishWorkflows ||
!isReadyToEnqueue ||
!isWorkflowSaved ||
hasBatchOrGeneratorNodes ||
!isReadyToDoValidationRun ||
!(outputNodeId !== null && !isSelectingOutputNode)
}
onClick={onClick}
@@ -307,7 +323,7 @@ NodeOutputFieldPreview.displayName = 'NodeOutputFieldPreview';
export const StartPublishFlowButton = memo(() => {
const { t } = useTranslation();
const deployWorkflowIsEnabled = useFeatureStatus('deployWorkflow');
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
const isReadyToEnqueue = useStore($isReadyToEnqueue);
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
@@ -331,7 +347,7 @@ export const StartPublishFlowButton = memo(() => {
leftIcon={<PiLightningFill />}
variant="ghost"
size="sm"
isDisabled={!deployWorkflowIsEnabled || !isWorkflowSaved || hasBatchOrGeneratorNodes}
isDisabled={!allowPublishWorkflows || !isReadyToEnqueue || !isWorkflowSaved || hasBatchOrGeneratorNodes}
>
{t('workflows.builder.publish')}
</Button>

View File

@@ -10,7 +10,7 @@ export const LockedWorkflowIcon = memo(() => {
<Tooltip label={t('workflows.builder.publishedWorkflowsLocked')} closeOnScroll>
<IconButton
size="sm"
cursor='not-allowed'
cursor="not-allowed"
variant="link"
alignSelf="stretch"
aria-label={t('workflows.builder.publishedWorkflowsLocked')}

View File

@@ -26,7 +26,7 @@ import {
workflowLibraryTagToggled,
workflowLibraryViewChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
@@ -40,7 +40,7 @@ export const WorkflowLibrarySideNav = () => {
const { t } = useTranslation();
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
const view = useAppSelector(selectWorkflowLibraryView);
const deployWorkflow = useFeatureStatus('deployWorkflow');
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
return (
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
@@ -60,8 +60,8 @@ export const WorkflowLibrarySideNav = () => {
</Flex>
</Collapse>
)}
{deployWorkflow && (
<WorkflowLibraryViewButton view="published">{t('workflows.publishedWorkflows')}</WorkflowLibraryViewButton>
{allowPublishWorkflows && (
<WorkflowLibraryViewButton view="published">{t('workflows.published')}</WorkflowLibraryViewButton>
)}
</Flex>
<Flex h="full" minH={0} overflow="hidden" flexDir="column">

View File

@@ -1,7 +1,8 @@
import { Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { WorkflowBuilder } from 'features/nodes/components/sidePanel/builder/WorkflowBuilder';
import { StartPublishFlowButton } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -10,7 +11,7 @@ import WorkflowJSONTab from './WorkflowJSONTab';
const WorkflowFieldsLinearViewPanel = () => {
const { t } = useTranslation();
const deployWorkflowIsEnabled = useFeatureStatus('deployWorkflow');
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
return (
<Tabs variant="enclosed" display="flex" w="full" h="full" flexDir="column">
<TabList>
@@ -18,7 +19,7 @@ const WorkflowFieldsLinearViewPanel = () => {
<Tab>{t('common.details')}</Tab>
<Tab>JSON</Tab>
<Spacer />
{deployWorkflowIsEnabled && <StartPublishFlowButton />}
{allowPublishWorkflows && <StartPublishFlowButton />}
</TabList>
<TabPanels h="full" pt={2}>

View File

@@ -1,9 +0,0 @@
import { useMemo } from 'react';
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
export const useInputFieldTemplateTitleSafe = (nodeId: string, fieldName: string): string => {
const template = useNodeTemplateOrThrow(nodeId);
const title = useMemo(() => template.inputs[fieldName]?.title ?? '', [fieldName, template.inputs]);
return title;
};

View File

@@ -1,22 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
/**
* Gets the user-defined description of an input field for a given node.
*
* If the node doesn't exist or is not an invocation node, an error is thrown.
*
* @param nodeId The ID of the node
* @param fieldName The name of the field
*/
export const useInputFieldUserDescriptionOrThrow = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() => createSelector(selectNodesSlice, (nodes) => selectFieldInputInstance(nodes, nodeId, fieldName).description),
[fieldName, nodeId]
);
const description = useAppSelector(selector);
return description;
};

View File

@@ -470,31 +470,8 @@ export const nodesSlice = createSlice({
builder.addCase(workflowLoaded, (state, action) => {
const { nodes, edges } = action.payload;
const changes: NodeChange<AnyNode>[] = [];
for (const node of nodes) {
if (node.type === 'notes') {
changes.push({
type: 'add',
item: {
...SHARED_NODE_PROPERTIES,
...node,
},
});
} else if (node.type === 'invocation') {
changes.push({
type: 'add',
item: {
...SHARED_NODE_PROPERTIES,
...node,
},
});
}
}
state.nodes = applyNodeChanges<AnyNode>(changes, []);
state.edges = applyEdgeChanges(
edges.map((edge) => ({ type: 'add', item: edge })),
[]
);
state.nodes = nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node }));
state.edges = edges;
});
},
});

View File

@@ -79,4 +79,4 @@ export const isInvocationOutputSchemaObject = (
export const isInvocationFieldSchema = (
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
): obj is InvocationFieldSchema => !('$ref' in obj);
): obj is InvocationFieldSchema => 'field_kind' in obj;

View File

@@ -148,7 +148,11 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
}
}
}
edges.forEach((edge, i) => {
// Stash invalid edges here to be deleted later
const edgesToDelete = new Set<string>();
for (const edge of edges) {
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
const sourceNode = nodes.find(({ id }) => id === edge.source);
const targetNode = nodes.find(({ id }) => id === edge.target);
@@ -215,8 +219,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
}
if (issues.length) {
// This edge has some issues. Remove it.
delete edges[i];
edgesToDelete.add(edge.id);
const source = edge.type === 'default' ? `${edge.source}.${edge.sourceHandle}` : edge.source;
const target = edge.type === 'default' ? `${edge.source}.${edge.targetHandle}` : edge.target;
warnings.push({
@@ -225,7 +228,10 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
data: edge,
});
}
});
}
// Remove invalid edges
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
// Migrated exposed fields to form elements if they exist and the form does not
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!

View File

@@ -1,3 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { useAppStore } from 'app/store/nanostores/store';
import {
$outputNodeId,
@@ -16,10 +17,13 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
import type { Batch, EnqueueBatchArg } from 'services/api/types';
import { assert } from 'tsafe';
const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
export const useEnqueueWorkflows = () => {
const { getState, dispatch } = useAppStore();
const enqueue = useCallback(
async (prepend: boolean, isApiValidationRun: boolean) => {
dispatch(enqueueRequestedWorkflows());
const state = getState();
const nodesState = selectNodesSlice(state);
const workflow = state.workflow;
@@ -130,9 +134,13 @@ export const useEnqueueWorkflows = () => {
} as const;
});
batchConfig.is_api_validation_run = true;
batchConfig.api_input_fields = api_input_fields;
batchConfig.api_output_fields = api_output_fields;
assert(workflow.id, 'Workflow without ID cannot be used for API validation run');
batchConfig.validation_run_data = {
workflow_id: workflow.id,
input_fields: api_input_fields,
output_fields: api_output_fields,
};
// If the batch is an API validation run, we only want to run it once
batchConfig.batch.runs = 1;

View File

@@ -29,7 +29,6 @@ import type { NodesState, Templates } from 'features/nodes/store/types';
import { getInvocationNodeErrors } from 'features/nodes/store/util/fieldValidators';
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
@@ -149,7 +148,6 @@ export const useReadinessWatcher = () => {
const canvasIsSelectingObject = useStore(canvasManager?.stateApi.$isSegmenting ?? $true);
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
const isInPublishFlow = useStore($isInPublishFlow);
const isPublished = useAppSelector(selectWorkflowIsPublished);
useEffect(() => {
debouncedUpdateReasons(
@@ -189,7 +187,6 @@ export const useReadinessWatcher = () => {
upscale,
workflowSettings,
isInPublishFlow,
isPublished,
]);
};

View File

@@ -21,6 +21,7 @@ const initialConfigState: AppConfig = {
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
@@ -220,4 +221,5 @@ export const selectMetadataFetchDebounce = createConfigSelector((config) => conf
export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models'));
export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload);
export const selectAllowPublishWorkflows = createConfigSelector((config) => config.allowPublishWorkflows);
export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);

View File

@@ -14,7 +14,6 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type LoadWorkflowOptions = {
asCopy?: boolean;
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
@@ -65,12 +64,11 @@ const useLoadImmediate = () => {
if (!dialogState) {
return;
}
const { type, data, onSuccess, onError, onCompleted, asCopy } = dialogState;
const { type, data, onSuccess, onError, onCompleted } = dialogState;
const options = {
onSuccess,
onError,
onCompleted,
asCopy,
};
if (type === 'object') {
await loadWorkflowFromObject(data, options);

View File

@@ -29,7 +29,7 @@ export const useLoadWorkflowFromFile = () => {
const { onSuccess, onError, onCompleted } = options;
try {
const unvalidatedWorkflow = JSON.parse(rawJSON as string);
const validatedWorkflow = await validatedAndLoadWorkflow(unvalidatedWorkflow);
const validatedWorkflow = await validatedAndLoadWorkflow(unvalidatedWorkflow, 'file');
if (!validatedWorkflow) {
reader.abort();

View File

@@ -41,7 +41,7 @@ export const useLoadWorkflowFromImage = () => {
assert(unvalidatedWorkflow !== null, 'No workflow or graph provided');
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow);
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow, 'image');
if (!validatedWorkflow) {
onError?.();

View File

@@ -24,14 +24,13 @@ export const useLoadWorkflowFromLibrary = () => {
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
asCopy?: boolean;
} = {}
) => {
const { onSuccess, onError, onCompleted } = options;
try {
const res = await getWorkflow(workflowId).unwrap();
const validatedWorkflow = await validateAndLoadWorkflow(res.workflow);
const validatedWorkflow = await validateAndLoadWorkflow(res.workflow, 'library');
if (!validatedWorkflow) {
onError?.();

View File

@@ -21,7 +21,7 @@ export const useLoadWorkflowFromObject = () => {
) => {
const { onSuccess, onError, onCompleted } = options;
try {
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow);
const validatedWorkflow = await validateAndLoadWorkflow(unvalidatedWorkflow, 'object');
if (!validatedWorkflow) {
onError?.();

View File

@@ -43,7 +43,10 @@ export const useValidateAndLoadWorkflow = () => {
*
* This function catches all errors. It toasts and logs on success and error.
*/
async (unvalidatedWorkflow: unknown): Promise<WorkflowV3 | null> => {
async (
unvalidatedWorkflow: unknown,
origin: 'file' | 'image' | 'object' | 'library'
): Promise<WorkflowV3 | null> => {
try {
const templates = $templates.get();
const { workflow, warnings } = await validateWorkflow({
@@ -54,8 +57,11 @@ export const useValidateAndLoadWorkflow = () => {
checkModelAccess,
});
if (workflow.is_published) {
//TODO: How to handle this?
if (origin !== 'library') {
// Workflow IDs should always map directly to the workflow in the library. If the workflow is loaded from
// some other source, and has an ID, we should remove it to ensure the app does not treat it as a library workflow.
// For example, when saving a workflow, we might accidentally attempt to save instead of save-as.
delete workflow.id;
}
$nodeExecutionStates.set({});

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
__version__ = "5.9.1"
__version__ = "5.10.0dev3"

14
pins.json Normal file
View File

@@ -0,0 +1,14 @@
{
"python": "3.12",
"torchIndexUrl": {
"win32": {
"cuda": "https://download.pytorch.org/whl/cu126"
},
"linux": {
"cpu": "https://download.pytorch.org/whl/cpu",
"rocm": "https://download.pytorch.org/whl/rocm6.2.4",
"cuda": "https://download.pytorch.org/whl/cu126"
},
"darwin": {}
}
}

View File

@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "InvokeAI"
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
requires-python = ">=3.10, <3.12"
requires-python = ">=3.10, <3.13"
readme = { content-type = "text/markdown", file = "README.md" }
keywords = ["stable-diffusion", "AI"]
dynamic = ["version"]
@@ -33,69 +33,46 @@ classifiers = [
]
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==1.0.1",
"bitsandbytes==0.45.0; sys_platform!='darwin'",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"accelerate",
"bitsandbytes; sys_platform!='darwin'",
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.31.0",
"gguf==0.10.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
"diffusers[torch]",
"gguf",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
"numpy<2.0.0",
"onnx==1.16.1",
"onnxruntime==1.19.2",
"opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3",
"safetensors==0.4.3",
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
"sentencepiece==0.2.0",
"spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch<2.5.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
"torchmetrics",
"torchsde",
"safetensors",
"spandrel",
"torch~=2.6.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
"torchsde", # diffusers needs this for SDE solvers, but it is not an explicit dep of diffusers
"torchvision",
"transformers==4.46.3",
"transformers",
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.1",
"fastapi==0.111.0",
"huggingface-hub==0.26.1",
"pydantic-settings==2.2.1",
"pydantic==2.7.2",
"python-socketio==5.11.1",
"uvicorn[standard]==0.28.0",
"fastapi-events",
"fastapi",
"huggingface-hub",
"pydantic-settings",
"pydantic",
"python-socketio",
"uvicorn[standard]",
# Auxiliary dependencies, pinned only if necessary.
"albumentations",
"blake3",
"click",
"datasets",
"Deprecated",
"dnspython",
"dynamicprompts",
"einops",
"facexlib",
# Exclude 3.9.1 which has a problem on windows, see https://github.com/matplotlib/matplotlib/issues/28551
"matplotlib!=3.9.1",
"npyscreen",
"omegaconf",
"picklescan",
"pillow",
"prompt-toolkit",
"pympler",
"pypatchmatch",
"pyperclip",
"pyreadline3",
"python-multipart",
"requests",
"rich~=13.3",
"scikit-image",
"semver~=3.0.1",
"test-tube",
"windows-curses; sys_platform=='win32'",
"humanize==4.12.1",
]
[project.optional-dependencies]
@@ -127,7 +104,8 @@ dependencies = [
"pytest-datadir",
"requests_testadapter",
"httpx",
"polyfactory==2.19.0"
"polyfactory==2.19.0",
"humanize==4.12.1",
]
[project.scripts]
@@ -207,9 +185,9 @@ exclude = [
".venv*",
"*.ipynb",
"invokeai/backend/image_util/mediapipe_face/", # External code
"invokeai/backend/image_util/mlsd/", # External code
"invokeai/backend/image_util/normal_bae/", # External code
"invokeai/backend/image_util/pidi/", # External code
"invokeai/backend/image_util/mlsd/", # External code
"invokeai/backend/image_util/normal_bae/", # External code
"invokeai/backend/image_util/pidi/", # External code
]
[tool.ruff.lint]

View File

@@ -21,16 +21,18 @@ def count_files(path: Path):
@pytest.fixture
def obj_serializer(tmp_path: Path):
return ObjectSerializerDisk[MockDataclass](tmp_path)
return ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
@pytest.fixture
def fwd_cache(tmp_path: Path):
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
return ObjectSerializerForwardCache(
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass]), max_cache_size=2
)
def test_obj_serializer_disk_initializes(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
assert obj_serializer._output_dir == tmp_path
@@ -70,7 +72,7 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
assert obj_serializer._base_output_dir == tmp_path
assert obj_serializer._output_dir != tmp_path
@@ -78,21 +80,21 @@ def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
tempdir_path = obj_serializer._output_dir
del obj_serializer
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
tempdir_path = obj_serializer._output_dir
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
assert Path(obj_serializer._output_dir, obj_1_name).exists()
@@ -102,19 +104,19 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path):
tempdir = tmp_path / "tmpdir"
tempdir.mkdir()
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
assert not tempdir.exists()
def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path):
tempdir = tmp_path / "tmpdir"
tempdir.mkdir()
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False)
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=False)
assert tempdir.exists()
def test_obj_serializer_disk_different_types(tmp_path: Path):
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer_1.save(obj_1)
obj_1_loaded = obj_serializer_1.load(obj_1_name)
@@ -123,19 +125,19 @@ def test_obj_serializer_disk_different_types(tmp_path: Path):
assert obj_1_loaded.foo == "bar"
assert obj_1_name.startswith("MockDataclass_")
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path, safe_globals=[int])
obj_2_name = obj_serializer_2.save(9001)
assert obj_serializer_2._obj_class_name == "int"
assert obj_serializer_2.load(obj_2_name) == 9001
assert obj_2_name.startswith("int_")
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path, safe_globals=[str])
obj_3_name = obj_serializer_3.save("foo")
assert obj_serializer_3._obj_class_name == "str"
assert obj_serializer_3.load(obj_3_name) == "foo"
assert obj_3_name.startswith("str_")
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path, safe_globals=[torch.Tensor])
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
obj_4_loaded = obj_serializer_4.load(obj_4_name)
assert obj_serializer_4._obj_class_name == "Tensor"

3634
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff