Compare commits

..

59 Commits

Author SHA1 Message Date
psychedelicious
f34d6099f5 build: fix path in build script 2025-04-04 17:05:00 +10:00
psychedelicious
ef9d832b6a ci: fix name of build hweel workflow 2025-04-04 17:04:27 +10:00
psychedelicious
6c87ea58b0 chore: bump version to v5.10.0dev4 2025-04-04 17:02:13 +10:00
psychedelicious
0e569364ac ci: update workflows to use revised build scripts 2025-04-04 17:00:09 +10:00
psychedelicious
bb6e22606b build: remove installer & convert installer build script to only build the wheel 2025-04-04 16:59:55 +10:00
psychedelicious
3e200a2ba2 chore: bump version to v5.10.0dev3 2025-04-04 16:48:12 +10:00
psychedelicious
4610b55a5d chore: update uv.lock 2025-04-04 16:46:15 +10:00
psychedelicious
b3b3dbd92d build: remove pin on spandrel dependency 2025-04-04 16:41:10 +10:00
psychedelicious
6c36b0508b build: add comment about torchsde to pyproject 2025-04-04 16:40:50 +10:00
psychedelicious
2756c539e0 build: remove pin on gguf dependency
This allows it to pull in sentencepiece on its own. In 0.10.0, it didn't have this package listed as a dependency, but in recent releases it does. So we are able to remove sentencepiece as an explicit dep.
2025-04-04 16:40:36 +10:00
psychedelicious
a34383d460 build: remove unused clip_anytorch dependency 2025-04-04 16:39:20 +10:00
psychedelicious
77f22497d2 build: remove unused pytorch-lightning dependency 2025-04-04 16:39:20 +10:00
psychedelicious
5967d4e1da build: remove unused pyreadline3 dependency 2025-04-04 16:39:20 +10:00
psychedelicious
1253ad5053 build: remove unused pyperclip dependency 2025-04-04 16:39:20 +10:00
psychedelicious
5aa08ab09b build: remove unused pympler dependency 2025-04-04 16:39:19 +10:00
psychedelicious
6ce527768b build: remove unused scikit-image dependency 2025-04-04 16:39:19 +10:00
psychedelicious
fe88012236 build: remove unused npyscreen dependency 2025-04-04 16:39:19 +10:00
psychedelicious
8609b98217 build: remove unused torchmetrics dependency 2025-04-04 16:13:45 +10:00
psychedelicious
19f0bf828c build: remove unused datasets dependency 2025-04-04 16:12:13 +10:00
psychedelicious
26cbeccfdf build: remove unused click dependency 2025-04-04 16:11:38 +10:00
psychedelicious
b5be81b97b build: remove unused omegaconf dependency 2025-04-04 16:09:53 +10:00
psychedelicious
f14d07968b build: remove unused facexlib dependency 2025-04-04 16:09:36 +10:00
psychedelicious
525a89900a build: remove unused timm dependency 2025-04-04 16:08:31 +10:00
psychedelicious
d8df31a8ac chore(ui): typegen 2025-04-04 16:03:29 +10:00
psychedelicious
380a41be34 chore: update uv.lock 2025-04-04 16:03:29 +10:00
psychedelicious
e990afbccb build: remove unused matplotlib dep 2025-04-04 16:03:29 +10:00
psychedelicious
c591478d24 tidy(nodes): remove matplotlib dependency
It was only used for a single color conversion function. Replaced with cv2 code, tested functionality to confirm it works the same.
2025-04-04 16:03:29 +10:00
psychedelicious
30def6a9bd build: move humanize to test deps 2025-04-04 16:03:29 +10:00
psychedelicious
6cf88a601d build: remove unused albumentations dependency
This is not used
2025-04-04 16:03:29 +10:00
psychedelicious
5e14545c32 tidy: delete unused file 2025-04-04 16:03:29 +10:00
psychedelicious
eefbcd2485 build: remove controlnet_aux dependency, remove pin for timm 2025-04-04 16:03:29 +10:00
psychedelicious
13cc44a22c tidy(nodes): rename controlnet_image_processors.py -> controlnet.py 2025-04-04 16:03:29 +10:00
psychedelicious
2cca339a5c tidy(nodes): remove unused old dw openpose detector class 2025-04-04 16:03:29 +10:00
psychedelicious
0a7cf6c0ec tidy(nodes): remove deprecated controlnet "processor" nodes 2025-04-04 16:03:29 +10:00
psychedelicious
06abc1d40a build: upgrade python to 3.12 in pins 2025-04-04 16:03:29 +10:00
psychedelicious
2cde86b7b8 build: update uv.lock 2025-04-04 16:03:28 +10:00
psychedelicious
0a49463c79 fix(backend): remove mps_fixes
The fixes in this module monkeypatched `torch` to resolve some issues with FP16 on macOS. These issues have long since been resolved.

Included in the now-removed fixes is `CustomSlicedAttentionProcessor`, which is intended to reduce memory requirements for MPS. This overrides `diffusers`' own `SlicedAttentionProcessor`.

Unfortunately, `attention_type: sliced` produces hot garbage with the fixes and black images without the fixes. So this class appears to now be a moot point.

Regardless, SDPA is supported on MPS and very efficient, so sliced attention is largely obsolete.
2025-04-04 16:03:28 +10:00
psychedelicious
f3402b6ce7 chore: bump version to v5.10.0dev2
Doing a dev build so I can test the launcher.
2025-04-04 16:03:28 +10:00
psychedelicious
5d3fb822c5 build: downgrade python to 3.11 in pins 2025-04-04 16:03:28 +10:00
psychedelicious
9e70d8eb6e build: restore prev setuptools config to fix wheel build 2025-04-04 16:03:28 +10:00
psychedelicious
402758d502 ci: use py3.12 to build installer 2025-04-04 16:03:28 +10:00
psychedelicious
b97cc51f23 experiment: add pins.json to repo
The launcher will query this file to get the pins needed for installation
2025-04-04 16:03:28 +10:00
psychedelicious
f6f33b5999 chore: bump version to v5.10.0dev1
Doing a dev build so I can test the launcher.
2025-04-04 16:03:28 +10:00
psychedelicious
cd873f1fe5 chore: update uv.lock for latest pydantic
Ran `uv lock --upgrade-package pydantic`
2025-04-04 16:03:28 +10:00
psychedelicious
5f3d398074 fix(ui): handle updated schema structure during invocation parsing
In https://github.com/pydantic/pydantic/pull/10029, pydantic made an improvement to its generated JSON schemas (OpenAPI schemas). The previous and new generated schemas both meet the schema spec.

When we parse the OpenAPI schema to generate node templates, we use some typeguard to narrow schema components from generic OpenAPI schema objects to a node field schema objects. The narrower node field schema objects contain extra data.

For example, they contain a `field_kind` attribute that indicates it the field is an input field or output field. These extra attributes are not part of the OpenAPI spec (but the spec allows does allow for this extra data).

This typeguard relied on a pydantic implementation detail. This was changed in the linked pydantic PR, which released with v2.9.0. With the change, our typeguard rejects input field schema objects, causing parsing to fail with errors/warnings like `Unhandled input property` in the JS console.

In the UI, this causes many fields - mostly model fields - to not show up in the workflow editor.

The fix for this is very simple - instead of relying on an implementation detail for the typeguard, we can check if the incoming schema object has any of our invoke-specific extra attributes. Specifically, we now look for the presence of the `field_kind` attribute on the incoming schema object. If it is present, we know we are dealing with an invocation input field and can parse it appropriately.
2025-04-04 16:03:28 +10:00
psychedelicious
e6b366ff61 chore: typegen 2025-04-04 16:03:28 +10:00
psychedelicious
bcd50ed688 chore: remove pydantic pin 2025-04-04 16:03:27 +10:00
psychedelicious
a5966c3197 chore(ui): typegen 2025-04-04 16:03:27 +10:00
psychedelicious
f28b054872 tests: update tests/test_object_serializer_disk.py 2025-04-04 16:03:27 +10:00
psychedelicious
31681f4ad7 fix(app): add trusted classes to torch safe globals to prevent errors when loading them
In `ObjectSerializerDisk`, we use `torch.load` to load serialized objects from disk. With torch 2.6.0, torch defaults to `weights_only=True`. As a result, torch will raise when attempting to deserialize anything with an unrecognized class.

For example, our `ConditioningFieldData` class is untrusted. When we load conditioning from disk, we will get a runtime error.

Torch provides a method to add trusted classes to an allowlist. This change adds an arg to `ObjectSerializerDisk` to add a list of safe globals to the allowlist and uses it for both `ObjectSerializerDisk` instances.

Note: My first attempt inferred the class from the generic type arg that `ObjectSerializerDisk` accepts, and added that to the allowlist. Unfortunately, this doesn't work.

For example, `ConditioningFieldData` has a `conditionings` attribute that may be one some other untrusted classes representing model-specific conditioning data. So, even if we allowlist `ConditioningFieldData`, loading will fail when torch deserializes the `conditionings` attribute.
2025-04-04 16:03:27 +10:00
Eugene Brodsky
aaf042de48 resolve conflict between timm version needed by LLaVA and controlnet-aux 2025-04-04 16:03:27 +10:00
Eugene Brodsky
c28e685409 reintroduce GPU_DRIVER build arg in CI container build, as it has apparently been removed 2025-04-04 16:03:27 +10:00
Eugene Brodsky
d6ac822a1f remove obsoleted depenencies that were used by the CLI 2025-04-04 16:03:27 +10:00
Eugene Brodsky
f0a4d7ac7f modify docs for python 3.12 2025-04-04 16:03:27 +10:00
Eugene Brodsky
04b0e658df update nodes schema / typegen 2025-04-04 16:03:27 +10:00
Eugene Brodsky
68845f4d85 update uv.lock 2025-04-04 16:03:27 +10:00
Eugene Brodsky
6df5614b54 refactor Dockerfile; get rid of multi-stage build; upgrade to python 3.12 2025-04-04 16:03:27 +10:00
Eugene Brodsky
0bd6f0245b use uv.lock to pin dependencies 2025-04-04 16:03:26 +10:00
Eugene Brodsky
6c9165046e upgrade pytorch and unpin some of the strict dependency pins to facilitate upgrading co-dependencies.
we will use uv.lock to ensure reproducibility
2025-04-04 16:03:26 +10:00
1124 changed files with 28118 additions and 60067 deletions

29
.github/CODEOWNERS vendored
View File

@@ -1,31 +1,32 @@
# continuous integration
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku @psychedelicious
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
# nodes
/invokeai/app/ @blessedcoolant @psychedelicious @hipsterusername @jazzhaiku
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
# installation and configuration
/pyproject.toml @lstein @blessedcoolant @psychedelicious @hipsterusername
/docker/ @lstein @blessedcoolant @psychedelicious @hipsterusername @ebr
/scripts/ @ebr @lstein @psychedelicious @hipsterusername
/installer/ @lstein @ebr @psychedelicious @hipsterusername
/invokeai/assets @lstein @ebr @psychedelicious @hipsterusername
/invokeai/configs @lstein @psychedelicious @hipsterusername
/invokeai/version @lstein @blessedcoolant @psychedelicious @hipsterusername
/pyproject.toml @lstein @blessedcoolant @hipsterusername
/docker/ @lstein @blessedcoolant @hipsterusername @ebr
/scripts/ @ebr @lstein @hipsterusername
/installer/ @lstein @ebr @hipsterusername
/invokeai/assets @lstein @ebr @hipsterusername
/invokeai/configs @lstein @hipsterusername
/invokeai/version @lstein @blessedcoolant @hipsterusername
# web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
# generation, model management, postprocessing
/invokeai/backend @lstein @blessedcoolant @hipsterusername @jazzhaiku @psychedelicious @maryhipp
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
# front ends
/invokeai/frontend/CLI @lstein @psychedelicious @hipsterusername
/invokeai/frontend/install @lstein @ebr @psychedelicious @hipsterusername
/invokeai/frontend/merge @lstein @blessedcoolant @psychedelicious @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @psychedelicious @hipsterusername
/invokeai/frontend/CLI @lstein @hipsterusername
/invokeai/frontend/install @lstein @ebr @hipsterusername
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername

View File

@@ -21,20 +21,6 @@ body:
- label: I have searched the existing issues
required: true
- type: dropdown
id: install_method
attributes:
label: Install method
description: How did you install Invoke?
multiple: false
options:
- "Invoke's Launcher"
- 'Stability Matrix'
- 'Pinokio'
- 'Manual'
validations:
required: true
- type: markdown
attributes:
value: __Describe your environment__
@@ -90,8 +76,8 @@ body:
attributes:
label: Version number
description: |
The version of Invoke you have installed. If it is not the [latest version](https://github.com/invoke-ai/InvokeAI/releases/latest), please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. v6.0.2
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. 3.6.1
validations:
required: true
@@ -99,17 +85,17 @@ body:
id: browser-version
attributes:
label: Browser
description: Your web browser and version, if you do not use the Launcher's provided GUI.
description: Your web browser and version.
placeholder: ex. Firefox 123.0b3
validations:
required: false
required: true
- type: textarea
id: python-deps
attributes:
label: System Information
label: Python dependencies
description: |
Click the gear icon at the bottom left corner, then click "About". Click the copy button and then paste here.
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
validations:
required: false

View File

@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 20
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '20'
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
version: 8.15.6
run_install: false
- name: get pnpm store directory

View File

@@ -67,10 +67,6 @@ jobs:
version: '0.6.10'
enable-cache: true
- name: check pypi classifiers
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: uv run --no-project scripts/check_classifiers.py ./pyproject.toml
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: uv tool run ruff@0.11.2 check --output-format=github .

View File

@@ -1,68 +0,0 @@
# Check the `uv` lockfile for consistency with `pyproject.toml`.
#
# If this check fails, you should run `uv lock` to update the lockfile.
name: 'uv lock checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
jobs:
uv-lock-checks:
env:
# uv requires a venv by default - but for this, we can simply use the system python
UV_SYSTEM_PYTHON: 1
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ inputs.always_run != true }}
id: changed-files
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
# See:
# - CVE-2025-30066
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
# - https://github.com/tj-actions/changed-files/issues/2463
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
with:
files_yaml: |
uvlock-pyprojecttoml:
- 'pyproject.toml'
- 'uv.lock'
- name: setup uv
if: ${{ steps.changed-files.outputs.uvlock-pyprojecttoml_any_changed == 'true' || inputs.always_run == true }}
uses: astral-sh/setup-uv@v5
with:
version: '0.6.10'
enable-cache: true
- name: check lockfile
if: ${{ steps.changed-files.outputs.uvlock-pyprojecttoml_any_changed == 'true' || inputs.always_run == true }}
run: uv lock --locked # this will exit with 1 if the lockfile is not consistent with pyproject.toml
shell: bash

4
.gitignore vendored
View File

@@ -180,7 +180,6 @@ cython_debug/
# Scratch folder
.scratch/
.vscode/
.zed/
# source installer files
installer/*zip
@@ -189,6 +188,3 @@ installer/install.sh
installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/
.aider*
.claude/

View File

@@ -4,29 +4,21 @@ repos:
hooks:
- id: black
name: black
stages: [pre-commit]
stages: [commit]
language: system
entry: black
types: [python]
- id: flake8
name: flake8
stages: [pre-commit]
stages: [commit]
language: system
entry: flake8
types: [python]
- id: isort
name: isort
stages: [pre-commit]
stages: [commit]
language: system
entry: isort
types: [python]
- id: uvlock
name: uv lock
stages: [pre-commit]
language: system
entry: uv lock
files: ^pyproject\.toml$
pass_filenames: false
types: [python]

View File

@@ -5,7 +5,8 @@
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@10.x && corepack enable
RUN corepack use pnpm@8.x
RUN corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./
@@ -98,15 +99,4 @@ CMD ["invokeai-web"]
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# add sources last to minimize image changes on code changes
COPY invokeai ${INVOKEAI_SRC}/invokeai
# this should not increase image size because we've already installed dependencies
# in a previous layer
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv pip install -e .
COPY invokeai ${INVOKEAI_SRC}/invokeai

View File

@@ -60,11 +60,16 @@ Next, these jobs run and must pass. They are the same jobs that are run for ever
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
- **`typegen-checks`**: ensures the frontend and backend types are synced
#### `build-wheel` Job
#### `build-installer` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `./scripts/build_wheel.sh` and uploads `dist.zip`, which contains the wheel and unarchived build.
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
You don't need to download or test these artifacts.
- **`dist`**: the python distribution, to be published on PyPI
- **`InvokeAI-installer-${VERSION}.zip`**: the legacy install scripts
You don't need to download either of these files.
> The legacy install scripts are no longer used, but we haven't updated the workflow to skip building them.
#### Sanity Check & Smoke Test
@@ -74,7 +79,7 @@ It's possible to test the python package before it gets published to PyPI. We've
But, if you want to be extra-super careful, here's how to test it:
- Download the `dist.zip` build artifact from the `build-wheel` job
- Download the `dist.zip` build artifact from the `build-installer` job
- Unzip it and find the wheel file
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/) - but instead of installing from PyPI, install from the wheel
- Test the app

View File

@@ -39,7 +39,7 @@ nodes imported in the `__init__.py` file are loaded. See the README in the nodes
folder for more examples:
```py
from .cool_node import ResizeInvocation
from .cool_node import CoolInvocation
```
## Creating A New Invocation
@@ -69,10 +69,7 @@ The first set of things we need to do when creating a new Invocation are -
So let us do that.
```python
from invokeai.invocation_api import (
BaseInvocation,
invocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -106,12 +103,8 @@ create your own custom field types later in this guide. For now, let's go ahead
and use it.
```python
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -135,12 +128,8 @@ image: ImageField = InputField(description="The input image")
Great. Now let us create our other inputs for `width` and `height`
```python
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -174,13 +163,8 @@ that are provided by it by InvokeAI.
Let us create this function first.
```python
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -207,14 +191,8 @@ all the necessary info related to image outputs. So let us use that.
We will cover how to create your own output types later in this guide.
```python
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.image import ImageOutput
@invocation('resize')
@@ -239,15 +217,9 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
So let's do that.
```python
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
@invocation("resize")
class ResizeInvocation(BaseInvocation):

View File

@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
With the modifications made, the install command should look something like this:
```sh
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu128 --reinstall
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
```
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
@@ -50,11 +50,11 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
7. Install the frontend dev toolchain, paying attention to versions:
7. Install the frontend dev toolchain:
- [`nodejs`](https://nodejs.org/) (tested on LTS, v22)
- [`nodejs`](https://nodejs.org/) (v20+)
- [`pnpm`](https://pnpm.io/installation) (tested on v10)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
8. Do a production build of the frontend:

View File

@@ -0,0 +1,121 @@
# Legacy Scripts
!!! warning "Legacy Scripts"
We recommend using the Invoke Launcher to install and update Invoke. It's a desktop application for Windows, macOS and Linux. It takes care of a lot of nitty gritty details for you.
Follow the [quick start guide](./quick_start.md) to get started.
!!! tip "Use the installer to update"
Using the installer for updates will not erase any of your data (images, models, boards, etc). It only updates the core libraries used to run Invoke.
Simply use the same path you installed to originally to update your existing installation.
Both release and pre-release versions can be installed using the installer. It also supports install through a wheel if needed.
Be sure to review the [installation requirements] and ensure your system has everything it needs to install Invoke.
## Getting the Latest Installer
Download the `InvokeAI-installer-vX.Y.Z.zip` file from the [latest release] page. It is at the bottom of the page, under **Assets**.
After unzipping the installer, you should have a `InvokeAI-Installer` folder with some files inside, including `install.bat` and `install.sh`.
## Running the Installer
!!! tip
Windows users should first double-click the `WinLongPathsEnabled.reg` file to prevent a failed installation due to long file paths.
Double-click the install script:
=== "Windows"
```sh
install.bat
```
=== "Linux/macOS"
```sh
install.sh
```
!!! info "Running the Installer from the commandline"
You can also run the install script from cmd/powershell (Windows) or terminal (Linux/macOS).
!!! warning "Untrusted Publisher (Windows)"
You may get a popup saying the file comes from an `Untrusted Publisher`. Click `More Info` and `Run Anyway` to get past this.
The installation process is simple, with a few prompts:
- Select the version to install. Unless you have a specific reason to install a specific version, select the default (the latest version).
- Select location for the install. Be sure you have enough space in this folder for the base application, as described in the [installation requirements].
- Select a GPU device.
!!! info "Slow Installation"
The installer needs to download several GB of data and install it all. It may appear to get stuck at 99.9% when installing `pytorch` or during a step labeled "Installing collected packages".
If it is stuck for over 10 minutes, something has probably gone wrong and you should close the window and restart.
## Running the Application
Find the install location you selected earlier. Double-click the launcher script to run the app:
=== "Windows"
```sh
invoke.bat
```
=== "Linux/macOS"
```sh
invoke.sh
```
Choose the first option to run the UI. After a series of startup messages, you'll see something like this:
```sh
Uvicorn running on http://127.0.0.1:9090 (Press CTRL+C to quit)
```
Copy the URL into your browser and you should see the UI.
## Improved Outpainting with PatchMatch
PatchMatch is an extra add-on that can improve outpainting. Windows users are in luck - it works out of the box.
On macOS and Linux, a few extra steps are needed to set it up. See the [PatchMatch installation guide](./patchmatch.md).
## First-time Setup
You will need to [install some models] before you can generate.
Check the [configuration docs] for details on configuring the application.
## Updating
Updating is exactly the same as installing - download the latest installer, choose the latest version, enter your existing installation path, and the app will update. None of your data (images, models, boards, etc) will be erased.
!!! info "Dependency Resolution Issues"
We've found that pip's dependency resolution can cause issues when upgrading packages. One very common problem was pip "downgrading" torch from CUDA to CPU, but things broke in other novel ways.
The installer doesn't have this kind of problem, so we use it for updating as well.
## Installation Issues
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
[installation requirements]: ./requirements.md
[FAQ]: ../faq.md
[install some models]: ./models.md
[configuration docs]: ../configuration.md
[latest release]: https://github.com/invoke-ai/InvokeAI/releases/latest
[create an issue]: https://github.com/invoke-ai/InvokeAI/issues
[discord]: https://discord.gg/ZmtBAhwWhy

View File

@@ -71,21 +71,7 @@ The following commands vary depending on the version of Invoke being installed a
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
=== "Invoke v5.12 and later"
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu128`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.10.0 to v5.11.0"
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu126`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.0.0 to v5.9.1"
=== "Invoke v5 or later"
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.

View File

@@ -35,7 +35,7 @@ More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download
Download the most recent launcher for your operating system:
Download the most launcher for your operating system:
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
@@ -49,9 +49,9 @@ If you have an existing Invoke installation, you can select it and let the launc
!!! warning "Problem running the launcher on macOS"
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can manually flag the launcher as safe:
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can either use the [legacy scripts](./legacy_scripts.md) to install, or manually flag the launcher as safe:
- Open the **Invoke Community Edition.dmg** file.
- Open the **Invoke-Installer-mac-arm64.dmg** file.
- Drag the launcher to **Applications**.
- Open a terminal.
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
@@ -117,6 +117,7 @@ If you still have problems, ask for help on the Invoke [discord](https://discord
- You can install the Invoke application as a python package. See our [manual install](./manual.md) docs.
- You can run Invoke with docker. See our [docker install](./docker.md) docs.
- You can still use our legacy scripts to install and run Invoke. See the [legacy scripts](./legacy_scripts.md) docs.
## Need Help?

View File

@@ -13,7 +13,6 @@ If you'd prefer, you can also just download the whole node folder from the linke
To use a community workflow, download the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
- Community Nodes
+ [Anamorphic Tools](#anamorphic-tools)
+ [Adapters-Linked](#adapters-linked-nodes)
+ [Autostereogram](#autostereogram-nodes)
+ [Average Images](#average-images)
@@ -21,12 +20,9 @@ To use a community workflow, download the `.json` node graph file and load it in
+ [Close Color Mask](#close-color-mask)
+ [Clothing Mask](#clothing-mask)
+ [Contrast Limited Adaptive Histogram Equalization](#contrast-limited-adaptive-histogram-equalization)
+ [Curves](#curves)
+ [Depth Map from Wavefront OBJ](#depth-map-from-wavefront-obj)
+ [Enhance Detail](#enhance-detail)
+ [Film Grain](#film-grain)
+ [Flip Pose](#flip-pose)
+ [Flux Ideal Size](#flux-ideal-size)
+ [Generative Grammar-Based Prompt Nodes](#generative-grammar-based-prompt-nodes)
+ [GPT2RandomPromptMaker](#gpt2randompromptmaker)
+ [Grid to Gif](#grid-to-gif)
@@ -65,13 +61,6 @@ To use a community workflow, download the `.json` node graph file and load it in
- [Help](#help)
--------------------------------
### Anamorphic Tools
**Description:** A set of nodes to perform anamorphic modifications to images, like lens blur, streaks, spherical distortion, and vignetting.
**Node Link:** https://github.com/JPPhoto/anamorphic-tools
--------------------------------
### Adapters Linked Nodes
@@ -143,13 +132,6 @@ Node Link: https://github.com/VeyDlin/clahe-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clahe-node/master/.readme/node.png" width="500" />
--------------------------------
### Curves
**Description:** Adjust an image's curve based on a user-defined string.
**Node Link:** https://github.com/JPPhoto/curves-node
--------------------------------
### Depth Map from Wavefront OBJ
@@ -180,20 +162,6 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
**Node Link:** https://github.com/JPPhoto/film-grain-node
--------------------------------
### Flip Pose
**Description:** This node will flip an openpose image horizontally, recoloring it to make sure that it isn't facing the wrong direction. Note that it does not work with openpose hands.
**Node Link:** https://github.com/JPPhoto/flip-pose-node
--------------------------------
### Flux Ideal Size
**Description:** This node returns an ideal size to use for the first stage of a Flux image generation pipeline. Generating at the right size helps limit duplication and odd subject placement.
**Node Link:** https://github.com/JPPhoto/flux-ideal-size
--------------------------------
### Generative Grammar-Based Prompt Nodes

View File

@@ -23,10 +23,6 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import (
SqliteModelRelationshipRecordStorage,
)
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
from invokeai.app.services.names.names_default import SimpleNameService
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
@@ -43,7 +39,6 @@ from invokeai.app.services.workflow_records.workflow_records_sqlite import Sqlit
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
CogView4ConditioningInfo,
ConditioningFieldData,
FLUXConditioningInfo,
SD3ConditioningInfo,
@@ -117,6 +112,7 @@ class ApiDependencies:
safe_globals=[torch.Tensor],
ephemeral=True,
),
max_cache_size=0,
)
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](
@@ -127,7 +123,6 @@ class ApiDependencies:
SDXLConditioningInfo,
FLUXConditioningInfo,
SD3ConditioningInfo,
CogView4ConditioningInfo,
],
ephemeral=True,
),
@@ -140,8 +135,6 @@ class ApiDependencies:
download_queue=download_queue_service,
events=events,
)
model_relationships = ModelRelationshipsService()
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
@@ -167,8 +160,6 @@ class ApiDependencies:
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
model_relationships=model_relationships,
model_relationship_records=model_relationship_records,
download_queue=download_queue_service,
names=names,
performance_statistics=performance_statistics,

View File

@@ -1,7 +1,8 @@
import typing
from enum import Enum
from importlib.metadata import distributions
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from platform import python_version
from typing import Optional
import torch
@@ -43,6 +44,24 @@ class AppVersion(BaseModel):
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
class AppDependencyVersions(BaseModel):
"""App depencency Versions Response"""
accelerate: str = Field(description="accelerate version")
compel: str = Field(description="compel version")
cuda: Optional[str] = Field(description="CUDA version")
diffusers: str = Field(description="diffusers version")
numpy: str = Field(description="Numpy version")
opencv: str = Field(description="OpenCV version")
onnx: str = Field(description="ONNX version")
pillow: str = Field(description="Pillow (PIL) version")
python: str = Field(description="Python version")
torch: str = Field(description="PyTorch version")
torchvision: str = Field(description="PyTorch Vision version")
transformers: str = Field(description="transformers version")
xformers: Optional[str] = Field(description="xformers version")
class AppConfig(BaseModel):
"""App Config Response"""
@@ -57,19 +76,27 @@ async def get_version() -> AppVersion:
return AppVersion(version=__version__)
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=dict[str, str])
async def get_app_deps() -> dict[str, str]:
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
async def get_app_deps() -> AppDependencyVersions:
try:
cuda = torch.version.cuda or "N/A"
except Exception:
cuda = "N/A"
deps["CUDA"] = cuda
sorted_deps = dict(sorted(deps.items(), key=lambda item: item[0].lower()))
return sorted_deps
xformers = version("xformers")
except PackageNotFoundError:
xformers = None
return AppDependencyVersions(
accelerate=version("accelerate"),
compel=version("compel"),
cuda=torch.version.cuda,
diffusers=version("diffusers"),
numpy=version("numpy"),
opencv=version("opencv-python"),
onnx=version("onnx"),
pillow=version("pillow"),
python=python_version(),
torch=torch.version.__version__,
torchvision=version("torchvision"),
transformers=version("transformers"),
xformers=xformers,
)
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)

View File

@@ -1,12 +1,21 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
class AddImagesToBoardResult(BaseModel):
board_id: str = Field(description="The id of the board the images were added to")
added_image_names: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(BaseModel):
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
@board_images_router.post(
"/",
operation_id="add_image_to_board",
@@ -14,26 +23,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
response_model=AddImagesToBoardResult,
)
async def add_image_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
) -> AddImagesToBoardResult:
):
"""Creates a board_image"""
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
result = ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to add image to board")
@@ -45,25 +45,14 @@ async def add_image_to_board(
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
response_model=RemoveImagesFromBoardResult,
)
async def remove_image_from_board(
image_name: str = Body(description="The name of the image to remove", embed=True),
) -> RemoveImagesFromBoardResult:
):
"""Removes an image from its board, if it had one"""
try:
removed_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@@ -83,25 +72,16 @@ async def add_images_to_board(
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
added_image_names: list[str] = []
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id,
image_name=image_name,
board_id=board_id, image_name=image_name
)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
added_image_names.append(image_name)
except Exception:
pass
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@@ -120,20 +100,13 @@ async def remove_images_from_board(
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
try:
removed_images: set[str] = set()
affected_boards: set[str] = set()
removed_image_names: list[str] = []
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
removed_image_names.append(image_name)
except Exception:
pass
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@@ -146,7 +146,7 @@ async def list_boards(
response_model=list[str],
)
async def list_all_board_image_names(
board_id: str = Path(description="The id of the board or 'none' for uncategorized images"),
board_id: str = Path(description="The id of the board"),
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
) -> list[str]:

View File

@@ -1,34 +1,24 @@
import io
import json
import traceback
from typing import ClassVar, Optional
from typing import Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import (
DeleteImagesResult,
ImageDTO,
ImageUrlsDTO,
StarredImagesResult,
UnstarredImagesResult,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@@ -37,19 +27,6 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
IMAGE_MAX_AGE = 31536000
class ResizeToDimensions(BaseModel):
width: int = Field(..., gt=0)
height: int = Field(..., gt=0)
MAX_SIZE: ClassVar[int] = 4096 * 4096
@model_validator(mode="after")
def validate_total_output_size(self):
if self.width * self.height > self.MAX_SIZE:
raise ValueError(f"Max total output size for resizing is {self.MAX_SIZE} pixels")
return self
@images_router.post(
"/upload",
operation_id="upload_image",
@@ -69,11 +46,6 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
resize_to: Optional[str] = Body(
default=None,
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
examples=['"[1024,1024]"'],
),
metadata: Optional[str] = Body(
default=None,
description="The metadata to associate with the image, must be a stringified JSON dict",
@@ -87,33 +59,13 @@ async def upload_image(
contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
if crop_visible:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
if crop_visible:
try:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
except Exception:
raise HTTPException(status_code=500, detail="Failed to crop image")
if resize_to:
try:
dims = json.loads(resize_to)
resize_dims = ResizeToDimensions(**dims)
except Exception:
raise HTTPException(status_code=400, detail="Invalid resize_to format or size")
try:
# heuristic_resize_fast expects an RGB or RGBA image
pil_rgba = pil_image.convert("RGBA")
np_image = pil_to_np(pil_rgba)
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
pil_image = np_to_pil(np_image)
except Exception:
raise HTTPException(status_code=500, detail="Failed to resize image")
extracted_metadata = extract_metadata_from_image(
pil_image=pil_image,
invokeai_metadata_override=metadata,
@@ -160,30 +112,18 @@ async def create_image_upload_entry(
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
@images_router.delete("/i/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),
) -> DeleteImagesResult:
) -> None:
"""Deletes an image"""
deleted_images: set[str] = set()
affected_boards: set[str] = set()
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except Exception:
# TODO: Does this need any exception handling at all?
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
@images_router.delete("/intermediates", operation_id="clear_intermediates")
async def clear_intermediates() -> int:
@@ -395,52 +335,23 @@ async def list_image_dtos(
return image_dtos
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
class DeleteImagesFromListResult(BaseModel):
deleted_images: list[str]
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
async def delete_images_from_list(
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesResult:
) -> DeleteImagesFromListResult:
try:
deleted_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
async def delete_uncategorized_images() -> DeleteImagesResult:
"""Deletes all images that are uncategorized"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id="none", categories=None, is_intermediate=None
)
try:
deleted_images: set[str] = set()
affected_boards: set[str] = set()
deleted_images: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add("none")
deleted_images.append(image_name)
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@@ -449,50 +360,36 @@ class ImagesUpdatedFromListResult(BaseModel):
updated_image_names: list[str] = Field(description="The image names that were updated")
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
async def star_images_in_list(
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> StarredImagesResult:
) -> ImagesUpdatedFromListResult:
try:
starred_images: set[str] = set()
affected_boards: set[str] = set()
updated_image_names: list[str] = []
for image_name in image_names:
try:
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=True)
)
starred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
updated_image_names.append(image_name)
except Exception:
pass
return StarredImagesResult(
starred_images=list(starred_images),
affected_boards=list(affected_boards),
)
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
async def unstar_images_in_list(
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> UnstarredImagesResult:
) -> ImagesUpdatedFromListResult:
try:
unstarred_images: set[str] = set()
affected_boards: set[str] = set()
updated_image_names: list[str] = []
for image_name in image_names:
try:
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=False)
)
unstarred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
updated_image_names.append(image_name)
except Exception:
pass
return UnstarredImagesResult(
unstarred_images=list(unstarred_images),
affected_boards=list(affected_boards),
)
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
@@ -563,61 +460,3 @@ async def get_bulk_download_item(
return response
except Exception:
raise HTTPException(status_code=404)
@images_router.get("/names", operation_id="get_image_names")
async def get_image_names(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
try:
result = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image names")
@images_router.post(
"/images_by_names",
operation_id="get_images_by_names",
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
try:
image_service = ApiDependencies.invoker.services.images
# Fetch DTOs preserving the order of requested names
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
dto = image_service.get_dto(name)
image_dtos.append(dto)
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image DTOs")

View File

@@ -41,7 +41,6 @@ from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
STARTER_MODELS,
StarterModel,
StarterModelBundle,
StarterModelWithoutDependencies,
)
@@ -86,7 +85,6 @@ example_model_config = {
"config_path": "string",
"key": "string",
"hash": "string",
"file_size": 1,
"description": "string",
"source": "string",
"converted_at": 0,
@@ -292,7 +290,7 @@ async def get_hugging_face_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -450,7 +448,7 @@ async def install_model(
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
examples=[{"name": "string", "description": "string"}],
example={"name": "string", "description": "string"},
),
) -> ModelInstallJob:
"""Install a model using a string identifier.
@@ -800,7 +798,7 @@ async def convert_model(
class StarterModelResponse(BaseModel):
starter_models: list[StarterModel]
starter_bundles: dict[str, StarterModelBundle]
starter_bundles: dict[str, list[StarterModel]]
def get_is_installed(
@@ -834,7 +832,7 @@ async def get_starter_models() -> StarterModelResponse:
model.dependencies = missing_deps
for bundle in starter_bundles.values():
for model in bundle.models:
for model in bundle:
model.is_installed = get_is_installed(model, installed_models)
# Remove already-installed dependencies
missing_deps: list[StarterModelWithoutDependencies] = []
@@ -894,12 +892,6 @@ class HFTokenHelper:
huggingface_hub.login(token=token, add_to_git_credential=False)
return cls.get_status()
@classmethod
def reset_token(cls) -> HFTokenStatus:
with SuppressOutput(), contextlib.suppress(Exception):
huggingface_hub.logout()
return cls.get_status()
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
async def get_hf_login_status() -> HFTokenStatus:
@@ -922,8 +914,3 @@ async def do_hf_login(
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
async def reset_hf_token() -> HFTokenStatus:
return HFTokenHelper.reset_token()

View File

@@ -1,215 +0,0 @@
"""FastAPI route for model relationship records."""
from typing import List
from fastapi import APIRouter, Body, HTTPException, Path, status
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
model_relationships_router = APIRouter(prefix="/v1/model_relationships", tags=["model_relationships"])
# === Schemas ===
class ModelRelationshipCreateRequest(BaseModel):
model_key_1: str = Field(
...,
description="The key of the first model in the relationship",
examples=[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210",
],
)
model_key_2: str = Field(
...,
description="The key of the second model in the relationship",
examples=[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
"10f7807b-ab54-46a9-ab03-600e88c630a1",
"f6c1d267-cf87-4ee0-bee0-37e791eacab7",
],
)
class ModelRelationshipBatchRequest(BaseModel):
model_keys: List[str] = Field(
...,
description="List of model keys to fetch related models for",
examples=[
[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
],
[
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210",
],
[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
],
],
)
# === Routes ===
@model_relationships_router.get(
"/i/{model_key}",
operation_id="get_related_models",
response_model=list[str],
responses={
200: {
"description": "A list of related model keys was retrieved successfully",
"content": {
"application/json": {
"example": [
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
]
}
},
},
404: {"description": "The specified model could not be found"},
422: {"description": "Validation error"},
},
)
async def get_related_models(
model_key: str = Path(..., description="The key of the model to get relationships for"),
) -> list[str]:
"""
Get a list of model keys related to a given model.
"""
try:
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.post(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully created"},
400: {"description": "Invalid model keys or self-referential relationship"},
409: {"description": "The relationship already exists"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Add Model Relationship",
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
)
async def add_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate"),
) -> None:
"""
Add a relationship between two models.
Relationships are bidirectional and will be accessible from both models.
- Raises 400 if keys are invalid or identical.
- Raises 409 if the relationship already exists.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.delete(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully removed"},
400: {"description": "Invalid model keys or self-referential relationship"},
404: {"description": "The relationship does not exist"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Remove Model Relationship",
description="Removes a **bidirectional** relationship between two models. The relationship must already exist.",
)
async def remove_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect"),
) -> None:
"""
Removes a bidirectional relationship between two model keys.
- Raises 400 if attempting to unlink a model from itself.
- Raises 404 if the relationship was not found.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.post(
"/batch",
operation_id="get_related_models_batch",
response_model=List[str],
responses={
200: {
"description": "Related model keys retrieved successfully",
"content": {
"application/json": {
"example": [
"ca562b14-995e-4a42-90c1-9528f1a5921d",
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
"c382eaa3-0e28-4ab0-9446-408667699aeb",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
]
}
},
},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Get Related Model Keys (Batch)",
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering.",
)
async def get_related_models_batch(
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections"),
) -> list[str]:
"""
Accepts multiple model keys and returns a flat list of all unique related keys.
Useful when working with multiple selections in the UI or cross-model comparisons.
"""
try:
all_related: set[str] = set()
for key in req.model_keys:
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
all_related.update(related)
return list(all_related)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Body, HTTPException, Path, Query
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
@@ -14,15 +14,13 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByBatchIDsResult,
CancelByDestinationResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
FieldIdentifier,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemNotFoundError,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -60,19 +58,17 @@ async def enqueue_batch(
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
try:
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
@session_queue_router.get(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItem]},
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
},
)
async def list_queue_items(
@@ -81,42 +77,12 @@ async def list_queue_items(
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets all queue items (without graphs)"""
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
@session_queue_router.get(
"/{queue_id}/list_all",
operation_id="list_all_queue_items",
responses={
200: {"model": list[SessionQueueItem]},
},
)
async def list_all_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
)
@session_queue_router.put(
@@ -128,10 +94,7 @@ async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
try:
return ApiDependencies.invoker.services.session_processor.resume()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
@@ -143,10 +106,7 @@ async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
try:
return ApiDependencies.invoker.services.session_processor.pause()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
@@ -158,25 +118,7 @@ async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
@session_queue_router.put(
"/{queue_id}/delete_all_except_current",
operation_id="delete_all_except_current",
responses={200: {"model": DeleteAllExceptCurrentResult}},
)
async def delete_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
@session_queue_router.put(
@@ -189,12 +131,7 @@ async def cancel_by_batch_ids(
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
queue_id=queue_id, batch_ids=batch_ids
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
@@ -207,12 +144,9 @@ async def cancel_by_destination(
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.put(
@@ -225,10 +159,7 @@ async def retry_items_by_id(
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
@session_queue_router.put(
@@ -242,14 +173,11 @@ async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
@session_queue_router.put(
@@ -263,10 +191,7 @@ async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
try:
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
@session_queue_router.get(
@@ -280,10 +205,7 @@ async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
@session_queue_router.get(
@@ -297,10 +219,7 @@ async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
@session_queue_router.get(
@@ -314,12 +233,9 @@ async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
@session_queue_router.get(
@@ -334,10 +250,7 @@ async def get_batch_status(
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
@session_queue_router.get(
@@ -353,27 +266,7 @@ async def get_queue_item(
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
@session_queue_router.delete(
"/{queue_id}/i/{item_id}",
operation_id="delete_queue_item",
)
async def delete_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
try:
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.put(
@@ -388,12 +281,8 @@ async def cancel_queue_item(
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
@session_queue_router.get(
@@ -406,27 +295,6 @@ async def counts_by_destination(
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
try:
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
@session_queue_router.delete(
"/{queue_id}/d/{destination}",
operation_id="delete_by_destination",
responses={200: {"model": DeleteByDestinationResult}},
)
async def delete_by_destination(
queue_id: str = Path(description="The queue id to query"),
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
try:
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)

View File

@@ -22,7 +22,6 @@ from invokeai.app.api.routers import (
download_queue,
images,
model_manager,
model_relationships,
session_queue,
style_presets,
utilities,
@@ -126,7 +125,6 @@ app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(model_relationships.model_relationships_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
@@ -158,7 +156,7 @@ web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here

View File

@@ -5,8 +5,6 @@ from __future__ import annotations
import inspect
import re
import sys
import types
import typing
import warnings
from abc import ABC, abstractmethod
from enum import Enum
@@ -22,14 +20,12 @@ from typing import (
Literal,
Optional,
Type,
TypedDict,
TypeVar,
Union,
cast,
)
import semver
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
@@ -76,24 +72,13 @@ class Classification(str, Enum, metaclass=MetaEnum):
Special = "special"
class Bottleneck(str, Enum, metaclass=MetaEnum):
"""
The bottleneck of an invocation.
- `Network`: The invocation's execution is network-bound.
- `GPU`: The invocation's execution is GPU-bound.
"""
Network = "network"
GPU = "gpu"
class UIConfigBase(BaseModel):
"""
Provides additional node configuration to the UI.
This is used internally by the @invocation decorator logic. Do not use this directly.
"""
tags: Optional[list[str]] = Field(default=None, description="The node's tags")
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
version: str = Field(
@@ -108,11 +93,6 @@ class UIConfigBase(BaseModel):
)
class OriginalModelField(TypedDict):
annotation: Any
field_info: FieldInfo
class BaseInvocationOutput(BaseModel):
"""
Base class for all invocation outputs.
@@ -120,12 +100,6 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
output_meta: Optional[dict[str, JsonValue]] = Field(
default=None,
description="Optional dictionary of metadata for the invocation output, unrelated to the invocation's actual output value. This is not exposed as an output field.",
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -141,9 +115,6 @@ class BaseInvocationOutput(BaseModel):
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
return cls.model_fields["type"].default
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation_output decorator."""
model_config = ConfigDict(
protected_namespaces=(),
validate_assignment=True,
@@ -177,7 +148,7 @@ class BaseInvocation(ABC, BaseModel):
return cls.model_fields["type"].default
@classmethod
def get_output_annotation(cls) -> Type[BaseInvocationOutput]:
def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@@ -209,7 +180,7 @@ class BaseInvocation(ABC, BaseModel):
Internal invoke method, calls `invoke()` after some prep.
Handles optional fields that are required to call `invoke()` and invocation cache.
"""
for field_name, field in type(self).model_fields.items():
for field_name, field in self.model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict
continue
@@ -224,9 +195,9 @@ class BaseInvocation(ABC, BaseModel):
setattr(self, field_name, orig_default)
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
if input_ == Input.Connection:
raise RequiredConnectionException(type(self).model_fields["type"].default, field_name)
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
elif input_ == Input.Any:
raise MissingInputException(type(self).model_fields["type"].default, field_name)
raise MissingInputException(self.model_fields["type"].default, field_name)
# skip node cache codepath if it's disabled
if services.configuration.node_cache_size == 0:
@@ -264,8 +235,6 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
bottleneck: ClassVar[Bottleneck]
UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict(
@@ -276,9 +245,6 @@ class BaseInvocation(ABC, BaseModel):
coerce_numbers_to_str=True,
)
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation decorator."""
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
@@ -290,26 +256,6 @@ class InvocationRegistry:
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
invocation_type = invocation.get_type()
node_pack = invocation.UIConfig.node_pack
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
if clobbered_invocation is not None:
# This should always be true - we just checked if the invocation type was in the set
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The invocation being clobbered is a core invocation
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
else:
# The invocation being clobbered is a custom invocation
logger.warning(
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
)
cls._invocation_classes.remove(clobbered_invocation)
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()
@@ -368,15 +314,6 @@ class InvocationRegistry:
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
output_type = output.get_type()
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
if clobbered_output is not None:
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
logger.warning(f'Overriding invocation output "{output_type}"')
cls._output_classes.remove(clobbered_output)
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()
@@ -385,11 +322,6 @@ class InvocationRegistry:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
"""Gets a map of all output types to their output classes."""
return {i.get_type(): i for i in cls.get_output_classes()}
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
@@ -415,11 +347,6 @@ class InvocationRegistry:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
@classmethod
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
"""Gets the output class for a given output type."""
return cls.get_outputs_map().get(output_type)
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
@@ -427,12 +354,11 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"use_cache",
"type",
"workflow",
"bottleneck",
}
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
RESERVED_OUTPUT_FIELD_NAMES = {"type", "output_meta"}
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
class _Model(BaseModel):
@@ -499,53 +425,11 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
return None
class NoDefaultSentinel:
pass
def validate_field_default(
cls_name: str, field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo
) -> None:
"""Validates the default value of a field against its pydantic field definition."""
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
# We store the original default value in the json_schema_extra dict, so we can validate it here.
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
if orig_default is NoDefaultSentinel:
return
# To validate the default value, we can create a temporary pydantic model with the field we are validating as its
# only field. Then validate the default value against this temporary model.
TempDefaultValidator = cast(BaseModel, create_model(cls_name, **{field_name: (annotation, field_info)}))
try:
TempDefaultValidator.model_validate({field_name: orig_default})
except Exception as e:
raise InvalidFieldError(
f'Default value for field "{field_name}" on invocation "{invocation_type}" is invalid, {e}'
) from e
def is_optional(annotation: Any) -> bool:
"""
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
"""
origin = typing.get_origin(annotation)
# PEP 604 unions (int|None) have origin types.UnionType
is_union = origin is typing.Union or origin is types.UnionType
if not is_union:
return False
return any(arg is type(None) for arg in typing.get_args(annotation))
def invocation(
invocation_type: str,
title: Optional[str] = None,
@@ -554,7 +438,6 @@ def invocation(
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
bottleneck: Bottleneck = Bottleneck.GPU,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@@ -566,7 +449,6 @@ def invocation(
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@@ -578,28 +460,27 @@ def invocation(
# The node pack is the module name - will be "invokeai" for built-in nodes
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The node being clobbered is a core node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
)
else:
# The node being clobbered is a custom node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
)
validate_fields(cls.model_fields, invocation_type)
fields: dict[str, tuple[Any, FieldInfo]] = {}
original_model_fields: dict[str, OriginalModelField] = {}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
if field_info.default is None and not is_optional(annotation):
annotation = annotation | None
fields[field_name] = (annotation, field_info)
# Add OpenAPI schema extras
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
@@ -615,7 +496,7 @@ def invocation(
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
uiconfig["version"] = version
else:
logger.warning(f'No version specified for node "{invocation_type}", using "1.0.0"')
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
@@ -623,8 +504,6 @@ def invocation(
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache
cls.bottleneck = bottleneck
# Add the invocation type to the model.
# You'd be tempted to just add the type field and rebuild the model, like this:
@@ -634,27 +513,11 @@ def invocation(
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
invocation_type_annotation = Literal[invocation_type]
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
# This cast makes the type annotation match the class's true type.
invocation_type_field_info = cast(
FieldInfo,
Field(title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
fields["type"] = (invocation_type_annotation, invocation_type_field_info)
# Invocation outputs must be registered using the @invocation_output decorator, but it is possible that the
# output is registered _after_ this invocation is registered. It depends on module import ordering.
#
# We can only confirm the output for an invocation is registered after all modules are imported. There's
# only really one good time to do that - during application startup, in `run_app.py`, after loading all
# custom nodes.
#
# We can still do some basic validation here - ensure the invoke method is defined and returns an instance
# of BaseInvocationOutput.
# Validate the `invoke()` method is implemented
if "invoke" in cls.__abstractmethods__:
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
@@ -676,13 +539,17 @@ def invocation(
)
docstring = cls.__doc__
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
new_class.__doc__ = docstring
new_class._original_model_fields = original_model_fields
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(invocation_type_annotation, invocation_type_field),
)
cls.__doc__ = docstring
InvocationRegistry.register_invocation(new_class)
InvocationRegistry.register_invocation(cls)
return new_class
return cls
return wrapper
@@ -705,41 +572,29 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
fields: dict[str, tuple[Any, FieldInfo]] = {}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
if field_info.default is not PydanticUndefined and is_optional(annotation):
annotation = annotation | None
fields[field_name] = (annotation, field_info)
# Add the output type to the model.
output_type_annotation = Literal[output_type]
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
# This cast makes the type annotation match the class's true type.
output_type_field_info = cast(
FieldInfo,
Field(title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
fields["type"] = (output_type_annotation, output_type_field_info)
docstring = cls.__doc__
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
new_class.__doc__ = docstring
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(output_type_annotation, output_type_field),
)
cls.__doc__ = docstring
InvocationRegistry.register_output(new_class)
InvocationRegistry.register_output(cls)
return new_class
return cls
return wrapper

View File

@@ -64,6 +64,7 @@ class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[],
min_length=1,
description="The images to batch over",
)
@@ -119,6 +120,7 @@ class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[],
min_length=1,
description="The strings to batch over",
)
@@ -174,6 +176,7 @@ class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)
@@ -227,6 +230,7 @@ class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)

View File

@@ -1,158 +0,0 @@
import cv2
import numpy as np
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.invocation_api import Classification, ImageOutput
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/"
class BriaControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
model: ModelIdentifierField = Field(description="The ControlNet model to use")
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
conditioning_scale: float = Field(description="The weight given to the ControlNet")
@invocation_output("bria_controlnet_output")
class BriaControlNetOutput(BaseInvocationOutput):
"""Bria ControlNet info"""
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
@invocation(
"bria_controlnet",
title="ControlNet - Bria",
tags=["controlnet", "bria"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Collect Bria ControlNet info to pass to denoiser node."""
control_image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
)
control_mode: BRIA_CONTROL_MODES = InputField(default="depth", description="The mode of the ControlNet")
control_weight: float = InputField(default=1.0, ge=-1, le=2, description="The weight given to the ControlNet")
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
if self.control_mode == "canny":
control_image = extract_canny(image_in)
elif self.control_mode == "depth":
control_image = extract_depth(image_in, context)
elif self.control_mode == "pose":
control_image = extract_openpose(image_in, context)
elif self.control_mode == "colorgrid":
control_image = tile(64, image_in)
elif self.control_mode == "recolor":
control_image = convert_to_grayscale(image_in)
elif self.control_mode == "tile":
control_image = tile(16, image_in)
control_image = resize_img(control_image)
image_dto = context.images.save(image=control_image)
image_output = ImageOutput.build(image_dto)
return BriaControlNetOutput(
preprocessed_images=image_output.image,
control=BriaControlNetField(
image=ImageField(image_name=image_dto.image_name),
model=self.control_model,
mode=self.control_mode,
conditioning_scale=self.control_weight,
),
)
RATIO_CONFIGS_1024 = {
0.6666666666666666: {"width": 832, "height": 1248},
0.7432432432432432: {"width": 880, "height": 1184},
0.8028169014084507: {"width": 912, "height": 1136},
1.0: {"width": 1024, "height": 1024},
1.2456140350877194: {"width": 1136, "height": 912},
1.3454545454545455: {"width": 1184, "height": 880},
1.4339622641509433: {"width": 1216, "height": 848},
1.5: {"width": 1248, "height": 832},
1.5490196078431373: {"width": 1264, "height": 816},
1.62: {"width": 1296, "height": 800},
1.7708333333333333: {"width": 1360, "height": 768},
}
def extract_depth(image: Image.Image, context: InvocationContext):
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
return depth_map
def extract_openpose(image: Image.Image, context: InvocationContext):
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face)
with body_model as body_model, hand_model as hand_model, face_model as face_model:
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
processed_image_open_pose = processed_image_open_pose.resize(image.size)
return processed_image_open_pose
def extract_canny(input_image):
image = np.array(input_image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def convert_to_grayscale(image):
gray_image = image.convert("L").convert("RGB")
return gray_image
def tile(downscale_factor, input_image):
control_image = input_image.resize(
(input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)
).resize(input_image.size, Image.Resampling.NEAREST)
return control_image
def resize_img(control_image):
image_ratio = control_image.width / control_image.height
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
to_height = RATIO_CONFIGS_1024[ratio]["height"]
to_width = RATIO_CONFIGS_1024[ratio]["width"]
resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS)
return resized_image

View File

@@ -1,46 +0,0 @@
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from PIL import Image
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation
@invocation(
"bria_decoder",
title="Decoder - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDecoderInvocation(BaseInvocation):
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
with context.models.load(self.vae.vae) as vae:
assert isinstance(vae, AutoencoderKL)
latents = latents / vae.config.scaling_factor
latents = latents.to(device=vae.device, dtype=vae.dtype)
decoded_output = vae.decode(latents)
image = decoded_output.sample
# Convert to numpy with proper gradient handling
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
img = Image.fromarray(image)
image_dto = context.images.save(image=img)
return ImageOutput.build(image_dto)

View File

@@ -1,180 +0,0 @@
from typing import List, Tuple
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
from invokeai.backend.bria.controlnet_utils import prepare_control_images
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
@invocation_output("bria_denoise_output")
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
@invocation(
"bria_denoise",
title="Denoise - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDenoiseInvocation(BaseInvocation):
num_steps: int = InputField(
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
)
guidance_scale: float = InputField(
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser"
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
title="VAE",
)
latents: LatentsField = InputField(
description="Latents to denoise",
input=Input.Connection,
title="Latents",
)
latent_image_ids: LatentsField = InputField(
description="Latent Image IDs to denoise",
input=Input.Connection,
title="Latent Image IDs",
)
pos_embeds: LatentsField = InputField(
description="Positive Prompt Embeds",
input=Input.Connection,
title="Positive Prompt Embeds",
)
neg_embeds: LatentsField = InputField(
description="Negative Prompt Embeds",
input=Input.Connection,
title="Negative Prompt Embeds",
)
text_ids: LatentsField = InputField(
description="Text IDs",
input=Input.Connection,
title="Text IDs",
)
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
description="ControlNet",
input=Input.Connection,
title="ControlNet",
default=None,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
latents = context.tensors.load(self.latents.latents_name)
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
text_ids = context.tensors.load(self.text_ids.latents_name)
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
device = None
dtype = None
with (
context.models.load(self.transformer.transformer) as transformer,
context.models.load(scheduler_identifier) as scheduler,
context.models.load(self.vae.vae) as vae,
context.models.load(self.t5_encoder.text_encoder) as t5_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(transformer, BriaTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
assert isinstance(vae, AutoencoderKL)
dtype = transformer.dtype
device = transformer.device
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds))
control_model, control_images, control_modes, control_scales = None, None, None, None
if self.control is not None:
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
context=context,
vae=vae,
width=1024,
height=1024,
device=vae.device,
)
pipeline = BriaControlNetPipeline(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=t5_encoder,
tokenizer=t5_tokenizer,
controlnet=control_model,
)
pipeline.to(device=transformer.device, dtype=transformer.dtype)
latents = pipeline(
control_image=control_images,
control_mode=control_modes,
width=1024,
height=1024,
controlnet_conditioning_scale=control_scales,
num_inference_steps=self.num_steps,
max_sequence_length=128,
guidance_scale=self.guidance_scale,
latents=latents,
latent_image_ids=latent_image_ids,
text_ids=text_ids,
prompt_embeds=pos_embeds,
negative_prompt_embeds=neg_embeds,
output_type="latent",
)[0]
assert isinstance(latents, torch.Tensor)
saved_input_latents_tensor = context.tensors.save(latents)
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
return BriaDenoiseInvocationOutput(latents=latents_output)
def _prepare_multi_control(
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
control = self.control if isinstance(self.control, list) else [self.control]
control_images, control_models, control_modes, control_scales = [], [], [], []
for controlnet in control:
if controlnet is not None:
control_models.append(context.models.load(controlnet.model).model)
control_modes.append(BriaControlModes[controlnet.mode].value)
control_scales.append(controlnet.conditioning_scale)
try:
control_images.append(context.images.get_pil(controlnet.image.image_name))
except Exception:
raise FileNotFoundError(
f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline."
)
control_model = BriaMultiControlNetModel(control_models).to(device)
tensored_control_images, tensored_control_modes = prepare_control_images(
vae=vae,
control_images=control_images,
control_modes=control_modes,
width=width,
height=height,
device=device,
)
return control_model, tensored_control_images, tensored_control_modes, control_scales

View File

@@ -1,76 +0,0 @@
import torch
from invokeai.app.invocations.fields import Input, InputField, OutputField
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import (
BaseInvocationOutput,
FieldDescriptions,
LatentsField,
)
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_latent_sampler_output")
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_latent_sampler",
title="Latent Sampler - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaLatentSamplerInvocation(BaseInvocation):
seed: int = InputField(
default=42,
title="Seed",
description="The seed to use for the latent sampler",
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
with context.models.load(self.transformer.transformer) as transformer:
device = transformer.device
dtype = transformer.dtype
height, width = 1024, 1024
generator = torch.Generator(device=device).manual_seed(self.seed)
num_channels_latents = 4
latents, latent_image_ids = prepare_latents(
batch_size=1,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=dtype,
device=device,
generator=generator,
)
saved_latents_tensor = context.tensors.save(latents)
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)
latents_output = LatentsField(latents_name=saved_latents_tensor)
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
return BriaLatentSamplerInvocationOutput(
latents=latents_output,
latent_image_ids=latent_image_ids_output,
)

View File

@@ -1,58 +0,0 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import (
ModelIdentifierField,
SubModelType,
T5EncoderField,
TransformerField,
VAEField,
)
from invokeai.invocation_api import (
BaseInvocation,
BaseInvocationOutput,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_model_loader_output")
class BriaModelLoaderOutput(BaseInvocationOutput):
"""Bria base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"bria_model_loader",
title="Main Model - Bria",
tags=["model", "bria"],
version="1.0.0",
classification=Classification.Prototype,
)
class BriaModelLoaderInvocation(BaseInvocation):
"""Loads a bria base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description="Bria model (Transformer) to load",
ui_type=UIType.BriaMainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput:
for key in [self.model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return BriaModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
vae=VAEField(vae=vae),
)

View File

@@ -1,93 +0,0 @@
from typing import Optional
import torch
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.app.invocations.model import T5EncoderField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.pipeline_bria_controlnet import encode_prompt
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InputField,
LatentsField,
invocation,
invocation_output,
)
@invocation_output("bria_text_encoder_output")
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
pos_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
neg_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
text_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_text_encoder",
title="Prompt - Bria",
tags=["prompt", "conditioning", "bria"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaTextEncoderInvocation(BaseInvocation):
prompt: str = InputField(
title="Prompt",
description="The prompt to encode",
)
negative_prompt: Optional[str] = InputField(
title="Negative Prompt",
description="The negative prompt to encode",
default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate",
)
max_length: int = InputField(
default=128,
title="Max Length",
description="The maximum length of the prompt",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
with (
t5_encoder_info as text_encoder,
t5_tokenizer_info as tokenizer,
):
assert isinstance(tokenizer, T5TokenizerFast)
assert isinstance(text_encoder, T5EncoderModel)
(prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt(
prompt=self.prompt,
tokenizer=tokenizer,
text_encoder=text_encoder,
negative_prompt=self.negative_prompt,
device=text_encoder.device,
num_images_per_prompt=1,
max_sequence_length=self.max_length,
lora_scale=1.0,
)
saved_pos_tensor = context.tensors.save(prompt_embeds)
saved_neg_tensor = context.tensors.save(negative_prompt_embeds)
saved_text_ids_tensor = context.tensors.save(text_ids)
pos_embeds_output = LatentsField(latents_name=saved_pos_tensor)
neg_embeds_output = LatentsField(latents_name=saved_neg_tensor)
text_ids_output = LatentsField(latents_name=saved_text_ids_tensor)
return BriaTextEncoderInvocationOutput(
pos_embeds=pos_embeds_output,
neg_embeds=neg_embeds_output,
text_ids=text_ids_output,
)

View File

@@ -1,363 +0,0 @@
from typing import Callable, Optional
import torch
import torchvision.transforms as tv_transforms
from diffusers.models.transformers.transformer_cogview4 import CogView4Transformer2DModel
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
CogView4ConditioningField,
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"cogview4_denoise",
title="Denoise - CogView4",
tags=["image", "cogview4"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run the denoising process with a CogView4 model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.cogview4_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: CogView4ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: CogView4ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=32, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=32, description="Height of the generated image.")
steps: int = InputField(default=25, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
represent the regions to be preserved.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
# 1.0 represents the regions to be preserved.
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
cogview4_conditioning = cond_data.conditionings[0]
assert isinstance(cogview4_conditioning, CogView4ConditioningInfo)
cogview4_conditioning = cogview4_conditioning.to(dtype=dtype, device=device)
return cogview4_conditioning.glm_embeds
def _get_noise(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
batch_size,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
"""Prepare the CFG scale list.
Args:
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
on the scheduler used (e.g. higher order schedulers).
Returns:
list[float]: _description_
"""
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
def _convert_timesteps_to_sigmas(self, image_seq_len: int, timesteps: torch.Tensor) -> list[float]:
# The logic to prepare the timestep / sigma schedule is based on:
# https://github.com/huggingface/diffusers/blob/b38450d5d2e5b87d5ff7088ee5798c85587b9635/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L575-L595
# The default FlowMatchEulerDiscreteScheduler configs are based on:
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
# This implementation differs slightly from the original for the sake of simplicity (differs in terminal value
# handling, not quantizing timesteps to integers, etc.).
def calculate_timestep_shift(
image_seq_len: int, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75
) -> float:
m = (image_seq_len / base_seq_len) ** 0.5
mu = m * max_shift + base_shift
return mu
def time_shift_linear(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return mu / (mu + (1 / t - 1) ** sigma)
mu = calculate_timestep_shift(image_seq_len)
sigmas = time_shift_linear(mu, 1.0, timesteps)
return sigmas.tolist()
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
assert isinstance(transformer_info.model, CogView4Transformer2DModel)
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
dtype=inference_dtype,
device=device,
)
# Prepare misc. conditioning variables.
# TODO(ryand): We could expose these as params (like with SDXL). But, we should experiment to see if they are
# useful first.
original_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
target_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([(0, 0)], dtype=pos_prompt_embeds.dtype, device=device)
# Prepare the timestep / sigma schedule.
patch_size = transformer_info.model.config.patch_size # type: ignore
assert isinstance(patch_size, int)
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
# We add an extra step to the end to account for the final timestep of 0.0.
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
sigmas = self._convert_timesteps_to_sigmas(image_seq_len, torch.tensor(timesteps))
total_steps = len(timesteps) - 1
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(total_steps)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels # type: ignore
assert isinstance(num_channels_latents, int)
noise = self._get_noise(
batch_size=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
# Prepare input latent image.
if init_latents is not None:
# Noise the init_latents by the appropriate amount for the first timestep.
s_0 = sigmas[0]
latents = s_0 * noise + (1.0 - s_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return latents
# Prepare inpaint extension.
inpaint_mask = self._prep_inpaint_mask(context, latents)
inpaint_extension: RectifiedFlowInpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = RectifiedFlowInpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=latents,
),
)
with transformer_info.model_on_device() as (_, transformer):
assert isinstance(transformer, CogView4Transformer2DModel)
# Denoising loop
for step_idx in tqdm(range(total_steps)):
t_curr = timesteps[step_idx]
sigma_curr = sigmas[step_idx]
sigma_prev = sigmas[step_idx + 1]
# Expand the timestep to match the latent model input.
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
timestep = torch.tensor([t_curr * 1000], device=device).expand(latents.shape[0])
# TODO(ryand): Support both sequential and batched CFG inference.
noise_pred_cond = transformer(
hidden_states=latents,
encoder_hidden_states=pos_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond = transformer(
hidden_states=latents,
encoder_hidden_states=neg_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
# TODO(ryand): Is casting to float32 necessary for precision/stability? I copied this from SD3.
latents = latents.to(dtype=torch.float32)
latents = latents + (sigma_prev - sigma_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
if inpaint_extension is not None:
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t_curr),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.CogView4)
return step_callback

View File

@@ -1,69 +0,0 @@
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
# should refactor to avoid this duplication.
@invocation(
"cogview4_i2l",
title="Image to Latents - CogView4",
tags=["image", "latents", "vae", "i2l", "cogview4"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image."""
image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
latents = vae.config.scaling_factor * latents
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -1,86 +0,0 @@
from contextlib import nullcontext
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
# should refactor to avoid this duplication.
@invocation(
"cogview4_l2i",
title="Latents to Image - CogView4",
tags=["latents", "image", "vae", "l2i", "cogview4"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(TorchDevice.choose_torch_device())
vae.disable_tiling()
tiling_context = nullcontext()
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
img = vae.decode(latents, return_dict=False)[0]
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -1,55 +0,0 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import (
GlmEncoderField,
ModelIdentifierField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("cogview4_model_loader_output")
class CogView4ModelLoaderOutput(BaseInvocationOutput):
"""CogView4 base model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
glm_encoder: GlmEncoderField = OutputField(description=FieldDescriptions.glm_encoder, title="GLM Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"cogview4_model_loader",
title="Main Model - CogView4",
tags=["model", "cogview4"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4ModelLoaderInvocation(BaseInvocation):
"""Loads a CogView4 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.cogview4_model,
ui_type=UIType.CogView4MainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
glm_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
glm_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
return CogView4ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
glm_encoder=GlmEncoderField(tokenizer=glm_tokenizer, text_encoder=glm_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -1,92 +0,0 @@
import torch
from transformers import GlmModel, PreTrainedTokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
from invokeai.app.invocations.model import GlmEncoderField
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
CogView4ConditioningInfo,
ConditioningFieldData,
)
from invokeai.backend.util.devices import TorchDevice
# The CogView4 GLM Text Encoder max sequence length set based on the default in diffusers.
COGVIEW4_GLM_MAX_SEQ_LEN = 1024
@invocation(
"cogview4_text_encoder",
title="Prompt - CogView4",
tags=["prompt", "conditioning", "cogview4"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class CogView4TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a cogview4 image."""
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
glm_encoder: GlmEncoderField = InputField(
title="GLM Encoder",
description=FieldDescriptions.glm_encoder,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
conditioning_name = context.conditioning.save(conditioning_data)
return CogView4ConditioningOutput.build(conditioning_name)
def _glm_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
prompt = [self.prompt]
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
with (
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
):
context.util.signal_progress("Running GLM text encoder")
assert isinstance(glm_text_encoder, GlmModel)
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
text_inputs = glm_tokenizer(
prompt,
padding="longest",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = glm_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = glm_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)
current_length = text_input_ids.shape[1]
pad_length = (16 - (current_length % 16)) % 16
if pad_length > 0:
pad_ids = torch.full(
(text_input_ids.shape[0], pad_length),
fill_value=glm_tokenizer.pad_token_id,
dtype=text_input_ids.dtype,
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = glm_text_encoder(
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
).hidden_states[-2]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds

View File

@@ -1,7 +1,7 @@
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
from compel import Compel, ReturnedEmbeddingsType, SplitLongTextMode
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@@ -104,7 +104,6 @@ class CompelInvocation(BaseInvocation):
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
split_long_text_mode=SplitLongTextMode.SENTENCES,
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -114,13 +113,6 @@ class CompelInvocation(BaseInvocation):
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
@@ -213,7 +205,6 @@ class SDXLPromptInvocationBase:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
split_long_text_mode=SplitLongTextMode.SENTENCES,
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -229,10 +220,7 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info

View File

@@ -274,12 +274,12 @@ class InvokeAdjustImageHuePlusInvocation(BaseInvocation, WithMetadata, WithBoard
title="Enhance Image",
tags=["enhance", "image"],
category="image",
version="1.2.1",
version="1.2.0",
)
class InvokeImageEnhanceInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies processing from PIL's ImageEnhance module. Originally created by @dwringer"""
image: ImageField = InputField(description="The image for which to apply processing")
image: ImageField = InputField(default=None, description="The image for which to apply processing")
invert: bool = InputField(default=False, description="Whether to invert the image colors")
color: float = InputField(ge=0, default=1.0, description="Color enhancement factor")
contrast: float = InputField(ge=0, default=1.0, description="Contrast enhancement factor")

View File

@@ -22,11 +22,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
heuristic_resize_fast,
)
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
@@ -113,7 +109,7 @@ class ControlNetInvocation(BaseInvocation):
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
version="1.1.1",
version="1.0.1",
classification=Classification.Prototype,
)
class HeuristicResizeInvocation(BaseInvocation):
@@ -126,7 +122,7 @@ class HeuristicResizeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_img = pil_to_np(image)
np_resized = heuristic_resize_fast(np_img, (self.width, self.height))
np_resized = heuristic_resize(np_img, (self.width, self.height))
resized = np_to_pil(np_resized)
image_dto = context.images.save(image=resized)
return ImageOutput.build(image_dto)

View File

@@ -1,14 +1,12 @@
from typing import Literal, Optional
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -44,13 +42,15 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.3.0",
version="1.2.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising."""
"""Creates mask for denoising model run."""
mask: ImageField = InputField(description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(default=16, ge=0, description="How far to expand the edges of the mask", ui_order=2)
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
@@ -81,110 +81,45 @@ class CreateGradientMaskInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
# Resize the mask_image. Makes the filter 64x faster and doesn't hurt quality in latent scale anyway
mask_image = mask_image.resize(
(
mask_image.width // LATENT_SCALE_FACTOR,
mask_image.height // LATENT_SCALE_FACTOR,
),
resample=Image.Resampling.BILINEAR,
)
mask_np_orig = np.array(mask_image, dtype=np.float32)
self.edge_radius = self.edge_radius // LATENT_SCALE_FACTOR # scale the edge radius to match the mask size
if self.edge_radius > 0:
mask_np = 255 - mask_np_orig # invert so 0 is unmasked (higher values = higher denoise strength)
dilated_mask = mask_np.copy()
# Create kernel based on coherence mode
if self.coherence_mode == "Box Blur":
# Create a circular distance kernel that fades from center outward
kernel_size = self.edge_radius * 2 + 1
center = self.edge_radius
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
for i in range(kernel_size):
for j in range(kernel_size):
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
if dist <= self.edge_radius:
kernel[i, j] = 1.0 - (dist / self.edge_radius)
else: # Gaussian Blur or Staged
# Create a Gaussian kernel
kernel_size = self.edge_radius * 2 + 1
kernel = cv2.getGaussianKernel(
kernel_size, self.edge_radius / 2.5
) # 2.5 is a magic number (standard deviation capturing)
kernel = kernel * kernel.T # Make 2D gaussian kernel
kernel = kernel / np.max(kernel) # Normalize center to 1.0
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
# Ensure values outside radius are 0
center = self.edge_radius
for i in range(kernel_size):
for j in range(kernel_size):
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
if dist > self.edge_radius:
kernel[i, j] = 0
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# 2D max filter
mask_tensor = torch.tensor(mask_np)
kernel_tensor = torch.tensor(kernel)
dilated_mask = 255 - self.max_filter2D_torch(mask_tensor, kernel_tensor).cpu()
dilated_mask = dilated_mask.numpy()
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
threshold = (1 - self.minimum_denoise) * 255
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever expanded mask is darker than the original mask but original was above threshhold, set it to the threshold
# makes any expansion areas drop to threshhold. Raising minimum across the image happen outside of this if
threshold_mask = (dilated_mask < mask_np_orig) & (mask_np_orig > threshold)
dilated_mask = np.where(threshold_mask, threshold, mask_np_orig)
# wherever expanded mask is less than 255 but greater than threshold, drop it to threshold (minimum denoise)
threshold_mask = (dilated_mask > threshold) & (dilated_mask < 255)
dilated_mask = np.where(threshold_mask, threshold, dilated_mask)
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
dilated_mask = mask_np_orig.copy()
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
# convert to tensor
dilated_mask = np.clip(dilated_mask, 0, 255).astype(np.uint8)
mask_tensor = torch.tensor(dilated_mask, device=torch.device("cpu"))
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# binary mask for compositing
expanded_mask = np.where((dilated_mask < 255), 0, 255)
expanded_mask_image = Image.fromarray(expanded_mask.astype(np.uint8), mode="L")
expanded_mask_image = expanded_mask_image.resize(
(
mask_image.width * LATENT_SCALE_FACTOR,
mask_image.height * LATENT_SCALE_FACTOR,
),
resample=Image.Resampling.NEAREST,
)
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
# restore the original mask size
dilated_mask = Image.fromarray(dilated_mask.astype(np.uint8))
dilated_mask = dilated_mask.resize(
(
mask_image.width * LATENT_SCALE_FACTOR,
mask_image.height * LATENT_SCALE_FACTOR,
),
resample=Image.Resampling.NEAREST,
)
# stack the mask as a tensor, repeating 4 times on dimmension 1
dilated_mask_tensor = image_resized_to_grid_as_tensor(dilated_mask, normalize=False)
mask_name = context.tensors.save(tensor=dilated_mask_tensor.unsqueeze(0))
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = dilated_mask_tensor
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@@ -202,29 +137,3 @@ class CreateGradientMaskInvocation(BaseInvocation):
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
def max_filter2D_torch(self, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
"""
This morphological operation is much faster in torch than numpy or opencv
For reasonable kernel sizes, the overhead of copying the data to the GPU is not worth it.
"""
h, w = kernel.shape
pad_h, pad_w = h // 2, w // 2
padded = torch.nn.functional.pad(image, (pad_w, pad_w, pad_h, pad_h), mode="constant", value=0)
result = torch.zeros_like(image)
# This looks like it's inside out, but it does the same thing and is more efficient
for i in range(h):
for j in range(w):
weight = kernel[i, j]
if weight <= 0:
continue
# Extract the region from padded tensor
region = padded[i : i + image.shape[0], j : j + image.shape[1]]
# Apply weight and update max
result = torch.maximum(result, region * weight)
return result

View File

@@ -608,7 +608,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
method=single_ip_adapter.method,
)
)

View File

@@ -40,10 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
BriaMainModel = "BriaMainModelField"
BriaControlNetModel = "BriaControlNetModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
@@ -63,10 +60,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
FluxKontextModel = "FluxKontextModelField"
# endregion
# region Misc Field Types
@@ -144,7 +137,6 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
glm_encoder = "GLM (THUDM) tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
clip_g_model = "CLIP-G Embed loader"
unet = "UNet (scheduler, LoRAs)"
@@ -159,7 +151,6 @@ class FieldDescriptions:
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
cogview4_model = "CogView4 model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -217,7 +208,6 @@ class FieldDescriptions:
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
class ImageField(BaseModel):
@@ -294,24 +284,12 @@ class FluxFillConditioningField(BaseModel):
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
class FluxKontextConditioningField(BaseModel):
"""A conditioning field for FLUX Kontext (reference image)."""
image: ImageField = Field(description="The Kontext reference image.")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class CogView4ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
@@ -411,8 +389,8 @@ class InputFieldJSONSchemaExtra(BaseModel):
"""
input: Input
orig_required: bool
field_kind: FieldKind
orig_required: bool = True
default: Optional[Any] = None
orig_default: Optional[Any] = None
ui_hidden: bool = False
@@ -447,7 +425,7 @@ class WithWorkflow:
workflow = None
def __init_subclass__(cls) -> None:
logger.warning(
logger.warn(
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
)
super().__init_subclass__()
@@ -509,7 +487,7 @@ def InputField(
input: Input = Input.Any,
ui_type: Optional[UIType] = None,
ui_component: Optional[UIComponent] = None,
ui_hidden: Optional[bool] = None,
ui_hidden: bool = False,
ui_order: Optional[int] = None,
ui_choice_labels: Optional[dict[str, str]] = None,
) -> Any:
@@ -545,20 +523,15 @@ def InputField(
json_schema_extra_ = InputFieldJSONSchemaExtra(
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
ui_choice_labels=ui_choice_labels,
field_kind=FieldKind.Input,
orig_required=True,
)
if ui_type is not None:
json_schema_extra_.ui_type = ui_type
if ui_component is not None:
json_schema_extra_.ui_component = ui_component
if ui_hidden is not None:
json_schema_extra_.ui_hidden = ui_hidden
if ui_order is not None:
json_schema_extra_.ui_order = ui_order
if ui_choice_labels is not None:
json_schema_extra_.ui_choice_labels = ui_choice_labels
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's
`invoke()` function.
@@ -588,7 +561,7 @@ def InputField(
if default_factory is not _Unset and default_factory is not None:
default = default_factory()
logger.warning('"default_factory" is not supported, calling it now to set "default"')
logger.warn('"default_factory" is not supported, calling it now to set "default"')
# These are the args we may wish pass to the pydantic `Field()` function
field_args = {
@@ -630,7 +603,7 @@ def InputField(
return Field(
**provided_args,
json_schema_extra=json_schema_extra_.model_dump(exclude_unset=True),
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
)

View File

@@ -16,12 +16,13 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
FluxFillConditioningField,
FluxKontextConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
@@ -32,8 +33,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
@@ -52,7 +53,6 @@ from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantTyp
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -63,9 +63,9 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.0.0",
version="3.3.0",
)
class FluxDenoiseInvocation(BaseInvocation):
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
@@ -145,20 +145,11 @@ class FluxDenoiseInvocation(BaseInvocation):
description=FieldDescriptions.vae,
input=Input.Connection,
)
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
# upstream nodes.
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -304,10 +295,10 @@ class FluxDenoiseInvocation(BaseInvocation):
assert packed_h * packed_w == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: RectifiedFlowInpaintExtension | None = None
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = RectifiedFlowInpaintExtension(
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
@@ -385,27 +376,6 @@ class FluxDenoiseInvocation(BaseInvocation):
dtype=inference_dtype,
)
kontext_extension = None
if self.kontext_conditioning is not None:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None
if kontext_extension is not None:
# Ensure batch sizes match
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
x = denoise(
model=transformer,
img=x,
@@ -421,8 +391,6 @@ class FluxDenoiseInvocation(BaseInvocation):
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
)
x = unpack(x.float(), self.height, self.width)
@@ -897,10 +865,7 @@ class FluxDenoiseInvocation(BaseInvocation):
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
# The denoise function now handles Kontext conditioning correctly,
# so we don't need to slice the latents here
latents = state.latents.float()
state.latents = unpack(latents, self.height, self.width).squeeze()
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
context.util.flux_step_callback(state)
return step_callback

View File

@@ -1,40 +0,0 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxKontextConditioningField,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_kontext_output")
class FluxKontextOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Kontext invocation."""
kontext_cond: FluxKontextConditioningField = OutputField(
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
)
@invocation(
"flux_kontext",
title="Kontext Conditioning - FLUX",
tags=["conditioning", "kontext", "flux"],
category="conditioning",
version="1.0.0",
)
class FluxKontextInvocation(BaseInvocation):
"""Prepares a reference image for FLUX Kontext conditioning."""
image: ImageField = InputField(description="The Kontext reference image.")
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
"""Packages the provided image into a Kontext conditioning field."""
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))

View File

@@ -1,9 +1,7 @@
import math
from typing import Literal, Optional
from typing import Optional
import torch
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -41,15 +39,12 @@ class FluxReduxOutput(BaseInvocationOutput):
)
DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "nearest-exact"]
@invocation(
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="2.1.0",
version="2.0.0",
classification=Classification.Beta,
)
class FluxReduxInvocation(BaseInvocation):
@@ -66,64 +61,23 @@ class FluxReduxInvocation(BaseInvocation):
title="FLUX Redux Model",
ui_type=UIType.FluxReduxModel,
)
downsampling_factor: int = InputField(
ge=1,
le=9,
default=1,
description="Redux Downsampling Factor (1-9)",
)
downsampling_function: DOWNSAMPLING_FUNCTIONS = InputField(
default="area",
description="Redux Downsampling Function",
)
weight: float = InputField(
ge=0,
le=1,
default=1.0,
description="Redux weight (0.0-1.0)",
)
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
encoded_x = self._siglip_encode(context, image)
redux_conditioning = self._flux_redux_encode(context, encoded_x)
if self.downsampling_factor > 1 or self.weight != 1.0:
redux_conditioning = self._downsample_weight(context, redux_conditioning)
tensor_name = context.tensors.save(redux_conditioning)
return FluxReduxOutput(
redux_cond=FluxReduxConditioningField(conditioning=TensorField(tensor_name=tensor_name), mask=self.mask)
)
@torch.no_grad()
def _downsample_weight(self, context: InvocationContext, redux_conditioning: torch.Tensor) -> torch.Tensor:
# Downsampling derived from https://github.com/kaibioinfo/ComfyUI_AdvancedRefluxControl
(b, t, h) = redux_conditioning.shape
m = int(math.sqrt(t))
if self.downsampling_factor > 1:
redux_conditioning = redux_conditioning.view(b, m, m, h)
redux_conditioning = torch.nn.functional.interpolate(
redux_conditioning.transpose(1, -1),
size=(m // self.downsampling_factor, m // self.downsampling_factor),
mode=self.downsampling_function,
)
redux_conditioning = redux_conditioning.transpose(1, -1).reshape(b, -1, h)
if self.weight != 1.0:
redux_conditioning = redux_conditioning * self.weight * self.weight
return redux_conditioning
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
siglip_model_config = self._get_siglip_model(context)
with context.models.load(siglip_model_config.key).model_on_device() as (_, model):
assert isinstance(model, SiglipVisionModel)
model_abs_path = context.models.get_absolute_path(siglip_model_config)
processor = SiglipImageProcessor.from_pretrained(model_abs_path, local_files_only=True)
assert isinstance(processor, SiglipImageProcessor)
siglip_pipeline = SigLipPipeline(processor, model)
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)

View File

@@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple, Union
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
@@ -111,9 +111,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
if context.config.get().log_tokenization:
self._log_t5_tokenization(context, t5_tokenizer)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
@@ -154,9 +151,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
if context.config.get().log_tokenization:
self._log_clip_tokenization(context, clip_tokenizer)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
@@ -176,88 +170,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _log_t5_tokenization(
self,
context: InvocationContext,
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
) -> None:
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
# Tokenize the prompt using the same parameters as the model's text encoder.
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=self.t5_max_seq_len,
truncation=True,
add_special_tokens=True, # This is important for T5 to add the EOS token.
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("\u2581", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for token in tokens:
if token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
used_tokens += 1
elif token == tokenizer.pad_token:
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
continue
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
context.logger.info(f"{tokenized_str}\x1b[0m")
def _log_clip_tokenization(
self,
context: InvocationContext,
tokenizer: CLIPTokenizer,
) -> None:
"""Logs the tokenization of a prompt for a CLIP-based model."""
max_length = tokenizer.model_max_length
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
attention_mask = tokenized_output.attention_mask[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The CLIP tokenizer uses '</w>' to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("</w>", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for i, token in enumerate(tokens):
if attention_mask[i] == 0:
# Do not log padding tokens.
continue
if token == tokenizer.bos_token:
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
elif token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
context.logger.info(f"{tokenized_str}\x1b[0m")

View File

@@ -21,14 +21,14 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size - SD1.5, SDXL",
tags=["latents", "math", "ideal_size"],
version="1.0.6",
version="1.0.5",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(description=FieldDescriptions.unet)
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "

View File

@@ -975,13 +975,13 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.1",
version="1.0.0",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(description="The target image")
target_image: ImageField = InputField(default=None, description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)
@@ -1218,15 +1218,12 @@ class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Add Image Noise",
tags=["image", "noise"],
category="image",
version="1.1.0",
version="1.0.1",
)
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add noise to an image"""
image: ImageField = InputField(description="The image to add noise to")
mask: Optional[ImageField] = InputField(
default=None, description="Optional mask determining where to apply noise (black=noise, white=no noise)"
)
seed: int = InputField(
default=0,
ge=0,
@@ -1270,27 +1267,12 @@ class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
(image.width, image.height), Image.Resampling.NEAREST
)
# Create a noisy version of the input image
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
# Apply mask if provided
if self.mask is not None:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
# Paste back the alpha channel
noisy_image.putalpha(alpha)
if mask_image.size != image.size:
mask_image = mask_image.resize(image.size, Image.Resampling.LANCZOS)
result_image = image.copy()
mask_image = ImageOps.invert(mask_image)
result_image.paste(noisy_image, (0, 0), mask=mask_image)
else:
result_image = noisy_image
# Paste back the alpha channel from the original image
result_image.putalpha(alpha)
image_dto = context.images.save(image=result_image)
image_dto = context.images.save(image=noisy_image)
return ImageOutput.build(image_dto)

View File

@@ -127,16 +127,13 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
return infilled
LAMA_MODEL_URL = "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt"
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
with self._context.models.load_remote_model(
source=LAMA_MODEL_URL,
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)

View File

@@ -31,7 +31,6 @@ class IPAdapterField(BaseModel):
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
method: str = Field(default="full", description="Weight apply method")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@@ -95,7 +94,7 @@ class IPAdapterInvocation(BaseInvocation):
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = InputField(
method: Literal["full", "style", "composition"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField(
@@ -148,38 +147,6 @@ class IPAdapterInvocation(BaseInvocation):
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_precise":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1", "down_blocks.2", "mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_strong":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.0", "up_blocks.1", "up_blocks.2", "down_blocks.0", "down_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = [
"up_blocks.0.attentions.1",
"up_blocks.1.attentions.1",
"up_blocks.2.attentions.1",
"up_blocks.0.attentions.2",
"up_blocks.1.attentions.2",
"up_blocks.2.attentions.2",
"up_blocks.0.attentions.0",
"up_blocks.1.attentions.0",
"up_blocks.2.attentions.0",
"down_blocks.0.attentions.0",
"down_blocks.0.attentions.1",
"down_blocks.0.attentions.2",
"down_blocks.1.attentions.0",
"down_blocks.1.attentions.1",
"down_blocks.1.attentions.2",
"down_blocks.2.attentions.0",
"down_blocks.2.attentions.2",
]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:
@@ -195,7 +162,6 @@ class IPAdapterInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
mask=self.mask,
method=self.method,
),
)

View File

@@ -3,14 +3,13 @@ from typing import Any
import torch
from PIL.Image import Image
from pydantic import field_validator
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.util.devices import TorchDevice
@@ -55,17 +54,10 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
images = self._get_images(context)
model_config = context.models.get_config(self.vllm_model)
with context.models.load(self.vllm_model).model_on_device() as (_, model):
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
model_abs_path = context.models.get_absolute_path(model_config)
processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
model = LlavaOnevisionPipeline(model, processor)
output = model.run(
with context.models.load(self.vllm_model) as vllm_model:
assert isinstance(vllm_model, LlavaOnevisionModel)
output = vllm_model.run(
prompt=self.prompt,
images=images,
device=TorchDevice.choose_torch_device(),

View File

@@ -42,9 +42,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-L", "ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = Field(
description="Method to apply IP Weights with"
)
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
@@ -154,10 +152,6 @@ GENERATION_MODES = Literal[
"sd3_img2img",
"sd3_inpaint",
"sd3_outpaint",
"cogview4_txt2img",
"cogview4_img2img",
"cogview4_inpaint",
"cogview4_outpaint",
]

View File

@@ -39,17 +39,7 @@ from invokeai.app.invocations.model import (
VAEField,
VAEOutput,
)
from invokeai.app.invocations.primitives import (
BooleanCollectionOutput,
BooleanOutput,
FloatCollectionOutput,
FloatOutput,
IntegerCollectionOutput,
IntegerOutput,
LatentsOutput,
StringCollectionOutput,
StringOutput,
)
from invokeai.app.invocations.primitives import BooleanOutput, FloatOutput, IntegerOutput, LatentsOutput, StringOutput
from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -1172,133 +1162,3 @@ class MetadataToT2IAdaptersInvocation(BaseInvocation, WithMetadata):
adapters = append_list(T2IAdapterField, i.t2i_adapter, adapters)
return MDT2IAdapterListOutput(t2i_adapter_list=adapters)
@invocation(
"metadata_to_string_collection",
title="Metadata To String Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToStringCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts a string collection value of a label from metadata"""
label: CORE_LABELS_STRING = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[str] = InputField(
description="The default string collection to use if not found in the metadata"
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return StringCollectionOutput(collection=output)
@invocation(
"metadata_to_integer_collection",
title="Metadata To Integer Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToIntegerCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts an integer value Collection of a label from metadata"""
label: CORE_LABELS_INTEGER = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[int] = InputField(description="The default integer to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return IntegerCollectionOutput(collection=output)
@invocation(
"metadata_to_float_collection",
title="Metadata To Float Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToFloatCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts a Float value Collection of a label from metadata"""
label: CORE_LABELS_FLOAT = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[float] = InputField(description="The default float to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return FloatCollectionOutput(collection=output)
@invocation(
"metadata_to_bool_collection",
title="Metadata To Bool Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToBoolCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts a Boolean value Collection of a label from metadata"""
label: CORE_LABELS_BOOL = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[bool] = InputField(description="The default bool to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return BooleanCollectionOutput(collection=output)

View File

@@ -68,11 +68,6 @@ class T5EncoderField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class GlmEncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')

View File

@@ -13,7 +13,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
BoundingBoxField,
CogView4ConditioningField,
ColorField,
ConditioningField,
DenoiseMaskField,
@@ -430,15 +429,6 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("flux_conditioning_collection_output")
class FluxConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
collection: list[FluxConditioningField] = OutputField(
description="The output conditioning tensors",
)
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""
@@ -450,17 +440,6 @@ class SD3ConditioningOutput(BaseInvocationOutput):
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
@invocation_output("cogview4_conditioning_output")
class CogView4ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
conditioning: CogView4ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "CogView4ConditioningOutput":
return cls(conditioning=CogView4ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -24,7 +24,7 @@ from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -263,10 +263,10 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Prepare inpaint extension.
inpaint_mask = self._prep_inpaint_mask(context, latents)
inpaint_extension: RectifiedFlowInpaintExtension | None = None
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = RectifiedFlowInpaintExtension(
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field
from transformers import AutoProcessor
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
@@ -104,13 +104,14 @@ class SegmentAnythingInvocation(BaseInvocation):
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = SamModel.from_pretrained(
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)

View File

@@ -1,3 +1,12 @@
import uvicorn
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
def get_app():
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
@@ -9,18 +18,9 @@ def get_app():
def run_app() -> None:
"""The main entrypoint for the app."""
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
# Parse the CLI arguments before doing anything else, which ensures CLI args correctly override settings from other
# sources like `invokeai.yaml` or env vars.
# Parse the CLI arguments.
InvokeAIArgs.parse_args()
import uvicorn
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from invokeai.backend.util.logging import InvokeAILogger
# Load config.
app_config = get_config()
@@ -31,14 +31,6 @@ def run_app() -> None:
if app_config.pytorch_cuda_alloc_conf:
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
# This import must happen after configure_torch_cuda_allocator() is called, because the module imports torch.
from invokeai.app.invocations.baseinvocation import InvocationRegistry
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.backend.util.devices import TorchDevice
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
# Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called.
from invokeai.app.util.startup_utils import (
apply_monkeypatches,
@@ -68,15 +60,6 @@ def run_app() -> None:
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
# Check all invocations and ensure their outputs are registered.
for invocation in InvocationRegistry.get_invocation_classes():
invocation_type = invocation.get_type()
output_annotation = invocation.get_output_annotation()
if output_annotation not in InvocationRegistry.get_output_classes():
logger.warning(
f'Invocation "{invocation_type}" has unregistered output class "{output_annotation.__name__}"'
)
if app_config.dev_reload:
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
# imported.

View File

@@ -14,14 +14,15 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
@@ -30,12 +31,17 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_image_from_board(
self,
image_name: str,
) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM board_images
@@ -43,6 +49,10 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_images_for_board(
self,
@@ -50,26 +60,27 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
@@ -79,55 +90,47 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
with self._db.transaction() as cursor:
params: list[str | bool] = []
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
AND board_images.board_id = ?
"""
params.append(board_id)
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Put a ring on it
stmt += ";"
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
@@ -135,31 +138,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -20,57 +20,61 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def delete(self, board_id: str) -> None:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
except Exception as e:
raise BoardRecordDeleteException from e
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
def save(
self,
board_name: str,
) -> BoardRecord:
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -80,43 +84,45 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
with self._db.transaction() as cursor:
try:
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
try:
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
return self.get(board_id)
def get_many(
@@ -127,77 +133,78 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
cursor = self._conn.cursor()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
WHERE archived = 0;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
# Execute count query
cursor.execute(count_query)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@@ -24,6 +24,7 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
@@ -92,7 +93,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@@ -175,7 +176,7 @@ class InvokeAIAppConfig(BaseSettings):
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
# DEVICE
device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$")
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION

View File

@@ -8,7 +8,6 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from shutil import disk_usage
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
@@ -336,14 +335,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
assert job.download_path
free_space = disk_usage(job.download_path.parent).free
GB = 2**30
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
if free_space < job.total_bytes:
raise RuntimeError(
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
)
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():

View File

@@ -241,7 +241,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def build(
@@ -264,7 +263,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
credits=queue_item.credits,
)

View File

@@ -5,7 +5,6 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -98,17 +97,3 @@ class ImageRecordStorageBase(ABC):
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -3,7 +3,7 @@ import datetime
from enum import Enum
from typing import Optional, Union
from pydantic import BaseModel, Field, StrictBool, StrictStr
from pydantic import Field, StrictBool, StrictStr
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import get_iso_timestamp
@@ -207,16 +207,3 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
starred=starred,
has_workflow=has_workflow,
)
class ImageCollectionCounts(BaseModel):
starred_count: int = Field(description="The number of starred images in the collection.")
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
class ImageNamesResult(BaseModel):
"""Response containing ordered image names with metadata for optimistic updates."""
image_names: list[str] = Field(description="Ordered list of image names")
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
total_count: int = Field(description="Total number of images matching the query")

View File

@@ -7,7 +7,6 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
from invokeai.app.services.image_records.image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -24,22 +23,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def get(self, image_name: str) -> ImageRecord:
with self._db.transaction() as cursor:
try:
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -47,20 +46,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if not result:
raise ImageRecordNotFoundException
@@ -68,60 +64,64 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
with self._db.transaction() as cursor:
try:
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
try:
cursor = self._conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
except sqlite3.Error as e:
raise ImageRecordSaveException from e
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
def get_many(
self,
@@ -135,162 +135,166 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
cursor = self._conn.cursor()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
query_params.append(is_intermediate)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
with self._db.transaction() as cursor:
try:
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
return count
def delete_intermediates(self) -> list[str]:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
return image_names
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
self._conn.commit()
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
def save(
self,
@@ -306,165 +310,75 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
except sqlite3.Error as e:
raise ImageRecordSaveException from e
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(board_id,),
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
created_at = datetime.fromisoformat(cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if result is None:
return None
return deserialize_image_record(dict(result))
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))

View File

@@ -6,7 +6,6 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -126,7 +125,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
@@ -148,17 +147,3 @@ class ImageServiceABC(ABC):
def delete_images_on_board(self, board_id: str):
"""Deletes all images on a board."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from invokeai.app.services.image_records.image_records_common import ImageRecord
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
@@ -39,27 +39,3 @@ def image_record_to_dto(
thumbnail_url=thumbnail_url,
board_id=board_id,
)
class ResultWithAffectedBoards(BaseModel):
affected_boards: list[str] = Field(description="The ids of boards affected by the delete operation")
class DeleteImagesResult(ResultWithAffectedBoards):
deleted_images: list[str] = Field(description="The names of the images that were deleted")
class StarredImagesResult(ResultWithAffectedBoards):
starred_images: list[str] = Field(description="The names of the images that were starred")
class UnstarredImagesResult(ResultWithAffectedBoards):
unstarred_images: list[str] = Field(description="The names of the images that were unstarred")
class AddImagesToBoardResult(ResultWithAffectedBoards):
added_images: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(ResultWithAffectedBoards):
removed_images: list[str] = Field(description="The image names that were removed from their board")

View File

@@ -10,7 +10,6 @@ from invokeai.app.services.image_files.image_files_common import (
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -79,7 +78,7 @@ class ImageService(ImageServiceABC):
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warning(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
@@ -310,27 +309,3 @@ class ImageService(ImageServiceABC):
except Exception as e:
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
try:
return self.__invoker.services.image_records.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting image names")
raise e

View File

@@ -27,10 +27,6 @@ if TYPE_CHECKING:
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.app.services.names.names_base import NameServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
@@ -58,8 +54,6 @@ class InvocationServices:
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
model_relationships: "ModelRelationshipsServiceABC",
model_relationship_records: "ModelRelationshipRecordStorageBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
@@ -87,8 +81,6 @@ class InvocationServices:
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.model_relationships = model_relationships
self.model_relationship_records = model_relationship_records
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue

View File

@@ -60,7 +60,7 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str) -> None:
def reset_stats(self):
"""Reset all stored statistics."""
pass

View File

@@ -73,9 +73,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self, graph_execution_state_id: str) -> None:
self._stats.pop(graph_execution_state_id, None)
self._cache_stats.pop(graph_execution_state_id, None)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -51,7 +51,6 @@ from invokeai.backend.model_manager.metadata import (
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
from invokeai.backend.model_manager.util.lora_metadata_extractor import apply_lora_metadata
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
@@ -149,7 +148,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning(f"Cancelling job {job.id}")
self._logger.warning("Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:
@@ -648,18 +647,10 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
@@ -668,10 +659,6 @@ class ModelInstallService(ModelInstallServiceBase):
info = info or self._probe(model_path, config)
# Apply LoRA metadata if applicable
model_images_path = self.app_config.models_path / "model_images"
apply_lora_metadata(info, model_path.resolve(), model_images_path)
model_path = model_path.resolve()
# Models in the Invoke-managed models dir should use relative paths.

View File

@@ -80,7 +80,6 @@ class ModelRecordChanges(BaseModelExcludeNull):
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
file_size: Optional[int] = Field(description="Size of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(

View File

@@ -78,6 +78,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._logger = logger
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@@ -88,33 +93,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
raise e
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(config.key)
@@ -126,7 +136,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
with self._db.transaction() as cursor:
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
DELETE FROM models
@@ -136,17 +147,22 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
with self._db.transaction() as cursor:
record = self.get_model(key)
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -158,6 +174,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -169,30 +189,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -204,15 +224,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -235,42 +255,43 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -292,68 +313,69 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
cursor = self._db.conn.cursor()
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

View File

@@ -1,25 +0,0 @@
from abc import ABC, abstractmethod
class ModelRelationshipRecordStorageBase(ABC):
"""Abstract base class for model-to-model relationship record storage."""
@abstractmethod
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Creates a relationship between two models by keys."""
pass
@abstractmethod
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Removes a relationship between two models by keys."""
pass
@abstractmethod
def get_related_model_keys(self, model_key: str) -> list[str]:
"""Gets all models keys related to a given model key."""
pass
@abstractmethod
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
"""Get related model keys for multiple models given a list of keys."""
pass

View File

@@ -1,55 +0,0 @@
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
def get_related_model_keys(self, model_key: str) -> list[str]:
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
result = [row[0] for row in cursor.fetchall()]
return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
result = [row[0] for row in cursor.fetchall()]
return result

View File

@@ -1,25 +0,0 @@
from abc import ABC, abstractmethod
class ModelRelationshipsServiceABC(ABC):
"""High-level service for managing model-to-model relationships."""
@abstractmethod
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Creates a relationship between two models keys."""
pass
@abstractmethod
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Removes a relationship between two models keys."""
pass
@abstractmethod
def get_related_model_keys(self, model_key: str) -> list[str]:
"""Gets all models keys related to a given model key."""
pass
@abstractmethod
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
"""Get related model keys for multiple models."""
pass

View File

@@ -1,9 +0,0 @@
from datetime import datetime
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class ModelRelationship(BaseModelExcludeNull):
model_key_1: str
model_key_2: str
created_at: datetime

View File

@@ -1,31 +0,0 @@
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.backend.model_manager.config import AnyModelConfig
class ModelRelationshipsService(ModelRelationshipsServiceABC):
__invoker: Invoker
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
self.__invoker.services.model_relationship_records.add_model_relationship(model_key_1, model_key_2)
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
self.__invoker.services.model_relationship_records.remove_model_relationship(model_key_1, model_key_2)
def get_related_model_keys(self, model_key: str) -> list[str]:
return self.__invoker.services.model_relationship_records.get_related_model_keys(model_key)
def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
self.add_model_relationship(model_1.key, model_2.key)
def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
self.remove_model_relationship(model_1.key, model_2.key)
def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]:
return self.get_related_model_keys(model.key)
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
return self.__invoker.services.model_relationship_records.get_related_model_keys_batch(model_keys)

View File

@@ -1,4 +1,3 @@
import gc
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
@@ -211,7 +210,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
@@ -440,12 +439,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
# GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks.
# Most queue items take seconds to execute, so the relative cost of a GC is very small.
# Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak
# allocation is well worth it.
gc.collect()
self._invoker.services.logger.info(
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
)

View File

@@ -10,8 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -19,6 +17,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -93,11 +92,6 @@ class SessionQueueBase(ABC):
"""Cancels a session queue item"""
pass
@abstractmethod
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
@@ -115,11 +109,6 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with the given batch destination"""
pass
@abstractmethod
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
"""Deletes all queue items with the given batch destination"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""
@@ -130,11 +119,6 @@ class SessionQueueBase(ABC):
"""Cancels all queue items except in-progress items"""
pass
@abstractmethod
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
"""Deletes all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
@@ -143,20 +127,10 @@ class SessionQueueBase(ABC):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""

View File

@@ -148,7 +148,7 @@ class Batch(BaseModel):
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
except NodeNotFoundError:
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
if batch_data.field_name not in type(node).model_fields:
if batch_data.field_name not in node.model_fields:
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values
@@ -205,10 +205,9 @@ class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItem(BaseModel):
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
@@ -252,7 +251,41 @@ class SessionQueueItem(BaseModel):
default=None,
description="The ID of the published workflow associated with this queue item",
)
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
@@ -332,7 +365,6 @@ class EnqueueBatchResult(BaseModel):
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
class RetryItemsResult(BaseModel):
@@ -364,18 +396,6 @@ class CancelByDestinationResult(CancelByBatchIDsResult):
pass
class DeleteByDestinationResult(BaseModel):
"""Result of deleting by a destination"""
deleted: int = Field(..., description="Number of queue items deleted")
class DeleteAllExceptCurrentResult(DeleteByDestinationResult):
"""Result of deleting all except current"""
pass
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""

View File

@@ -17,8 +17,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -26,6 +24,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
@@ -47,17 +46,22 @@ class SqliteSessionQueue(SessionQueueBase):
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
@@ -65,104 +69,102 @@ class SqliteSessionQueue(SessionQueueBase):
WHERE status = 'in_progress';
"""
)
except Exception:
self._conn.rollback()
raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
return priority
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
cursor.execute(
"""--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
self._conn.commit()
except Exception:
self._conn.rollback()
raise
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
item_ids=item_ids,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -170,40 +172,40 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -216,23 +218,8 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
""",
(item_id,),
)
row = cursor.fetchone()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
@@ -241,7 +228,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -249,34 +239,35 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -294,19 +285,24 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id = ?
AND (
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
)
"""
cursor.execute(
f"""--sql
@@ -325,28 +321,16 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
return queue_item
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
try:
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE item_id = ?
""",
(item_id,),
)
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
return queue_item
@@ -368,7 +352,8 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -378,8 +363,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id] + batch_ids
cursor.execute(
@@ -399,14 +382,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -415,8 +401,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = (queue_id, destination)
cursor.execute(
@@ -436,67 +420,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
params = (queue_id, destination)
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
DELETE
FROM session_queue
{where};
""",
(queue_id,),
)
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -504,8 +438,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id]
cursor.execute(
@@ -525,13 +457,21 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
raise
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id == ?
@@ -554,25 +494,30 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -585,6 +530,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get_queue_item(item_id)
def list_queue_items(
@@ -594,45 +543,53 @@ class SqliteSessionQueue(SessionQueueBase):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
) -> CursorPaginatedResults[SessionQueueItemDTO]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
AND status = ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
@@ -640,52 +597,21 @@ class SqliteSessionQueue(SessionQueueBase):
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor.execute(query, params)
results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
@@ -701,20 +627,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
@@ -733,20 +659,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueCountsByDestination(
@@ -762,7 +688,8 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
@@ -813,6 +740,10 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -2,12 +2,11 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
BaseModel,
ConfigDict,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
@@ -58,32 +57,17 @@ class Edge(BaseModel):
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
invocation_output_class = invocation_class.get_output_annotation()
field_info = invocation_output_class.model_fields.get(field)
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
output_field_type = field_info.annotation
return output_field_type
except Exception:
return None
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
field_info = invocation_class.model_fields.get(field)
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
input_field_type = field_info.annotation
return input_field_type
except Exception:
return None
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
def is_union_subtype(t1, t2):
@@ -440,7 +424,7 @@ class Graph(BaseModel):
)
# input fields are on the node
if edge.destination.field not in type(destination_node).model_fields:
if edge.destination.field not in destination_node.model_fields:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
@@ -803,22 +787,6 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
model_config = ConfigDict(
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
@@ -1007,11 +975,10 @@ class GraphExecutionState(BaseModel):
new_node_ids = []
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = []
for source_node_id in next_node_parents:
prepared_nodes = self.source_prepared_mapping[source_node_id]
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
all_iteration_mappings = list(
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
)
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)

View File

@@ -18,10 +18,9 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import diffusion_step_callback
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
@@ -544,30 +543,6 @@ class ModelsInterface(InvocationContextInterface):
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path:
"""Gets the absolute path for a given model config or path.
For example, if the model's path is `flux/main/FLUX Dev.safetensors`, and the models path is
`/home/username/InvokeAI/models`, this method will return
`/home/username/InvokeAI/models/flux/main/FLUX Dev.safetensors`.
Args:
config_or_path: The model config or path.
Returns:
The absolute path to the model.
"""
model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path)
if model_path.is_absolute():
return model_path.resolve()
base_models_path = self._services.configuration.models_path
joined_path = base_models_path / model_path
resolved_path = joined_path.resolve()
return resolved_path
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
@@ -607,7 +582,7 @@ class UtilInterface(InvocationContextInterface):
base_model: The base model for the current denoising step.
"""
diffusion_step_callback(
stable_diffusion_step_callback(
signal_progress=self.signal_progress,
intermediate_state=intermediate_state,
base_model=base_model,
@@ -625,10 +600,9 @@ class UtilInterface(InvocationContextInterface):
intermediate_state: The intermediate state of the diffusion pipeline.
"""
diffusion_step_callback(
flux_step_callback(
signal_progress=self.signal_progress,
intermediate_state=intermediate_state,
base_model=BaseModelType.Flux,
is_canceled=self.is_canceled,
)

View File

@@ -1,7 +1,4 @@
import sqlite3
import threading
from collections.abc import Generator
from contextlib import contextmanager
from logging import Logger
from pathlib import Path
@@ -29,65 +26,46 @@ class SqliteDatabase:
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self._logger = logger
self._db_path = db_path
self._verbose = verbose
self._lock = threading.RLock()
self.logger = logger
self.db_path = db_path
self.verbose = verbose
if not self._db_path:
if not self.db_path:
logger.info("Initializing in-memory database")
else:
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Initializing database at {self._db_path}")
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
if self._verbose:
self._conn.set_trace_callback(self._logger.debug)
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
# Enable foreign key constraints
self._conn.execute("PRAGMA foreign_keys = ON;")
self.conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self._conn.execute("PRAGMA journal_mode = WAL;")
self.conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self._db_path:
if not self.db_path:
return
try:
with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
final_db_size = Path(self._db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self._logger.error(f"Error cleaning database: {e}")
self.logger.error(f"Error cleaning database: {e}")
raise
@contextmanager
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
cursor.close()

View File

@@ -21,8 +21,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import build_migration_16
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -61,8 +59,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_16())
migrator.register_migration(build_migration_17())
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.run_migrations()
return db

View File

@@ -1,37 +0,0 @@
import sqlite3
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
class Migration19Callback:
def __init__(self, app_config: InvokeAIAppConfig):
self.models_path = app_config.models_path
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._populate_size(cursor)
self._add_size_column(cursor)
def _add_size_column(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"ALTER TABLE models ADD COLUMN file_size INTEGER "
"GENERATED ALWAYS as (json_extract(config, '$.file_size')) VIRTUAL NOT NULL"
)
def _populate_size(self, cursor: sqlite3.Cursor) -> None:
all_models = cursor.execute("SELECT id, path FROM models;").fetchall()
for model_id, model_path in all_models:
mod = ModelOnDisk(self.models_path / model_path)
cursor.execute(
"UPDATE models SET config = json_set(config, '$.file_size', ?) WHERE id = ?", (mod.size(), model_id)
)
def build_migration_19(app_config: InvokeAIAppConfig) -> Migration:
return Migration(
from_version=18,
to_version=19,
callback=Migration19Callback(app_config),
)

View File

@@ -1,37 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration20Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
-- many-to-many relationship table for models
CREATE TABLE IF NOT EXISTS model_relationships (
-- model_key_1 and model_key_2 are the same as the key(primary key) in the models table
model_key_1 TEXT NOT NULL,
model_key_2 TEXT NOT NULL,
created_at TEXT DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
PRIMARY KEY (model_key_1, model_key_2),
-- model_key_1 < model_key_2, to ensure uniqueness and prevent duplicates
FOREIGN KEY (model_key_1) REFERENCES models(id) ON DELETE CASCADE,
FOREIGN KEY (model_key_2) REFERENCES models(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""
-- Creates an index to keep performance equal when searching for model_key_1 or model_key_2
CREATE INDEX IF NOT EXISTS keyx_model_relationships_model_key_2
ON model_relationships(model_key_2)
"""
)
def build_migration_20() -> Migration:
return Migration(
from_version=19,
to_version=20,
callback=Migration20Callback(),
)

View File

@@ -32,7 +32,7 @@ class SqliteMigrator:
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db._logger
self._logger = db.logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
@@ -45,7 +45,7 @@ class SqliteMigrator:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db._conn.cursor()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
@@ -59,13 +59,13 @@ class SqliteMigrator:
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db._db_path is not None:
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db._conn.backup(backup_conn)
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
@@ -81,7 +81,7 @@ class SqliteMigrator:
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db._conn as conn:
with self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(

View File

@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,23 +25,24 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -59,11 +60,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -84,11 +90,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -111,10 +122,15 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE from style_presets
@@ -122,41 +138,51 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
with self._db.transaction() as cursor:
main_query = """
SELECT
*
FROM style_presets
"""
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
rows = cursor.fetchall()
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
with self._db.transaction() as cursor:
# First delete all existing default style presets
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)

View File

@@ -1,343 +0,0 @@
{
"name": "Text to Image - CogView4",
"author": "",
"description": "Generate an image from a prompt with CogView4.",
"version": "",
"contact": "",
"tags": "CogView4, Text to Image",
"notes": "",
"exposedFields": [],
"meta": { "category": "default", "version": "3.0.0" },
"id": "default_0e405a8e-ab5e-4e6c-bd99-b59deabd5591",
"form": {
"elements": {
"container-XSINSu999B": {
"id": "container-XSINSu999B",
"data": {
"layout": "column",
"children": [
"heading-N0TXlsboP5",
"text-PVw8AvXCTz",
"divider-5wmCOm9mqG",
"node-field-gPil4XSw8L",
"node-field-T2oYYNrAzH",
"node-field-SRj6Dn28lm"
]
},
"type": "container"
},
"node-field-gPil4XSw8L": {
"id": "node-field-gPil4XSw8L",
"type": "node-field",
"parentId": "container-XSINSu999B",
"data": {
"fieldIdentifier": {
"nodeId": "a4569d8b-6a43-44b9-8919-4ceec6682904",
"fieldName": "prompt"
},
"settings": {
"type": "string-field-config",
"component": "textarea"
},
"showDescription": false
}
},
"node-field-T2oYYNrAzH": {
"id": "node-field-T2oYYNrAzH",
"type": "node-field",
"parentId": "container-XSINSu999B",
"data": {
"fieldIdentifier": {
"nodeId": "acb26944-1208-4016-9929-ab8dd0860573",
"fieldName": "prompt"
},
"settings": {
"type": "string-field-config",
"component": "textarea"
},
"showDescription": false
}
},
"node-field-SRj6Dn28lm": {
"id": "node-field-SRj6Dn28lm",
"type": "node-field",
"parentId": "container-XSINSu999B",
"data": {
"fieldIdentifier": {
"nodeId": "7890507c-d346-4d13-bcb4-bc6d4850b2e3",
"fieldName": "model"
},
"showDescription": false
}
},
"heading-N0TXlsboP5": {
"id": "heading-N0TXlsboP5",
"parentId": "container-XSINSu999B",
"type": "heading",
"data": { "content": "Text to Image - CogView4" }
},
"text-PVw8AvXCTz": {
"id": "text-PVw8AvXCTz",
"parentId": "container-XSINSu999B",
"type": "text",
"data": { "content": "Generate an image from a prompt with CogView4." }
},
"divider-5wmCOm9mqG": {
"id": "divider-5wmCOm9mqG",
"parentId": "container-XSINSu999B",
"type": "divider"
}
},
"rootElementId": "container-XSINSu999B"
},
"nodes": [
{
"id": "7890507c-d346-4d13-bcb4-bc6d4850b2e3",
"type": "invocation",
"data": {
"id": "7890507c-d346-4d13-bcb4-bc6d4850b2e3",
"version": "1.0.0",
"nodePack": "invokeai",
"label": "",
"notes": "",
"type": "cogview4_model_loader",
"inputs": {
"model": {
"name": "model",
"label": ""
}
},
"isOpen": true,
"isIntermediate": true,
"useCache": true
},
"position": { "x": -52.193850056888095, "y": 282.4721422789611 }
},
{
"id": "a4569d8b-6a43-44b9-8919-4ceec6682904",
"type": "invocation",
"data": {
"id": "a4569d8b-6a43-44b9-8919-4ceec6682904",
"version": "1.0.0",
"nodePack": "invokeai",
"label": "",
"notes": "",
"type": "cogview4_text_encoder",
"inputs": {
"prompt": {
"name": "prompt",
"label": "Positive Prompt",
"description": "",
"value": "A whimsical stuffed gnome sits on a golden sandy beach, its plush fabric slightly textured and well-worn. The gnome has a round, cheerful face with a fluffy white beard, a bulbous nose, and a tall, slightly floppy red hat with a few decorative stitching details. It wears a tiny blue vest over a soft, earthy-toned tunic, and its stubby arms grasp a ripe yellow banana with a few brown speckles. The ocean waves gently roll onto the shore in the background, with turquoise water reflecting the warm glow of the late afternoon sun. A few scattered seashells and driftwood pieces are near the gnome, while a colorful beach umbrella and footprints in the sand hint at a lively beach scene. The sky is a soft pastel blend of pink, orange, and light blue, with wispy clouds stretching across the horizon.\n"
},
"glm_encoder": {
"name": "glm_encoder",
"label": "",
"description": ""
}
},
"isOpen": true,
"isIntermediate": true,
"useCache": true
},
"position": { "x": 328.9380683664592, "y": 305.11768986950995 }
},
{
"id": "acb26944-1208-4016-9929-ab8dd0860573",
"type": "invocation",
"data": {
"id": "acb26944-1208-4016-9929-ab8dd0860573",
"version": "1.0.0",
"nodePack": "invokeai",
"label": "",
"notes": "",
"type": "cogview4_text_encoder",
"inputs": {
"prompt": {
"name": "prompt",
"label": "Negative Prompt",
"description": "",
"value": ""
},
"glm_encoder": {
"name": "glm_encoder",
"label": "",
"description": ""
}
},
"isOpen": true,
"isIntermediate": true,
"useCache": true
},
"position": { "x": 334.6799782744916, "y": 496.5882067536601 }
},
{
"id": "cdd72700-463d-4e10-8d76-3e842e4c0b49",
"type": "invocation",
"data": {
"id": "cdd72700-463d-4e10-8d76-3e842e4c0b49",
"version": "1.0.0",
"nodePack": "invokeai",
"label": "",
"notes": "",
"type": "cogview4_l2i",
"inputs": {
"board": {
"name": "board",
"label": "",
"description": "",
"value": "auto"
},
"metadata": { "name": "metadata", "label": "", "description": "" },
"latents": { "name": "latents", "label": "", "description": "" },
"vae": { "name": "vae", "label": "", "description": "" }
},
"isOpen": true,
"isIntermediate": false,
"useCache": true
},
"position": { "x": 1112.027247217991, "y": 294.1351498145327 }
},
{
"id": "e75e2ced-284e-4135-81dc-cdf06c7a409d",
"type": "invocation",
"data": {
"id": "e75e2ced-284e-4135-81dc-cdf06c7a409d",
"version": "1.0.0",
"nodePack": "invokeai",
"label": "",
"notes": "",
"type": "cogview4_denoise",
"inputs": {
"board": {
"name": "board",
"label": "",
"description": "",
"value": "auto"
},
"metadata": { "name": "metadata", "label": "", "description": "" },
"latents": { "name": "latents", "label": "", "description": "" },
"denoise_mask": {
"name": "denoise_mask",
"label": "",
"description": ""
},
"denoising_start": {
"name": "denoising_start",
"label": "",
"description": "",
"value": 0
},
"denoising_end": {
"name": "denoising_end",
"label": "",
"description": "",
"value": 1
},
"transformer": {
"name": "transformer",
"label": "",
"description": ""
},
"positive_conditioning": {
"name": "positive_conditioning",
"label": "",
"description": ""
},
"negative_conditioning": {
"name": "negative_conditioning",
"label": "",
"description": ""
},
"cfg_scale": {
"name": "cfg_scale",
"label": "",
"description": "",
"value": 3.5
},
"width": {
"name": "width",
"label": "",
"description": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"description": "",
"value": 1024
},
"steps": {
"name": "steps",
"label": "",
"description": "",
"value": 30
},
"seed": { "name": "seed", "label": "", "description": "", "value": 0 }
},
"isOpen": true,
"isIntermediate": true,
"useCache": false
},
"position": { "x": 720.8830004638692, "y": 332.66609681908415 }
}
],
"edges": [
{
"id": "reactflow__edge-7890507c-d346-4d13-bcb4-bc6d4850b2e3vae-cdd72700-463d-4e10-8d76-3e842e4c0b49vae",
"type": "default",
"source": "7890507c-d346-4d13-bcb4-bc6d4850b2e3",
"target": "cdd72700-463d-4e10-8d76-3e842e4c0b49",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-7890507c-d346-4d13-bcb4-bc6d4850b2e3glm_encoder-a4569d8b-6a43-44b9-8919-4ceec6682904glm_encoder",
"type": "default",
"source": "7890507c-d346-4d13-bcb4-bc6d4850b2e3",
"target": "a4569d8b-6a43-44b9-8919-4ceec6682904",
"sourceHandle": "glm_encoder",
"targetHandle": "glm_encoder"
},
{
"id": "reactflow__edge-7890507c-d346-4d13-bcb4-bc6d4850b2e3glm_encoder-acb26944-1208-4016-9929-ab8dd0860573glm_encoder",
"type": "default",
"source": "7890507c-d346-4d13-bcb4-bc6d4850b2e3",
"target": "acb26944-1208-4016-9929-ab8dd0860573",
"sourceHandle": "glm_encoder",
"targetHandle": "glm_encoder"
},
{
"id": "reactflow__edge-a4569d8b-6a43-44b9-8919-4ceec6682904conditioning-e75e2ced-284e-4135-81dc-cdf06c7a409dpositive_conditioning",
"type": "default",
"source": "a4569d8b-6a43-44b9-8919-4ceec6682904",
"target": "e75e2ced-284e-4135-81dc-cdf06c7a409d",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-acb26944-1208-4016-9929-ab8dd0860573conditioning-e75e2ced-284e-4135-81dc-cdf06c7a409dnegative_conditioning",
"type": "default",
"source": "acb26944-1208-4016-9929-ab8dd0860573",
"target": "e75e2ced-284e-4135-81dc-cdf06c7a409d",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-e75e2ced-284e-4135-81dc-cdf06c7a409dlatents-cdd72700-463d-4e10-8d76-3e842e4c0b49latents",
"type": "default",
"source": "e75e2ced-284e-4135-81dc-cdf06c7a409d",
"target": "cdd72700-463d-4e10-8d76-3e842e4c0b49",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-7890507c-d346-4d13-bcb4-bc6d4850b2e3transformer-e75e2ced-284e-4135-81dc-cdf06c7a409dtransformer",
"type": "default",
"source": "7890507c-d346-4d13-bcb4-bc6d4850b2e3",
"target": "e75e2ced-284e-4135-81dc-cdf06c7a409d",
"sourceHandle": "transformer",
"targetHandle": "transformer"
}
]
}

View File

@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
@@ -51,8 +51,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
with self._db.transaction() as cursor:
try:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
@@ -63,13 +64,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -78,13 +84,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -92,6 +103,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(
@@ -106,108 +121,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# We will construct the query dynamically based on the query params
# We will construct the query dynamically based on the query params
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
conditions.append(category_condition)
params.extend(category_params)
conditions.append(category_condition)
params.extend(category_params)
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
conditions.append(tags_condition)
params.extend(tags_params)
conditions.append(tags_condition)
params.extend(tags_params)
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Put a ring on it
main_query += ";"
count_query += ";"
# Put a ring on it
main_query += ";"
count_query += ";"
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
@@ -232,46 +247,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
@@ -281,51 +296,52 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -334,6 +350,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -348,7 +368,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
@@ -428,3 +449,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.model_dump_json(), w.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise

View File

@@ -230,86 +230,6 @@ def heuristic_resize(np_img: np.ndarray[Any, Any], size: tuple[int, int]) -> np.
return resized
# precompute common kernels
_KERNEL3 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# directional masks for NMS
_DIRS = [
np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], np.uint8),
np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], np.uint8),
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], np.uint8),
np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], np.uint8),
]
def heuristic_resize_fast(np_img: np.ndarray, size: tuple[int, int]) -> np.ndarray:
h, w = np_img.shape[:2]
# early exit
if (w, h) == size:
return np_img
# separate alpha channel
img = np_img
alpha = None
if img.ndim == 3 and img.shape[2] == 4:
alpha, img = img[:, :, 3], img[:, :, :3]
# build small sample for uniquecolor & binary detection
flat = img.reshape(-1, img.shape[-1])
N = flat.shape[0]
# include four corners to avoid missing extreme values
corners = np.vstack([img[0, 0], img[0, w - 1], img[h - 1, 0], img[h - 1, w - 1]])
cnt = min(N, 100_000)
samp = np.vstack([corners, flat[np.random.choice(N, cnt, replace=False)]])
uc = np.unique(samp, axis=0).shape[0]
vmin, vmax = samp.min(), samp.max()
# detect binary edge map & onepixeledge case
is_binary = uc == 2 and vmin < 16 and vmax > 240
one_pixel_edge = False
if is_binary:
# single gray conversion
gray0 = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
grad = cv2.morphologyEx(gray0, cv2.MORPH_GRADIENT, _KERNEL3)
cnt_edge = cv2.countNonZero(grad)
cnt_all = cv2.countNonZero((gray0 > 127).astype(np.uint8))
one_pixel_edge = (2 * cnt_edge) > cnt_all
# choose interp for color/seg/grayscale
area_new, area_old = size[0] * size[1], w * h
if 2 < uc < 200: # segmentation map
interp = cv2.INTER_NEAREST
elif area_new < area_old:
interp = cv2.INTER_AREA
else:
interp = cv2.INTER_CUBIC
# single resize pass on RGB
resized = cv2.resize(img, size, interpolation=interp)
if is_binary:
# convert to gray & apply NMS via C++ dilate
gray_r = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
nms = np.zeros_like(gray_r)
for K in _DIRS:
d = cv2.dilate(gray_r, K)
mask = d == gray_r
nms[mask] = gray_r[mask]
# threshold + thinning if needed
_, bw = cv2.threshold(nms, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
out_bin = cv2.ximgproc.thinning(bw) if one_pixel_edge else bw
# restore 3 channels
resized = np.stack([out_bin] * 3, axis=2)
# restore alpha with same interp as RGB for consistency
if alpha is not None:
am = cv2.resize(alpha, size, interpolation=interp)
am = (am > 127).astype(np.uint8) * 255
resized = np.dstack((resized, am))
return resized
###########################################################################
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
# modified for InvokeAI
@@ -324,7 +244,7 @@ def np_img_resize(
np_img = normalize_image_channel_count(np_img)
if resize_mode == "just_resize": # RESIZE
np_img = heuristic_resize_fast(np_img, (w, h))
np_img = heuristic_resize(np_img, (w, h))
np_img = clone_contiguous(np_img)
return np_img_to_torch(np_img, device), np_img
@@ -345,7 +265,7 @@ def np_img_resize(
# Inpaint hijack
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
np_img = heuristic_resize_fast(np_img, (safeint(old_w * k), safeint(old_h * k)))
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
@@ -355,7 +275,7 @@ def np_img_resize(
return np_img_to_torch(np_img, device), np_img
else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1)
np_img = heuristic_resize_fast(np_img, (safeint(old_w * k), safeint(old_h * k)))
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)

View File

@@ -12,9 +12,6 @@ from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFie
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
@@ -64,10 +61,6 @@ def get_openapi_func(
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
# Remove output_metadata that is only used on back-end from the schema
if "output_meta" in json_schema["properties"]:
json_schema["properties"].pop("output_meta")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema

View File

@@ -10,7 +10,7 @@ def get_timestamp() -> int:
def get_iso_timestamp() -> str:
return datetime.datetime.now(datetime.timezone.utc).isoformat()
return datetime.datetime.utcnow().isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:

View File

@@ -8,8 +8,6 @@ from invokeai.app.services.session_processor.session_processor_common import Can
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
# See scripts/generate_vae_linear_approximation.py for generating these factors.
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
@@ -74,32 +72,11 @@ FLUX_LATENT_RGB_FACTORS = [
[-0.1146, -0.0827, -0.0598],
]
COGVIEW4_LATENT_RGB_FACTORS = [
[0.00408832, -0.00082485, -0.00214816],
[0.00084172, 0.00132241, 0.00842067],
[-0.00466737, -0.00983181, -0.00699561],
[0.03698397, -0.04797235, 0.03585809],
[0.00234701, -0.00124326, 0.00080869],
[-0.00723903, -0.00388422, -0.00656606],
[-0.00970917, -0.00467356, -0.00971113],
[0.17292486, -0.03452463, -0.1457515],
[0.02330308, 0.02942557, 0.02704329],
[-0.00903131, -0.01499841, -0.01432564],
[0.01250298, 0.0019407, -0.02168986],
[0.01371188, 0.00498283, -0.01302135],
[0.42396525, 0.4280575, 0.42148206],
[0.00983825, 0.00613302, 0.00610316],
[0.00473307, -0.00889551, -0.00915924],
[-0.00955853, -0.00980067, -0.00977842],
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
if samples.dim() == 4:
samples = samples[0]
latent_image = samples.permute(1, 2, 0) @ latent_rgb_factors
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
@@ -123,11 +100,7 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
if total_steps == 0:
return 0.0
if order == 2:
# Prevent division by zero when total_steps is 1 or 2
denominator = floor(total_steps / 2)
if denominator == 0:
return 0.0
return floor(step / 2) / denominator
return floor(step / 2) / floor(total_steps / 2)
# order == 1
return step / total_steps
@@ -135,7 +108,7 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
SignalProgressFunc: TypeAlias = Callable[[str, float | None, Image.Image | None, tuple[int, int] | None], None]
def diffusion_step_callback(
def stable_diffusion_step_callback(
signal_progress: SignalProgressFunc,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
@@ -152,28 +125,39 @@ def diffusion_step_callback(
else:
sample = intermediate_state.latents
smooth_matrix: list[list[float]] | None = None
if base_model in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
latent_rgb_factors = SD1_5_LATENT_RGB_FACTORS
elif base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
latent_rgb_factors = SDXL_LATENT_RGB_FACTORS
smooth_matrix = SDXL_SMOOTH_MATRIX
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
elif base_model == BaseModelType.StableDiffusion3:
latent_rgb_factors = SD3_5_LATENT_RGB_FACTORS
elif base_model == BaseModelType.CogView4:
latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS
elif base_model == BaseModelType.Flux:
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
else:
raise ValueError(f"Unsupported base model: {base_model}")
latent_rgb_factors_torch = torch.tensor(latent_rgb_factors, dtype=sample.dtype, device=sample.device)
smooth_matrix_torch = (
torch.tensor(smooth_matrix, dtype=sample.dtype, device=sample.device) if smooth_matrix else None
)
image = sample_to_lowres_estimated_image(
samples=sample, latent_rgb_factors=latent_rgb_factors_torch, smooth_matrix=smooth_matrix_torch
)
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
width = image.width * 8
height = image.height * 8
percentage = calc_percentage(intermediate_state)
signal_progress("Denoising", percentage, image, (width, height))
def flux_step_callback(
signal_progress: SignalProgressFunc,
intermediate_state: PipelineIntermediateState,
is_canceled: Callable[[], bool],
) -> None:
if is_canceled():
raise CanceledException
sample = intermediate_state.latents
latent_rgb_factors = torch.tensor(FLUX_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
latent_image_perm = sample.permute(1, 2, 0).to(dtype=sample.dtype, device=sample.device)
latent_image = latent_image_perm @ latent_rgb_factors
latents_ubyte = (
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
).to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(latents_ubyte.cpu().numpy())
width = image.width * 8
height = image.height * 8

View File

@@ -1,314 +0,0 @@
import math
import os
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from diffusers.utils import logging
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_t5_prompt_embeds(
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str], None] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
# padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
# Concat zeros to max_sequence
b, seq_len, dim = prompt_embeds.shape
if seq_len < max_sequence_length:
padding = torch.zeros(
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
prompt_embeds = prompt_embeds.to(device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# in order the get the same sigmas as in training and sample from them
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
sigmas = timesteps / num_train_timesteps
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
new_sigmas = sigmas[inds]
return new_sigmas
def is_ng_none(negative_prompt):
return (
negative_prompt is None
or negative_prompt == ""
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
or (isinstance(negative_prompt, list) and negative_prompt[0] == "")
)
class CudaTimerContext:
def __init__(self, times_arr):
self.times_arr = times_arr
def __enter__(self):
self.before_event = torch.cuda.Event(enable_timing=True)
self.after_event = torch.cuda.Event(enable_timing=True)
self.before_event.record()
def __exit__(self, type, value, traceback):
self.after_event.record()
torch.cuda.synchronize()
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
self.times_arr.append(elapsed_time)
def get_env_prefix():
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
if env == "AWS":
return "SM_CHANNEL"
elif env == "AZURE":
return "AZUREML_DATAREFERENCE"
raise Exception(f"Env {env} not supported")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def initialize_distributed():
# Initialize the process group for distributed training
dist.init_process_group("nccl")
# Get the current process's rank (ID) and the total number of processes (world size)
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"Initialized distributed training: Rank {rank}/{world_size}")
def get_clip_prompt_embeds(
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
assert max_sequence_length == tokenizer.model_max_length
prompt = [prompt] if isinstance(prompt, str) else prompt
# Define tokenizers and text encoders
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, pooled_prompt_embeds
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class FluxPosEmbed(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin

Some files were not shown because too many files have changed in this diff Show More