Compare commits

..

1 Commits

Author SHA1 Message Date
Mary Hipp
59bd6b935d wip missing fields prototype 2025-02-19 16:00:33 -05:00
448 changed files with 8664 additions and 19146 deletions

6
.github/CODEOWNERS vendored
View File

@@ -1,12 +1,12 @@
# continuous integration
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername
# installation and configuration
/pyproject.toml @lstein @blessedcoolant @hipsterusername
@@ -22,7 +22,7 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername @jazzhaiku
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername
# front ends
/invokeai/frontend/CLI @lstein @hipsterusername

View File

@@ -76,6 +76,9 @@ jobs:
latest=${{ matrix.gpu-driver == 'cuda' && github.ref == 'refs/heads/main' }}
suffix=-${{ matrix.gpu-driver }},onlatest=false
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
@@ -100,7 +103,7 @@ jobs:
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
# cache-from: |
# type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
# type=gha,scope=main-${{ matrix.gpu-driver }}
# cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
cache-from: |
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
type=gha,scope=main-${{ matrix.gpu-driver }}
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}

View File

@@ -44,12 +44,7 @@ jobs:
- name: check for changed frontend files
if: ${{ inputs.always_run != true }}
id: changed-files
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
# See:
# - CVE-2025-30066
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
# - https://github.com/tj-actions/changed-files/issues/2463
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:

View File

@@ -44,12 +44,7 @@ jobs:
- name: check for changed frontend files
if: ${{ inputs.always_run != true }}
id: changed-files
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
# See:
# - CVE-2025-30066
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
# - https://github.com/tj-actions/changed-files/issues/2463
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:

View File

@@ -43,12 +43,7 @@ jobs:
- name: check for changed python files
if: ${{ inputs.always_run != true }}
id: changed-files
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
# See:
# - CVE-2025-30066
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
# - https://github.com/tj-actions/changed-files/issues/2463
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
@@ -67,7 +62,7 @@ jobs:
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff==0.9.9
run: pip install ruff==0.6.0
shell: bash
- name: ruff check

View File

@@ -77,12 +77,7 @@ jobs:
- name: check for changed python files
if: ${{ inputs.always_run != true }}
id: changed-files
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
# See:
# - CVE-2025-30066
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
# - https://github.com/tj-actions/changed-files/issues/2463
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:

View File

@@ -42,12 +42,7 @@ jobs:
- name: check for changed files
if: ${{ inputs.always_run != true }}
id: changed-files
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
# See:
# - CVE-2025-30066
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
# - https://github.com/tj-actions/changed-files/issues/2463
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
uses: tj-actions/changed-files@v42
with:
files_yaml: |
src:

View File

@@ -13,63 +13,48 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
git
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
ENV VIRTUAL_ENV=/opt/venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ENV INVOKEAI_SRC=/opt/invokeai
ENV PYTHON_VERSION=3.11
ENV UV_PYTHON=3.11
ENV UV_COMPILE_BYTECODE=1
ENV UV_LINK_MODE=copy
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
ARG GPU_DRIVER=cuda
ARG TARGETPLATFORM="linux/amd64"
# unused but available
ARG BUILDPLATFORM
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
RUN mkdir -p ${VIRTUAL_ENV} && \
mkdir -p ${INVOKEAI_SRC} && \
chmod -R a+w /opt && \
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
chmod -R a+w /opt
USER ubuntu
# Install python
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
uv python install ${PYTHON_VERSION}
# Install python and create the venv
RUN uv python install ${PYTHON_VERSION} && \
uv venv --relocatable --prompt "invoke" --python ${PYTHON_VERSION} ${VIRTUAL_ENV}
WORKDIR ${INVOKEAI_SRC}
COPY invokeai ./invokeai
COPY pyproject.toml ./
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# Editable mode helps use the same image for development:
# the local working copy can be bind-mounted into the image
# at path defined by ${INVOKEAI_SRC}
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
fi && \
uv sync --no-install-project
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
COPY invokeai invokeai
COPY pyproject.toml .
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv sync
uv pip install --python ${PYTHON_VERSION} $extra_index_url_arg -e "."
#### Build the Web UI ------------------------------------
@@ -113,7 +98,6 @@ RUN apt update && apt install -y --no-install-recommends \
ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV PYTHON_VERSION=3.11
ENV INVOKEAI_ROOT=/invokeai
ENV INVOKEAI_HOST=0.0.0.0
@@ -125,7 +109,7 @@ ENV CONTAINER_GID=${CONTAINER_GID:-1000}
# Install `uv` for package management
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
USER ubuntu
RUN uv python install ${PYTHON_VERSION}
USER root

View File

@@ -1,50 +1,41 @@
# Release Process
The Invoke application is published as a python package on [PyPI]. This includes both a source distribution and built distribution (a wheel).
The app is published in twice, in different build formats.
Most users install it with the [Launcher](https://github.com/invoke-ai/launcher/), others with `pip`.
The launcher uses GitHub as the source of truth for available releases.
## Broad Strokes
- Merge all changes and bump the version in the codebase.
- Tag the release commit.
- Wait for the release workflow to complete.
- Approve the PyPI publish jobs.
- Write GH release notes.
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
## General Prep
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
Make a developer call-out for PRs to merge. Merge and test things out.
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
## Release Workflow
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
It is triggered on **tag push**, when the tag matches `v*`.
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
### Triggering the Workflow
Ensure all commits that should be in the release are merged, and you have pulled them locally.
Run `make tag-release` to tag the current commit and kick off the workflow.
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
The release may also be dispatched [manually].
### Workflow Jobs and Process
The workflow consists of a number of concurrently-run checks and tests, then two final publish jobs.
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
The publish jobs require manual approval and are only run if the other jobs succeed.
#### `check-version` Job
This job ensures that the `invokeai` python package version specifier matches the tag for the release. The version specifier is pulled from the `__version__` variable in `invokeai/version/invokeai_version.py`.
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
This job uses [samuelcolvin/check-python-version].
@@ -52,52 +43,62 @@ This job uses [samuelcolvin/check-python-version].
#### Check and Test Jobs
Next, these jobs run and must pass. They are the same jobs that are run for every PR.
- **`python-tests`**: runs `pytest` on matrix of platforms
- **`python-checks`**: runs `ruff` (format and lint)
- **`frontend-tests`**: runs `vitest`
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
- **`typegen-checks`**: ensures the frontend and backend types are synced
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
> **TODO** We should add an end-to-end test job that generates an image.
#### `build-installer` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
- **`dist`**: the python distribution, to be published on PyPI
- **`InvokeAI-installer-${VERSION}.zip`**: the legacy install scripts
You don't need to download either of these files.
> The legacy install scripts are no longer used, but we haven't updated the workflow to skip building them.
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
#### Sanity Check & Smoke Test
At this point, the release workflow pauses as the remaining publish jobs require approval.
At this point, the release workflow pauses as the remaining publish jobs require approval. Time to test the installer.
It's possible to test the python package before it gets published to PyPI. We've never had problems with it, so it's not necessary to do this.
Because the installer pulls from PyPI, and we haven't published to PyPI yet, you will need to install from the wheel:
But, if you want to be extra-super careful, here's how to test it:
- Download and unzip `dist.zip` and the installer from the **Summary** tab of the workflow
- Run the installer script using the `--wheel` CLI arg, pointing at the wheel:
- Download the `dist.zip` build artifact from the `build-installer` job
- Unzip it and find the wheel file
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/) - but instead of installing from PyPI, install from the wheel
- Test the app
```sh
./install.sh --wheel ../InvokeAI-4.0.0rc6-py3-none-any.whl
```
- Install to a temporary directory so you get the new user experience
- Download a model and generate
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation as if the installer got the wheel from PyPI.
##### Something isn't right
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?) and start over.
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?).
Now you can start from the top:
- Fix the issues and PR the fixes per usual
- Get the PR approved and merged per usual
- Switch to `main` and pull in the fixes
- Run `make tag-release` to move the tag to `HEAD` (which has the fixes) and kick off the release workflow again
- Re-do the sanity check
#### PyPI Publish Jobs
The publish jobs will not run if any of the previous jobs fail.
The publish jobs will run if any of the previous jobs fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
- Click the **Review deployments** button
- Select the environment (either `testpypi` or `pypi` - typically you select both)
- Select the environment (either `testpypi` or `pypi`)
- Click **Approve and deploy**
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
@@ -112,33 +113,46 @@ If there are no incidents, contact @hipsterusername or @lstein, who have owner a
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release for some reason:
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
- Approve this publish job without approving the prod publish
- Let it finish
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/), making sure to use the Test PyPI index URL: `https://test.pypi.org/simple/`
- Test the app
If approved and successful, you could try out the test release like this:
```sh
# Create a new virtual environment
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
# Install the distribution from Test PyPI
pip install --index-url https://test.pypi.org/simple/ invokeai
# Run and test the app
invokeai-web
# Cleanup
deactivate
rm -rf ~/.test-invokeai-dist
```
#### `publish-pypi` Job
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
It's a good idea to wait to approve and run this job until you have the release notes ready!
## Publish the GitHub Release with installer
## Prep and publish the GitHub Release
Once the release is published to PyPI, it's time to publish the GitHub release.
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
2. The **Generate release notes** button automatically inserts the changelog and new contributors. Make sure to select the correct tags for this release and the last stable release. GH often selects the wrong tags - do this manually.
3. Write the release notes, describing important changes. Contributions from community members should be shouted out. Use the GH-generated changelog to see all contributors. If there are Weblate translation updates, open that PR and shout out every person who contributed a translation.
4. Check **Set as a pre-release** if it's a pre-release.
5. Approve and wait for the `publish-pypi` job to finish if you haven't already.
6. Publish the GH release.
7. Post the release in Discord in the [releases](https://discord.com/channels/1020123559063990373/1149260708098359327) channel with abbreviated notes. For example:
> Invoke v5.7.0 (stable): <https://github.com/invoke-ai/InvokeAI/releases/tag/v5.7.0>
>
> It's a pretty big one - Form Builder, Metadata Nodes (thanks @SkunkWorxDark!), and much more.
8. Right click the message in releases and copy the link to it. Then, post that link in the [new-release-discussion](https://discord.com/channels/1020123559063990373/1149506274971631688) channel. For example:
> Invoke v5.7.0 (stable): <https://discord.com/channels/1020123559063990373/1149260708098359327/1344521744916021248>
1. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
1. Use `scripts/get_external_contributions.py` to get a list of external contributions to shout out in the release notes.
1. Upload the zip file created in **`build`** job into the Assets section of the release notes.
1. Check **Set as a pre-release** if it's a pre-release.
1. Check **Create a discussion for this release**.
1. Publish the release.
1. Announce the release in Discord.
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
## Manual Build
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
No checks are run, it just builds.
## Manual Release
@@ -146,10 +160,12 @@ The `release` workflow can be dispatched manually. You must dispatch the workflo
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
[PyPI]: https://pypi.org/
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
[Test PyPI]: https://test.pypi.org/
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
[ncipollo/release-action]: https://github.com/ncipollo/release-action
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version

View File

@@ -31,7 +31,6 @@ It is possible to fine-tune the settings for best performance or if you still ge
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
- Partial model loading (`enable_partial_loading`)
- PyTorch CUDA allocator config (`pytorch_cuda_alloc_conf`)
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
- Working memory (`device_working_mem_gb`)
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
@@ -52,16 +51,6 @@ As described above, you can enable partial model loading by adding this line to
enable_partial_loading: true
```
### PyTorch CUDA allocator config
The PyTorch CUDA allocator's behavior can be configured using the `pytorch_cuda_alloc_conf` config. Tuning the allocator configuration can help to reduce the peak reserved VRAM. The optimal configuration is dependent on many factors (e.g. device type, VRAM, CUDA driver version, etc.), but switching from PyTorch's native allocator to using CUDA's built-in allocator works well on many systems. To try this, add the following line to your `invokeai.yaml` file:
```yaml
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
```
A more complete explanation of the available configuration options is [here](https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf).
### Dynamic RAM and VRAM cache sizes
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
@@ -86,26 +75,24 @@ But, if your GPU has enough VRAM to hold models fully, you might get a perf boos
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
# might be a good value to achieve aggressive model caching.
max_cache_ram_gb: 28
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
# consideration the VRAM used by other processes).
# You can override the default value by setting `max_cache_vram_gb`.
# CAUTION: Most users should not manually set this value. See warning below.
max_cache_vram_gb: 16
# You can override the default value by setting `max_cache_vram_gb`. Note that this value takes precedence over the
# `device_working_mem_gb`.
# It is recommended to set the VRAM cache size to be as large as possible while leaving enough room for the working
# memory of the tasks you will be doing. For example, on a 24GB GPU that will be running unquantized FLUX without any
# auxiliary models, 18GB might be a good value.
max_cache_vram_gb: 18
```
!!! warning "Max safe value for `max_cache_vram_gb`"
!!! tip "Max safe value for `max_cache_vram_gb`"
Most users should not manually configure the `max_cache_vram_gb`. This configuration value takes precedence over the `device_working_mem_gb` and any operations that explicitly reserve additional working memory (e.g. VAE decode). As such, manually configuring it increases the likelihood of encountering out-of-memory errors.
For users who wish to configure `max_cache_vram_gb`, the max safe value can be determined by subtracting `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
Most users who override `max_cache_vram_gb` are doing so because they wish to use significantly less VRAM, and should be setting `max_cache_vram_gb` to a value significantly less than the 'max safe value'.
### Working memory
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.

View File

@@ -36,7 +36,6 @@ from invokeai.app.services.style_preset_images.style_preset_images_disk import S
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@@ -84,7 +83,6 @@ class ApiDependencies:
model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
workflow_thumbnails_folder = config.workflow_thumbnails_path
db = init_db(config=config, logger=logger, image_files=image_files)
@@ -122,7 +120,6 @@ class ApiDependencies:
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
services = InvocationServices(
board_image_records=board_image_records,
@@ -150,7 +147,6 @@ class ApiDependencies:
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -1,124 +0,0 @@
import json
import logging
from dataclasses import dataclass
from PIL import Image
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutIDValidator
@dataclass
class ExtractedMetadata:
invokeai_metadata: str | None
invokeai_workflow: str | None
invokeai_graph: str | None
def extract_metadata_from_image(
pil_image: Image.Image,
invokeai_metadata_override: str | None,
invokeai_workflow_override: str | None,
invokeai_graph_override: str | None,
logger: logging.Logger,
) -> ExtractedMetadata:
"""
Extracts the "invokeai_metadata", "invokeai_workflow", and "invokeai_graph" data embedded in the PIL Image.
These items are stored as stringified JSON in the image file's metadata, so we need to do some parsing to validate
them. Once parsed, the values are returned as they came (as strings), or None if they are not present or invalid.
In some situations, we may prefer to override the values extracted from the image file with some other values.
For example, when uploading an image via API, the client can optionally provide the metadata directly in the request,
as opposed to embedding it in the image file. In this case, the client-provided metadata will be used instead of the
metadata embedded in the image file.
Args:
pil_image: The PIL Image object.
invokeai_metadata_override: The metadata override provided by the client.
invokeai_workflow_override: The workflow override provided by the client.
invokeai_graph_override: The graph override provided by the client.
logger: The logger to use for debug logging.
Returns:
ExtractedMetadata: The extracted metadata, workflow, and graph.
"""
# The fallback value for metadata is None.
stringified_metadata: str | None = None
# Use the metadata override if provided, else attempt to extract it from the image file.
metadata_raw = invokeai_metadata_override or pil_image.info.get("invokeai_metadata", None)
# If the metadata is present in the image file, we will attempt to parse it as JSON. When we create images,
# we always store metadata as a stringified JSON dict. So, we expect it to be a string here.
if isinstance(metadata_raw, str):
try:
# Must be a JSON string
metadata_parsed = json.loads(metadata_raw)
# Must be a dict
if isinstance(metadata_parsed, dict):
# Looks good, overwrite the fallback value
stringified_metadata = metadata_raw
except Exception as e:
logger.debug(f"Failed to parse metadata for uploaded image, {e}")
pass
# We expect the workflow, if embedded in the image, to be a JSON-stringified WorkflowWithoutID. We will store it
# as a string.
workflow_raw: str | None = invokeai_workflow_override or pil_image.info.get("invokeai_workflow", None)
# The fallback value for workflow is None.
stringified_workflow: str | None = None
# If the workflow is present in the image file, we will attempt to parse it as JSON. When we create images, we
# always store workflows as a stringified JSON WorkflowWithoutID. So, we expect it to be a string here.
if isinstance(workflow_raw, str):
try:
# Validate the workflow JSON before storing it
WorkflowWithoutIDValidator.validate_json(workflow_raw)
# Looks good, overwrite the fallback value
stringified_workflow = workflow_raw
except Exception:
logger.debug("Failed to parse workflow for uploaded image")
pass
# We expect the workflow, if embedded in the image, to be a JSON-stringified Graph. We will store it as a
# string.
graph_raw: str | None = invokeai_graph_override or pil_image.info.get("invokeai_graph", None)
# The fallback value for graph is None.
stringified_graph: str | None = None
# If the graph is present in the image file, we will attempt to parse it as JSON. When we create images, we
# always store graphs as a stringified JSON Graph. So, we expect it to be a string here.
if isinstance(graph_raw, str):
try:
# TODO(psyche): Due to pydantic's handling of None values, it is possible for the graph to fail validation,
# even if it is a direct dump of a valid graph. Node fields in the graph are allowed to have be unset if
# they have incoming connections, but something about the ser/de process cannot adequately handle this.
#
# In lieu of fixing the graph validation, we will just do a simple check here to see if the graph is dict
# with the correct keys. This is not a perfect solution, but it should be good enough for now.
# FIX ME: Validate the graph JSON before storing it
# Graph.model_validate_json(graph_raw)
# Crappy workaround to validate JSON
graph_parsed = json.loads(graph_raw)
if not isinstance(graph_parsed, dict):
raise ValueError("Not a dict")
if not isinstance(graph_parsed.get("nodes", None), dict):
raise ValueError("'nodes' is not a dict")
if not isinstance(graph_parsed.get("edges", None), list):
raise ValueError("'edges' is not a list")
# Looks good, overwrite the fallback value
stringified_graph = graph_raw
except Exception as e:
logger.debug(f"Failed to parse graph for uploaded image, {e}")
pass
return ExtractedMetadata(
invokeai_metadata=stringified_metadata, invokeai_workflow=stringified_workflow, invokeai_graph=stringified_graph
)

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -88,9 +87,7 @@ async def delete_board(
try:
if include_images is True:
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id=board_id,
categories=None,
is_intermediate=None,
board_id=board_id
)
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
@@ -101,9 +98,7 @@ async def delete_board(
)
else:
deleted_board_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id=board_id,
categories=None,
is_intermediate=None,
board_id=board_id
)
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
return DeleteBoardResult(
@@ -147,14 +142,10 @@ async def list_boards(
)
async def list_all_board_image_names(
board_id: str = Path(description="The id of the board"),
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
) -> list[str]:
"""Gets a list of images for a board"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id,
categories,
is_intermediate,
)
return image_names

View File

@@ -6,10 +6,9 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
@@ -46,16 +45,18 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
metadata: Optional[str] = Body(
default=None,
description="The metadata to associate with the image, must be a stringified JSON dict",
embed=True,
metadata: Optional[JsonValue] = Body(
default=None, description="The metadata to associate with the image", embed=True
),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
_metadata = None
_workflow = None
_graph = None
contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
@@ -66,13 +67,30 @@ async def upload_image(
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
extracted_metadata = extract_metadata_from_image(
pil_image=pil_image,
invokeai_metadata_override=metadata,
invokeai_workflow_override=None,
invokeai_graph_override=None,
logger=ApiDependencies.invoker.services.logger,
)
# TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
pass
# attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
pass
# attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass
try:
image_dto = ApiDependencies.invoker.services.images.create(
@@ -81,9 +99,9 @@ async def upload_image(
image_category=image_category,
session_id=session_id,
board_id=board_id,
metadata=extracted_metadata.invokeai_metadata,
workflow=extracted_metadata.invokeai_workflow,
graph=extracted_metadata.invokeai_graph,
metadata=_metadata,
workflow=_workflow,
graph=_graph,
is_intermediate=is_intermediate,
)

View File

@@ -48,9 +48,7 @@ async def enqueue_batch(
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
@session_queue_router.get(

View File

@@ -1,10 +1,6 @@
import io
import traceback
from typing import Optional
from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from fastapi import APIRouter, Body, HTTPException, Path, Query
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.shared.pagination import PaginatedResults
@@ -14,14 +10,11 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowCategory,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemWithThumbnailDTO,
WorkflowRecordListItemDTO,
WorkflowRecordOrderBy,
WorkflowRecordWithThumbnailDTO,
WorkflowWithoutID,
)
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import WorkflowThumbnailFileNotFoundException
IMAGE_MAX_AGE = 31536000
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
@@ -29,17 +22,15 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
"/i/{workflow_id}",
operation_id="get_workflow",
responses={
200: {"model": WorkflowRecordWithThumbnailDTO},
200: {"model": WorkflowRecordDTO},
},
)
async def get_workflow(
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowRecordWithThumbnailDTO:
) -> WorkflowRecordDTO:
"""Gets a workflow"""
try:
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
@@ -66,11 +57,6 @@ async def delete_workflow(
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except WorkflowThumbnailFileNotFoundException:
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
pass
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
@@ -92,7 +78,7 @@ async def create_workflow(
"/",
operation_id="list_workflows",
responses={
200: {"model": PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]},
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
},
)
async def list_workflows(
@@ -102,158 +88,10 @@ async def list_workflows(
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
),
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories of workflow to get"),
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets a page of workflows"""
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
direction=direction,
page=page,
per_page=per_page,
query=query,
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
return ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by, direction=direction, page=page, per_page=per_page, query=query, category=category
)
for workflow in workflows.items:
workflows_with_thumbnails.append(
WorkflowRecordListItemWithThumbnailDTO(
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
**workflow.model_dump(),
)
)
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
items=workflows_with_thumbnails,
total=workflows.total,
page=workflows.page,
pages=workflows.pages,
per_page=workflows.per_page,
)
@workflows_router.put(
"/i/{workflow_id}/thumbnail",
operation_id="set_workflow_thumbnail",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def set_workflow_thumbnail(
workflow_id: str = Path(description="The workflow to update"),
image: UploadFile = File(description="The image file to upload"),
):
"""Sets a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.workflow_thumbnails.save(workflow_id, pil_image)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@workflows_router.delete(
"/i/{workflow_id}/thumbnail",
operation_id="delete_workflow_thumbnail",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def delete_workflow_thumbnail(
workflow_id: str = Path(description="The workflow to update"),
):
"""Removes a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e))
@workflows_router.get(
"/i/{workflow_id}/thumbnail",
operation_id="get_workflow_thumbnail",
responses={
200: {
"description": "The workflow thumbnail was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The workflow thumbnail could not be found"},
},
status_code=200,
)
async def get_workflow_thumbnail(
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
) -> FileResponse:
"""Gets a workflow's thumbnail image"""
try:
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
response = FileResponse(
path,
media_type="image/png",
filename=workflow_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
tags: list[str] = Query(description="The tags to get counts for"),
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
) -> dict[str, int]:
"""Counts workflows by tag"""
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
tags=tags, categories=categories, has_been_opened=has_been_opened
)
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
async def counts_by_category(
categories: list[WorkflowCategory] = Query(description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
) -> dict[str, int]:
"""Counts workflows by category"""
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
categories=categories, has_been_opened=has_been_opened
)
@workflows_router.put(
"/i/{workflow_id}/opened_at",
operation_id="update_opened_at",
)
async def update_opened_at(
workflow_id: str = Path(description="The workflow to update"),
) -> None:
"""Updates the opened_at field of a workflow"""
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)

View File

@@ -1,8 +1,12 @@
import asyncio
import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from pathlib import Path
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -11,7 +15,11 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
@@ -30,13 +38,31 @@ from invokeai.app.api.routers import (
from invokeai.app.api.sockets import SocketIO
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
app_config = get_config()
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
logger = InvokeAILogger.get_logger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
# the correct port when the server starts in the lifespan handler.
port = app_config.port
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -45,7 +71,7 @@ async def lifespan(app: FastAPI):
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
proto = "https" if app_config.ssl_certfile else "http"
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
# Logging this way ignores the logger's log level and _always_ logs the message
record = logger.makeRecord(
@@ -160,3 +186,73 @@ except RuntimeError:
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def invoke_api() -> None:
def find_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_port(port=port + 1)
else:
return port
if app_config.dev_reload:
try:
import jurigged
except ImportError as e:
logger.error(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
global port
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
check_cudnn(logger)
config = uvicorn.Config(
app=app,
host=app_config.host,
port=port,
loop="asyncio",
log_level=app_config.log_level_network,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

View File

@@ -1,5 +1,33 @@
import shutil
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.services.config.config_default import get_config
custom_nodes_path = Path(get_config().custom_nodes_path)
custom_nodes_path.mkdir(parents=True, exist_ok=True)
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
custom_nodes_readme_path = str(custom_nodes_path / "README.md")
# copy our custom nodes __init__.py to the custom nodes directory
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
# set the same permissions as the destination directory, in case our source is read-only,
# so that the files are user-writable
for p in custom_nodes_path.glob("**/*"):
p.chmod(custom_nodes_path.stat().st_mode)
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not load custom nodes from {custom_nodes_init_path}")
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
# add core nodes to __all__
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
__all__ = [f.stem for f in python_files] # type: ignore

View File

@@ -44,6 +44,8 @@ if TYPE_CHECKING:
logger = InvokeAILogger.get_logger()
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
class InvalidVersionError(ValueError):
pass
@@ -238,11 +240,6 @@ class BaseInvocation(ABC, BaseModel):
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
@@ -417,7 +414,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
field.json_schema_extra.pop("ui_type")
return None
@@ -449,27 +446,8 @@ def invocation(
if re.compile(r"^\S+$").match(invocation_type) is None:
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
# The node pack is the module name - will be "invokeai" for built-in nodes
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in BaseInvocation.get_invocation_types():
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The node being clobbered is a core node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
)
else:
# The node being clobbered is a custom node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
)
raise ValueError(f'Invocation type "{invocation_type}" already exists')
validate_fields(cls.model_fields, invocation_type)
@@ -479,7 +457,8 @@ def invocation(
uiconfig["tags"] = tags
uiconfig["category"] = category
uiconfig["classification"] = classification
uiconfig["node_pack"] = node_pack
# The node pack is the module name - will be "invokeai" for built-in nodes
uiconfig["node_pack"] = cls.__module__.split(".")[0]
if version is not None:
try:

View File

@@ -64,50 +64,13 @@ class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[],
min_length=1,
description="The images to batch over",
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()
@invocation_output("image_generator_output")
class ImageGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of boards"""
images: list[ImageField] = OutputField(description="The generated images")
class ImageGeneratorField(BaseModel):
pass
@invocation(
"image_generator",
title="Image Generator",
tags=["primitives", "board", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageGenerator(BaseInvocation):
"""Generated a collection of images for use in a batched generation"""
generator: ImageGeneratorField = InputField(
description="The image generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> ImageGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"string_batch",
title="String Batch",

View File

@@ -40,10 +40,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"compel",
title="Prompt - SD1.5",
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.2.1",
version="1.2.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@@ -233,10 +233,10 @@ class SDXLPromptInvocationBase:
@invocation(
"sdxl_compel_prompt",
title="Prompt - SDXL",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.2.1",
version="1.2.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@@ -327,10 +327,10 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@invocation(
"sdxl_refiner_compel_prompt",
title="Prompt - SDXL Refiner",
title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.1.2",
version="1.1.1",
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@@ -376,10 +376,10 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
@invocation(
"clip_skip",
title="Apply CLIP Skip - SD1.5, SDXL",
title="CLIP Skip",
tags=["clipskip", "clip", "skip"],
category="conditioning",
version="1.1.1",
version="1.1.0",
)
class CLIPSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
@@ -513,7 +513,7 @@ def log_tokenization_for_text(
usedTokens += 1
if usedTokens > 0:
print(f"\n>> [TOKENLOG] Tokens {display_label or ''} ({usedTokens}):")
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"{tokenized}\x1b[0m")
if discarded != "":

View File

@@ -87,7 +87,7 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""

View File

@@ -0,0 +1,58 @@
"""
Invoke-managed custom node loader. See README.md for more information.
"""
import sys
import traceback
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
loaded_count = 0
for d in Path(__file__).parent.iterdir():
# skip files
if not d.is_dir():
continue
# skip hidden directories
if d.name.startswith("_") or d.name.startswith("."):
continue
# skip directories without an `__init__.py`
init = d / "__init__.py"
if not init.exists():
continue
module_name = init.parent.stem
# skip if already imported
if module_name in globals():
continue
# load the module, appending adding a suffix to identify it as a custom node pack
spec = spec_from_file_location(module_name, init.absolute())
if spec is None or spec.loader is None:
logger.warn(f"Could not load {init}")
continue
logger.info(f"Loading node pack {module_name}")
try:
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
loaded_count += 1
except Exception:
full_error = traceback.format_exc()
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
del init, module_name
if loaded_count > 0:
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")

View File

@@ -127,10 +127,10 @@ def get_scheduler(
@invocation(
"denoise_latents",
title="Denoise - SD1.5, SDXL",
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.5.4",
version="1.5.3",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""

View File

@@ -57,8 +57,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
# endregion
# region Misc Field Types
@@ -154,7 +152,6 @@ class FieldDescriptions:
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
spandrel_image_to_image_model = "Image-to-Image model"
vllm_model = "VLLM model"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
@@ -204,7 +201,6 @@ class FieldDescriptions:
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"
class ImageField(BaseModel):
@@ -263,17 +259,6 @@ class FluxConditioningField(BaseModel):
)
class FluxReduxConditioningField(BaseModel):
"""A FLUX Redux conditioning tensor primitive value"""
conditioning: TensorField = Field(description="The Redux image conditioning tensor.")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -21,10 +21,10 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
@invocation(
"flux_control_lora_loader",
title="Control LoRA - FLUX",
title="Flux Control LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.1",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxControlLoRALoaderInvocation(BaseInvocation):

View File

@@ -15,7 +15,6 @@ from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
@@ -47,7 +46,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
@@ -62,7 +61,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.2.3",
version="3.2.2",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -104,11 +103,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
redux_conditioning: FluxReduxConditioningField | list[FluxReduxConditioningField] | None = InputField(
default=None,
description="FLUX Redux conditioning tensor.",
input=Input.Connection,
)
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
cfg_scale_start_step: int = InputField(
default=0,
@@ -196,23 +190,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
redux_conditionings: list[FluxReduxConditioning] = self._load_redux_conditioning(
context=context,
redux_cond_field=self.redux_conditioning,
packed_height=packed_h,
packed_width=packed_w,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
text_conditioning=pos_text_conditionings,
redux_conditioning=redux_conditionings,
img_seq_len=packed_h * packed_w,
pos_text_conditionings, img_seq_len=packed_h * packed_w
)
neg_regional_prompting_extension = (
RegionalPromptingExtension.from_text_conditioning(
text_conditioning=neg_text_conditionings, redux_conditioning=[], img_seq_len=packed_h * packed_w
)
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
if neg_text_conditionings
else None
)
@@ -418,42 +400,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return text_conditionings
def _load_redux_conditioning(
self,
context: InvocationContext,
redux_cond_field: FluxReduxConditioningField | list[FluxReduxConditioningField] | None,
packed_height: int,
packed_width: int,
device: torch.device,
dtype: torch.dtype,
) -> list[FluxReduxConditioning]:
# Normalize to a list of FluxReduxConditioningFields.
if redux_cond_field is None:
return []
redux_cond_list = (
[redux_cond_field] if isinstance(redux_cond_field, FluxReduxConditioningField) else redux_cond_field
)
redux_conditionings: list[FluxReduxConditioning] = []
for redux_cond_field in redux_cond_list:
# Load the Redux conditioning tensor.
redux_cond_data = context.tensors.load(redux_cond_field.conditioning.tensor_name)
redux_cond_data.to(device=device, dtype=dtype)
# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if redux_cond_field.mask is not None:
mask = context.tensors.load(redux_cond_field.mask.tensor_name)
mask = mask.to(device=device)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, packed_height, packed_width, dtype, device
)
redux_conditionings.append(FluxReduxConditioning(redux_embeddings=redux_cond_data, mask=mask))
return redux_conditionings
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int

View File

@@ -37,10 +37,10 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
@invocation(
"flux_model_loader",
title="Main Model - FLUX",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.6",
version="1.0.5",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):

View File

@@ -1,119 +0,0 @@
from typing import Optional
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxReduxConditioningField,
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from invokeai.backend.model_manager.starter_models import siglip
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice
@invocation_output("flux_redux_output")
class FluxReduxOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Redux invocation."""
redux_cond: FluxReduxConditioningField = OutputField(
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
)
@invocation(
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="2.0.0",
classification=Classification.Prototype,
)
class FluxReduxInvocation(BaseInvocation):
"""Runs a FLUX Redux model to generate a conditioning tensor."""
image: ImageField = InputField(description="The FLUX Redux image prompt.")
mask: Optional[TensorField] = InputField(
default=None,
description="The bool mask associated with this FLUX Redux image prompt. Excluded regions should be set to "
"False, included regions should be set to True.",
)
redux_model: ModelIdentifierField = InputField(
description="The FLUX Redux model to use.",
title="FLUX Redux Model",
ui_type=UIType.FluxReduxModel,
)
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
encoded_x = self._siglip_encode(context, image)
redux_conditioning = self._flux_redux_encode(context, encoded_x)
tensor_name = context.tensors.save(redux_conditioning)
return FluxReduxOutput(
redux_cond=FluxReduxConditioningField(conditioning=TensorField(tensor_name=tensor_name), mask=self.mask)
)
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
siglip_model_config = self._get_siglip_model(context)
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
@torch.no_grad()
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
with context.models.load(self.redux_model).model_on_device() as (_, flux_redux):
assert isinstance(flux_redux, FluxReduxModel)
dtype = next(flux_redux.parameters()).dtype
encoded_x = encoded_x.to(dtype=dtype)
return flux_redux(encoded_x)
def _get_siglip_model(self, context: InvocationContext) -> AnyModelConfig:
siglip_models = context.models.search_by_attrs(name=siglip.name, base=BaseModelType.Any, type=ModelType.SigLIP)
if not len(siglip_models) > 0:
context.logger.warning(
f"The SigLIP model required by FLUX Redux ({siglip.name}) is not installed. Downloading and installing now. This may take a while."
)
# TODO(psyche): Can the probe reliably determine the type of the model? Just hardcoding it bc I don't want to experiment now
config_overrides = ModelRecordChanges(name=siglip.name, type=ModelType.SigLIP)
# Queue the job
job = context._services.model_manager.install.heuristic_import(siglip.source, config=config_overrides)
# Wait for up to 10 minutes - model is ~3.5GB
context._services.model_manager.install.wait_for_job(job, timeout=600)
siglip_models = context.models.search_by_attrs(
name=siglip.name,
base=BaseModelType.Any,
type=ModelType.SigLIP,
)
if len(siglip_models) == 0:
context.logger.error("Error while fetching SigLIP for FLUX Redux")
assert len(siglip_models) == 1
return siglip_models[0]

View File

@@ -26,10 +26,10 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
@invocation(
"flux_text_encoder",
title="Prompt - FLUX",
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.2",
version="1.1.1",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):

View File

@@ -22,10 +22,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_decode",
title="Latents to Image - FLUX",
title="FLUX Latents to Image",
tags=["latents", "image", "vae", "l2i", "flux"],
category="latents",
version="1.0.2",
version="1.0.1",
)
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -41,11 +41,16 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
scaling_constant = 1090 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
# We add a 20% buffer to the working memory estimate to be safe.
working_memory = working_memory * 1.2
return int(working_memory)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:

View File

@@ -19,10 +19,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_encode",
title="Image to Latents - FLUX",
title="FLUX Image to Latents",
tags=["latents", "image", "vae", "i2l", "flux"],
category="latents",
version="1.0.1",
version="1.0.0",
)
class FluxVaeEncodeInvocation(BaseInvocation):
"""Encodes an image into latents."""

View File

@@ -19,9 +19,9 @@ class IdealSizeOutput(BaseInvocationOutput):
@invocation(
"ideal_size",
title="Ideal Size - SD1.5, SDXL",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.5",
version="1.0.4",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""

View File

@@ -31,10 +31,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"i2l",
title="Image to Latents - SD1.5, SDXL",
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.1.1",
version="1.1.0",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""

View File

@@ -69,13 +69,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] =
}
@invocation(
"ip_adapter",
title="IP-Adapter - SD1.5, SDXL",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.5.1",
)
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.5.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""

View File

@@ -31,10 +31,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"l2i",
title="Latents to Image - SD1.5, SDXL",
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.3.2",
version="1.3.1",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -60,7 +60,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if self.fp32 else 2
scaling_constant = 2200 # Determined experimentally.
scaling_constant = 960 # Determined experimentally.
if use_tiling:
tile_size = self.tile_size
@@ -84,7 +84,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
return int(working_memory)
# We add 20% to the working memory estimate to be safe.
working_memory = int(working_memory * 1.2)
return working_memory
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@@ -1,83 +0,0 @@
import logging
import shutil
import sys
import traceback
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
def load_custom_nodes(custom_nodes_path: Path, logger: logging.Logger):
"""
Loads all custom nodes from the custom_nodes_path directory.
If custom_nodes_path does not exist, it creates it.
It also copies the custom_nodes/README.md file to the custom_nodes_path directory. Because this file may change,
it is _always_ copied to the custom_nodes_path directory.
Then, it crawls the custom_nodes_path directory and imports all top-level directories as python modules.
If the directory does not contain an __init__.py file or starts with an `_` or `.`, it is skipped.
"""
# create the custom nodes directory if it does not exist
custom_nodes_path.mkdir(parents=True, exist_ok=True)
# Copy the README file to the custom nodes directory
source_custom_nodes_readme_path = Path(__file__).parent / "custom_nodes/README.md"
target_custom_nodes_readme_path = Path(custom_nodes_path) / "README.md"
# copy our custom nodes README to the custom nodes directory
shutil.copy(source_custom_nodes_readme_path, target_custom_nodes_readme_path)
loaded_packs: list[str] = []
failed_packs: list[str] = []
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
for d in custom_nodes_path.iterdir():
# skip files
if not d.is_dir():
continue
# skip hidden directories
if d.name.startswith("_") or d.name.startswith("."):
continue
# skip directories without an `__init__.py`
init = d / "__init__.py"
if not init.exists():
continue
module_name = init.parent.stem
# skip if already imported
if module_name in globals():
continue
# load the module
spec = spec_from_file_location(module_name, init.absolute())
if spec is None or spec.loader is None:
logger.warning(f"Could not load {init}")
continue
logger.info(f"Loading node pack {module_name}")
try:
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
loaded_packs.append(module_name)
except Exception:
failed_packs.append(module_name)
full_error = traceback.format_exc()
logger.error(f"Failed to load node pack {module_name} (may have partially loaded):\n{full_error}")
del init, module_name
loaded_count = len(loaded_packs)
if loaded_count > 0:
logger.info(
f"Loaded {loaded_count} node pack{'s' if loaded_count != 1 else ''} from {custom_nodes_path}: {', '.join(loaded_packs)}"
)

View File

@@ -284,7 +284,6 @@ class CoreMetadataInvocation(BaseInvocation):
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Deprecated,
)
class MetadataFieldExtractorInvocation(BaseInvocation):
"""Extracts the text value from an image's metadata given a key.

File diff suppressed because it is too large Load Diff

View File

@@ -122,10 +122,10 @@ class ModelIdentifierOutput(BaseInvocationOutput):
@invocation(
"model_identifier",
title="Any Model",
title="Model identifier",
tags=["model"],
category="model",
version="1.0.1",
version="1.0.0",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
@@ -144,10 +144,10 @@ class ModelIdentifierInvocation(BaseInvocation):
@invocation(
"main_model_loader",
title="Main Model - SD1.5",
title="Main Model",
tags=["model"],
category="model",
version="1.0.4",
version="1.0.3",
)
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@@ -244,7 +244,7 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Model - SD1.5", tags=["model"], category="model", version="1.0.2")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
@@ -257,9 +257,7 @@ class LoRASelectorInvocation(BaseInvocation):
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
@invocation(
"lora_collection_loader", title="LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.1"
)
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.1.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
@@ -322,10 +320,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
@invocation(
"sdxl_lora_loader",
title="LoRA Model - SDXL",
title="SDXL LoRA",
tags=["lora", "model"],
category="model",
version="1.0.4",
version="1.0.3",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@@ -402,10 +400,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
@invocation(
"sdxl_lora_collection_loader",
title="LoRA Collection - SDXL",
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.1.1",
version="1.1.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
@@ -471,9 +469,7 @@ class SDXLLoRACollectionLoader(BaseInvocation):
return output
@invocation(
"vae_loader", title="VAE Model - SD1.5, SDXL, SD3, FLUX", tags=["vae", "model"], category="model", version="1.0.4"
)
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
@@ -500,10 +496,10 @@ class SeamlessModeOutput(BaseInvocationOutput):
@invocation(
"seamless",
title="Apply Seamless - SD1.5, SDXL",
title="Seamless",
tags=["seamless", "model"],
category="model",
version="1.0.2",
version="1.0.1",
)
class SeamlessModeInvocation(BaseInvocation):
"""Applies the seamless transformation to the Model UNet and VAE."""
@@ -543,7 +539,7 @@ class SeamlessModeInvocation(BaseInvocation):
return SeamlessModeOutput(unet=unet, vae=vae)
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2")
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.1")
class FreeUInvocation(BaseInvocation):
"""
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):

View File

@@ -72,10 +72,10 @@ class NoiseOutput(BaseInvocationOutput):
@invocation(
"noise",
title="Create Latent Noise",
title="Noise",
tags=["latents", "noise"],
category="latents",
version="1.0.3",
version="1.0.2",
)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""

View File

@@ -265,9 +265,13 @@ class ImageInvocation(BaseInvocation):
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.images.get_dto(self.image.image_name)
image = context.images.get_pil(self.image.image_name)
return ImageOutput.build(image_dto=image_dto)
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@invocation(

View File

@@ -32,10 +32,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_denoise",
title="Denoise - SD3",
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.1.1",
version="1.1.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -21,10 +21,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_i2l",
title="Image to Latents - SD3",
title="SD3 Image to Latents",
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
version="1.0.1",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -24,10 +24,10 @@ from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_l2i",
title="Latents to Image - SD3",
title="SD3 Latents to Image",
tags=["latents", "image", "vae", "l2i", "sd3"],
category="latents",
version="1.3.2",
version="1.3.1",
)
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -43,11 +43,16 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
scaling_constant = 1230 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
# We add a 20% buffer to the working memory estimate to be safe.
working_memory = working_memory * 1.2
return int(working_memory)
@torch.no_grad()

View File

@@ -30,10 +30,10 @@ class Sd3ModelLoaderOutput(BaseInvocationOutput):
@invocation(
"sd3_model_loader",
title="Main Model - SD3",
title="SD3 Main Model",
tags=["model", "sd3"],
category="model",
version="1.0.1",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):

View File

@@ -29,10 +29,10 @@ SD3_T5_MAX_SEQ_LEN = 256
@invocation(
"sd3_text_encoder",
title="Prompt - SD3",
title="SD3 Text Encoding",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.1",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):

View File

@@ -24,7 +24,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="Main Model - SDXL", tags=["model", "sdxl"], category="model", version="1.0.4")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
@@ -58,10 +58,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
@invocation(
"sdxl_refiner_model_loader",
title="Refiner Model - SDXL",
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.4",
version="1.0.3",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""

View File

@@ -185,9 +185,9 @@ class SegmentAnythingInvocation(BaseInvocation):
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
assert bounding_boxes is not None, (
"Bounding boxes must be provided to use the 'highest_box_score' mask filter."
)
assert (
bounding_boxes is not None
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
assert len(masks) == len(bounding_boxes)
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most

View File

@@ -45,11 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation(
"t2i_adapter",
title="T2I-Adapter - SD1.5, SDXL",
tags=["t2i_adapter", "control"],
category="t2i_adapter",
version="1.0.4",
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""

View File

@@ -53,11 +53,11 @@ def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> C
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise - SD1.5, SDXL",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
version="1.0.1",
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.

View File

@@ -9,6 +9,6 @@ def validate_weights(weights: Union[float, list[float]]) -> None:
def validate_begin_end_step(begin_step_percent: float, end_step_percent: float) -> None:
"""Validate that begin_step_percent is less than or equal to end_step_percent"""
if begin_step_percent > end_step_percent:
"""Validate that begin_step_percent is less than end_step_percent"""
if begin_step_percent >= end_step_percent:
raise ValueError("Begin step percent must be less than or equal to end step percent")

View File

@@ -1,82 +1,12 @@
import uvicorn
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
def get_app():
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
"""
from invokeai.app.api_app import app, loop
return app, loop
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
def run_app() -> None:
"""The main entrypoint for the app."""
# Parse the CLI arguments.
# Before doing _anything_, parse CLI args!
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
InvokeAIArgs.parse_args()
# Load config.
app_config = get_config()
from invokeai.app.api_app import invoke_api
logger = InvokeAILogger.get_logger(config=app_config)
# Configure the torch CUDA memory allocator.
# NOTE: It is important that this happens before torch is imported.
if app_config.pytorch_cuda_alloc_conf:
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
# Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called.
from invokeai.app.util.startup_utils import (
apply_monkeypatches,
check_cudnn,
enable_dev_reload,
find_open_port,
register_mime_types,
)
# Find an open port, and modify the config accordingly.
orig_config_port = app_config.port
app_config.port = find_open_port(app_config.port)
if orig_config_port != app_config.port:
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
# Miscellaneous startup tasks.
apply_monkeypatches()
register_mime_types()
if app_config.dev_reload:
enable_dev_reload()
check_cudnn(logger)
# Initialize the app and event loop.
app, loop = get_app()
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
# Start the server.
config = uvicorn.Config(
app=app,
host=app_config.host,
port=app_config.port,
loop="asyncio",
log_level=app_config.log_level_network,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)
loop.run_until_complete(server.serve())
invoke_api()

View File

@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.image_records.image_records_common import ImageCategory
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for the one-to-many board-image relationship record storage."""
@@ -28,8 +26,6 @@ class BoardImageRecordStorageBase(ABC):
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
"""Gets all board images for a board, as a list of the image names."""
pass

View File

@@ -1,20 +1,23 @@
import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecord,
deserialize_image_record,
)
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def add_image_to_board(
self,
@@ -22,8 +25,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
@@ -35,14 +38,16 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
@@ -53,6 +58,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
@@ -61,108 +68,96 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
AND board_images.board_id = ?
"""
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT image_name
FROM board_images
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_board_for_image(
self,
image_name: str,
) -> Optional[str]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_image_count_for_board(self, board_id: str) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
return count
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.image_records.image_records_common import ImageCategory
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@@ -28,8 +26,6 @@ class BoardImagesServiceABC(ABC):
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
"""Gets all board images for a board, as a list of the image names."""
pass

View File

@@ -1,7 +1,6 @@
from typing import Optional
from invokeai.app.services.board_images.board_images_base import BoardImagesServiceABC
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.invoker import Invoker
@@ -27,14 +26,8 @@ class BoardImagesService(BoardImagesServiceABC):
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(
board_id,
categories,
is_intermediate,
)
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
def get_board_for_image(
self,

View File

@@ -1,4 +1,5 @@
import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
@@ -18,14 +19,20 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def delete(self, board_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
@@ -33,9 +40,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(board_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
@@ -43,8 +55,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
) -> BoardRecord:
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
@@ -55,6 +67,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get(
@@ -62,8 +76,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
@@ -72,9 +86,12 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -85,10 +102,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
changes: BoardChanges,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
@@ -99,7 +117,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
@@ -110,7 +128,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
@@ -123,6 +141,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get_many(
@@ -133,10 +153,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
cursor = self._conn.cursor()
try:
self._lock.acquire()
# Build base query
base_query = """
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
@@ -144,67 +165,81 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
# Execute count query
self._cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
try:
self._lock.acquire()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
self._cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards
return boards
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@@ -63,11 +63,7 @@ class BulkDownloadService(BulkDownloadBase):
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
def _board_handler(self, board_id: str) -> list[ImageDTO]:
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(
board_id,
categories=None,
is_intermediate=None,
)
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
return self._image_handler(image_names)
def generate_item_id(self, board_id: Optional[str]) -> str:

View File

@@ -72,7 +72,6 @@ class InvokeAIAppConfig(BaseSettings):
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
style_presets_dir: Path to directory for style presets.
workflow_thumbnails_dir: Path to directory for workflow thumbnails.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
@@ -92,7 +91,6 @@ class InvokeAIAppConfig(BaseSettings):
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
@@ -143,7 +141,6 @@ class InvokeAIAppConfig(BaseSettings):
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
workflow_thumbnails_dir: Path = Field(default=Path("workflow_thumbnails"), description="Path to directory for workflow thumbnails.")
# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
@@ -172,9 +169,6 @@ class InvokeAIAppConfig(BaseSettings):
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
# PyTorch Memory Allocator
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
@@ -306,11 +300,6 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the style presets directory, resolved to an absolute path.."""
return self._resolve(self.style_presets_dir)
@property
def workflow_thumbnails_path(self) -> Path:
"""Path to the workflow thumbnails directory, resolved to an absolute path.."""
return self._resolve(self.workflow_thumbnails_dir)
@property
def convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path.."""
@@ -483,9 +472,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert config.schema_version == CONFIG_SCHEMA_VERSION, (
f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

View File

@@ -1,4 +1,5 @@
import sqlite3
import threading
from datetime import datetime
from typing import Optional, Union, cast
@@ -21,14 +22,21 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def get(self, image_name: str) -> ImageRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
@@ -36,9 +44,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
@@ -47,8 +58,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
@@ -56,7 +68,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
if not result:
raise ImageRecordNotFoundException
@@ -65,7 +77,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
def update(
self,
@@ -73,10 +88,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
changes: ImageRecordChanges,
) -> None:
try:
cursor = self._conn.cursor()
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
@@ -87,7 +102,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
@@ -98,7 +113,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
@@ -109,7 +124,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET starred = ?
@@ -122,6 +137,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
@@ -135,104 +152,110 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_params.append(is_intermediate)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
@@ -243,48 +266,58 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
placeholders = ",".join("?" for _ in image_names)
self._lock.acquire()
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
return count
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, self._cursor.fetchone()[0])
self._conn.commit()
return count
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def delete_intermediates(self) -> list[str]:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
@@ -295,6 +328,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
@@ -311,8 +346,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata: Optional[str] = None,
) -> datetime:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
@@ -345,7 +380,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
)
self._conn.commit()
cursor.execute(
self._cursor.execute(
"""--sql
SELECT created_at
FROM images
@@ -354,30 +389,34 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
finally:
self._lock.release()
if result is None:
return None

View File

@@ -265,11 +265,7 @@ class ImageService(ImageServiceABC):
def delete_images_on_board(self, board_id: str):
try:
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(
board_id,
categories=None,
is_intermediate=None,
)
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
for image_name in image_names:
self.__invoker.services.image_files.delete(image_name)
self.__invoker.services.image_records.delete_many(image_names)
@@ -282,7 +278,7 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Failed to delete image files")
raise
except Exception as e:
self.__invoker.services.logger.error(f"Problem deleting image records and files: {str(e)}")
self.__invoker.services.logger.error("Problem deleting image records and files")
raise e
def delete_intermediates(self) -> int:

View File

@@ -32,7 +32,6 @@ if TYPE_CHECKING:
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.urls.urls_base import UrlServiceBase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_base import WorkflowThumbnailServiceBase
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@@ -66,7 +65,6 @@ class InvocationServices:
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -93,4 +91,3 @@ class InvocationServices:
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails

View File

@@ -78,6 +78,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._logger = logger
@property
@@ -95,38 +96,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(config.key)
@@ -138,21 +139,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
DELETE FROM models
WHERE id=?;
""",
(key,),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
with self._db.lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM models
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
@@ -163,23 +164,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -191,33 +192,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
@@ -226,15 +227,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
count = 0
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -282,18 +284,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
with self._db.lock:
self._cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = self._cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -312,28 +313,34 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results
def list_models(
@@ -349,32 +356,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
ModelRecordOrderBy.Format: "format",
}
cursor = self._db.conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
with self._db.lock:
# query1: get the total number of model configs
self._cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(self._cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
# query2: fetch key fields
self._cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Coroutine, Optional
from typing import Optional
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@@ -33,7 +33,7 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
"""Enqueues all permutations of a batch for execution."""
pass

View File

@@ -1,7 +1,7 @@
import datetime
import json
from itertools import chain, product
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import (
AliasChoices,
@@ -406,143 +406,61 @@ class IsFullResult(BaseModel):
# region Util
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
"""
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
field_values_json for each session.
Populates the given graph with the given batch data items.
"""
graph_clone = graph.model_copy(deep=True)
for item in node_field_values:
node = graph_clone.get_node(item.node_path)
if node is None:
continue
setattr(node, item.field_name, item.value)
graph_clone.update_node(item.node_path, node)
return graph_clone
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
the field.
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
Each inner list now represents the substitution values for a single permutation (session).
- For each permutation, substitute the values into the graph
This function is optimized for performance, as it is used to generate a large number of sessions at once.
Args:
batch: The batch to generate sessions from
maximum: The maximum number of sessions to generate
Returns:
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
generator will stop early if the maximum number of sessions is reached.
def create_session_nfv_tuples(
batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
that was applied to the graph.
"""
# TODO: Should this be a class method on Batch?
data: list[list[tuple[dict]]] = []
data: list[list[tuple[NodeFieldValue]]] = []
batch_data_collection = batch.data if batch.data is not None else []
for batch_datum_list in batch_data_collection:
node_field_values_to_zip: list[list[dict]] = []
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
node_field_values_to_zip: list[list[NodeFieldValue]] = []
for batch_datum in batch_datum_list:
node_field_values = [
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
# Zip the dicts together to create a list of dicts for each permutation
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# We serialize the graph and session once, then mutate the graph dict in place for each session.
#
# This sounds scary, but it's actually fine.
#
# The batch prep logic injects field values into the same fields for each generated session.
#
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
# [
# (
# {"node_path": "1", "field_name": "a", "value": 1},
# {"node_path": "2", "field_name": "b", "value": 2},
# {"node_path": "3", "field_name": "c", "value": 3},
# ),
# (
# {"node_path": "1", "field_name": "a", "value": 4},
# {"node_path": "2", "field_name": "b", "value": 5},
# {"node_path": "3", "field_name": "c", "value": 6},
# )
# ]
#
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
# No matter the complexity of the batch, this property holds true.
#
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
#
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
# but this was also slow.
#
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
# objects for each session.
#
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
# dict as the session's graph.
# Dump the batch's graph to a dict once
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
# overwritten for each session by the mutated graph_as_dict.
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
# create generator to yield session,nfv tuples
count = 0
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
# still limited by the maximum number of sessions.
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
# We've reached the maximum number of sessions we may generate
return
# Flatten the list of lists of dicts into a single list of dicts
# TODO(psyche): Is the a more efficient way to do this?
flat_node_field_values = list(chain.from_iterable(d))
# Need a fresh ID for each session
session_id = uuid_string()
# Mutate the session dict in place
session_dict["id"] = session_id
# Substitute the values into the graph
for nfv in flat_node_field_values:
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
# Mutate the session dict in place
session_dict["graph"] = graph_as_dict
# Serialize the session and field values
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
session_json = json.dumps(session_dict, default=to_jsonable_python)
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
# Yield the session_id, session_json, and field_values_json
yield (session_id, session_json, field_values_json)
# Increment the count so we know when to stop
graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
Calculates the number of sessions that would be created by the batch, without incurring
the overhead of actually generating them. Adapted from `create_sessions().
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
@@ -558,78 +476,42 @@ def calc_session_count(batch: Batch) -> int:
return len(data_product) * batch.runs
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)
str, # session_id
str, # batch_id
str | None, # field_values (optional, as stringified JSON)
int, # priority
str | None, # workflow (optional, as stringified JSON)
str | None, # origin (optional)
str | None, # destination (optional)
int | None, # retried_from_item_id (optional, this is always None for new items)
]
"""A type alias for the tuple of values to insert into the session queue table.
class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table"""
**If you change this, be sure to update the `enqueue_batch` and `retry_items_by_id` methods in the session queue service!**
"""
# Careful with the ordering of this - it must match the insert statement
queue_id: str # queue_id
session: str # session json
session_id: str # session_id
batch_id: str # batch_id
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: str | None
destination: str | None
retried_from_item_id: int | None = None
def prepare_values_to_insert(
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
) -> list[ValueToInsertTuple]:
"""
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
`executemany` statement to insert multiple rows at once.
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
Args:
queue_id: The ID of the queue to insert the items into
batch: The batch to prepare the values for
priority: The priority of the queue items
max_new_queue_items: The maximum number of queue items to insert
Returns:
A list of tuples to insert into the session queue table. Each tuple contains the following values:
- queue_id
- session (as stringified JSON)
- session_id
- batch_id
- field_values (optional, as stringified JSON)
- priority
- workflow (optional, as stringified JSON)
- origin (optional)
- destination (optional)
- retried_from_item_id (optional, this is always None for new items)
"""
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
# this difference becomes noticeable.
#
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
values_to_insert: list[ValueToInsertTuple] = []
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
# not support by default. Apparently there are sets somewhere in the graph.
# The same workflow is used for all sessions in the batch - serialize it once
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = []
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id
session.id = uuid_string()
values_to_insert.append(
(
queue_id,
session_json,
session_id,
batch.batch_id,
field_values_json,
priority,
workflow_json,
batch.origin,
batch.destination,
None,
SessionQueueValueToInsert(
queue_id=queue_id,
session=session.model_dump_json(warnings=False, exclude_none=True), # as json
session_id=session.id,
batch_id=batch.batch_id,
# must use pydantic_encoder bc field_values is a list of models
field_values=json.dumps(field_values, default=to_jsonable_python) if field_values else None, # as json
priority=priority,
workflow=json.dumps(workflow, default=to_jsonable_python) if workflow else None, # as json
origin=batch.origin,
destination=batch.destination,
)
)
return values_to_insert

View File

@@ -1,6 +1,6 @@
import asyncio
import json
import sqlite3
import threading
from typing import Optional, Union, cast
from pydantic_core import to_jsonable_python
@@ -27,7 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
SessionQueueValueToInsert,
calc_session_count,
prepare_values_to_insert,
)
@@ -38,6 +38,9 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteSessionQueue(SessionQueueBase):
__invoker: Invoker
__conn: sqlite3.Connection
__cursor: sqlite3.Cursor
__lock: threading.RLock
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
@@ -53,7 +56,9 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
def _set_in_progress_to_canceled(self) -> None:
"""
@@ -61,8 +66,8 @@ class SqliteSessionQueue(SessionQueueBase):
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
cursor = self._conn.cursor()
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -70,13 +75,14 @@ class SqliteSessionQueue(SessionQueueBase):
"""
)
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
cursor = self._conn.cursor()
cursor.execute(
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -86,12 +92,11 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
return cast(int, self.__cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
cursor = self._conn.cursor()
cursor.execute(
self.__cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
@@ -101,14 +106,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
self.__lock.acquire()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
@@ -130,17 +133,19 @@ class SqliteSessionQueue(SessionQueueBase):
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
cursor.executemany(
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
@@ -152,19 +157,25 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -172,40 +183,52 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -219,8 +242,8 @@ class SqliteSessionQueue(SessionQueueBase):
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
@@ -228,10 +251,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -239,36 +264,48 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
cursor = self._conn.cursor()
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -276,8 +313,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
"""--sql
DELETE
FROM session_queue
@@ -285,16 +322,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id = ?
@@ -304,7 +342,8 @@ class SqliteSessionQueue(SessionQueueBase):
OR status = 'canceled'
)
"""
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -312,8 +351,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
DELETE
FROM session_queue
@@ -321,10 +360,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
@@ -353,8 +394,8 @@ class SqliteSessionQueue(SessionQueueBase):
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
WHERE
@@ -365,7 +406,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id] + batch_ids
cursor.execute(
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -373,8 +414,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -382,18 +423,20 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
@@ -403,7 +446,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = (queue_id, destination)
cursor.execute(
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -411,8 +454,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -420,18 +463,20 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
self.__conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByDestinationResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id is ?
@@ -440,7 +485,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id]
cursor.execute(
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -448,8 +493,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -457,7 +502,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
@@ -465,19 +510,21 @@ class SqliteSessionQueue(SessionQueueBase):
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -485,8 +532,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -494,35 +541,43 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
@@ -530,10 +585,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items(
@@ -544,71 +601,83 @@ class SqliteSessionQueue(SessionQueueBase):
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
try:
item_id = cursor
self.__lock.acquire()
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
self.__cursor.execute(query, params)
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
@@ -627,23 +696,29 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return BatchStatus(
batch_id=batch_id,
@@ -659,18 +734,24 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
@@ -689,8 +770,9 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
values_to_insert: list[ValueToInsertTuple] = []
self.__lock.acquire()
values_to_insert: list[SessionQueueValueToInsert] = []
retried_item_ids: list[int] = []
for item_id in item_ids:
@@ -716,23 +798,23 @@ class SqliteSessionQueue(SessionQueueBase):
else queue_item.item_id
)
value_to_insert: ValueToInsertTuple = (
queue_item.queue_id,
cloned_session_json,
cloned_session.id,
queue_item.batch_id,
field_values_json,
queue_item.priority,
workflow_json,
queue_item.origin,
queue_item.destination,
retried_from_item_id,
value_to_insert = SessionQueueValueToInsert(
queue_id=queue_item.queue_id,
batch_id=queue_item.batch_id,
destination=queue_item.destination,
field_values=field_values_json,
origin=queue_item.origin,
priority=queue_item.priority,
workflow=workflow_json,
session=cloned_session_json,
session_id=cloned_session.id,
retried_from_item_id=retried_from_item_id,
)
values_to_insert.append(value_to_insert)
# TODO(psyche): Handle max queue size?
cursor.executemany(
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
@@ -740,10 +822,12 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -9,7 +9,6 @@ from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
from invokeai.app.services.board_records.board_records_common import BoardRecordOrderBy
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
@@ -17,7 +16,6 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModel,
@@ -104,9 +102,7 @@ class BoardsInterface(InvocationContextInterface):
Returns:
A list of all boards.
"""
return self._services.boards.get_all(
order_by=BoardRecordOrderBy.CreatedAt, direction=SQLiteDirection.Descending
)
return self._services.boards.get_all()
def add_image_to_board(self, board_id: str, image_name: str) -> None:
"""Adds an image to a board.
@@ -126,11 +122,7 @@ class BoardsInterface(InvocationContextInterface):
Returns:
A list of all image names for the board.
"""
return self._services.board_images.get_all_board_image_names_for_board(
board_id,
categories=None,
is_intermediate=None,
)
return self._services.board_images.get_all_board_image_names_for_board(board_id)
class LoggerInterface(InvocationContextInterface):
@@ -291,7 +283,7 @@ class ImagesInterface(InvocationContextInterface):
Returns:
The local path of the image or thumbnail.
"""
return Path(self._services.images.get_path(image_name, thumbnail))
return self._services.images.get_path(image_name, thumbnail)
class TensorsInterface(InvocationContextInterface):

View File

@@ -1,4 +1,5 @@
import sqlite3
import threading
from logging import Logger
from pathlib import Path
@@ -37,20 +38,14 @@ class SqliteDatabase:
self.logger.info(f"Initializing database at {self.db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
# Enable foreign key constraints
self.conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self.conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
@@ -58,14 +53,15 @@ class SqliteDatabase:
# No need to clean in-memory database
if not self.db_path:
return
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
raise
with self.lock:
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
raise

View File

@@ -19,8 +19,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import build_migration_16
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -57,8 +55,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.register_migration(build_migration_16())
migrator.register_migration(build_migration_17())
migrator.register_migration(build_migration_18())
migrator.run_migrations()
return db

View File

@@ -1,35 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration17Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_workflows_tags_col(cursor)
def _add_workflows_tags_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `tags` column to the workflow_library table. It is a generated column that extracts the tags from the
workflow JSON.
"""
cursor.execute(
"ALTER TABLE workflow_library ADD COLUMN tags TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.tags')) VIRTUAL;"
)
def build_migration_17() -> Migration:
"""
Build the migration from database version 16 to 17.
This migration does the following:
- Adds `tags` column to the workflow_library table. It is a generated column that extracts the tags from the
workflow JSON.
"""
migration_17 = Migration(
from_version=16,
to_version=17,
callback=Migration17Callback(),
)
return migration_17

View File

@@ -1,47 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration18Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._make_workflow_opened_at_nullable(cursor)
def _make_workflow_opened_at_nullable(self, cursor: sqlite3.Cursor) -> None:
"""
Make the `opened_at` column nullable in the `workflow_library` table. This is accomplished by:
- Dropping the existing `idx_workflow_library_opened_at` index (must be done before dropping the column)
- Dropping the existing `opened_at` column
- Adding a new nullable column `opened_at` (no data migration needed, all values will be NULL)
- Adding a new `idx_workflow_library_opened_at` index on the `opened_at` column
"""
# For index renaming in SQLite, we need to drop and recreate
cursor.execute("DROP INDEX IF EXISTS idx_workflow_library_opened_at;")
# Rename existing column to deprecated
cursor.execute("ALTER TABLE workflow_library DROP COLUMN opened_at;")
# Add new nullable column - all values will be NULL - no migration of data needed
cursor.execute("ALTER TABLE workflow_library ADD COLUMN opened_at DATETIME;")
# Create new index on the new column
cursor.execute(
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
)
def build_migration_18() -> Migration:
"""
Build the migration from database version 17 to 18.
This migration does the following:
- Make the `opened_at` column nullable in the `workflow_library` table. This is accomplished by:
- Dropping the existing `idx_workflow_library_opened_at` index (must be done before dropping the column)
- Dropping the existing `opened_at` column
- Adding a new nullable column `opened_at` (no data migration needed, all values will be NULL)
- Adding a new `idx_workflow_library_opened_at` index on the `opened_at` column
"""
migration_18 = Migration(
from_version=17,
to_version=18,
callback=Migration18Callback(),
)
return migration_18

View File

@@ -43,45 +43,46 @@ class SqliteMigrator:
def run_migrations(self) -> bool:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
with self._db.lock:
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
self._logger.info("Database update needed")
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.conn as conn:
with self._db.lock, self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(
@@ -107,26 +108,27 @@ class SqliteMigrator:
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
with self._db.lock:
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
@classmethod
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:

View File

@@ -17,7 +17,9 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,25 +27,31 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = self._cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
@@ -64,16 +72,18 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
cursor = self._conn.cursor()
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
cursor.execute(
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
@@ -94,15 +104,17 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
cursor = self._conn.cursor()
self._lock.acquire()
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE style_presets
SET name = ?
@@ -113,7 +125,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
# Change the preset data for a style preset
if changes.preset_data is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE style_presets
SET preset_data = ?
@@ -126,12 +138,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from style_presets
WHERE id = ?;
@@ -142,38 +156,46 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
main_query = """
SELECT
*
FROM style_presets
"""
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
return style_presets
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
@@ -183,8 +205,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)

View File

@@ -18,8 +18,3 @@ class UrlServiceBase(ABC):
def get_style_preset_image_url(self, style_preset_id: str) -> str:
"""Gets the URL for a style preset image"""
pass
@abstractmethod
def get_workflow_thumbnail_url(self, workflow_id: str) -> str:
"""Gets the URL for a workflow thumbnail"""
pass

View File

@@ -22,6 +22,3 @@ class LocalUrlService(UrlServiceBase):
def get_style_preset_image_url(self, style_preset_id: str) -> str:
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"
def get_workflow_thumbnail_url(self, workflow_id: str) -> str:
return f"{self._base_url}/workflows/i/{workflow_id}/thumbnail"

View File

@@ -1,11 +1,10 @@
{
"id": "default_686bb1d0-d086-4c70-9fa3-2f600b922023",
"name": "Upscaler - SD1.5, ESRGAN",
"name": "ESRGAN Upscaling with Canny ControlNet",
"author": "InvokeAI",
"description": "Sample workflow for using ESRGAN to upscale with ControlNet with SD1.5",
"description": "Sample workflow for using Upscaling with ControlNet with SD1.5",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "sd1.5, upscaling, control",
"tags": "upscale, controlnet, default",
"notes": "",
"exposedFields": [
{
@@ -185,7 +184,14 @@
},
"control_model": {
"name": "control_model",
"label": "Control Model (select Canny)"
"label": "Control Model (select Canny)",
"value": {
"key": "a7b9c76f-4bc5-42aa-b918-c1c458a5bb24",
"hash": "blake3:260c7f8e10aefea9868cfc68d89970e91033bd37132b14b903e70ee05ebf530e",
"name": "sd-controlnet-canny",
"base": "sd-1",
"type": "controlnet"
}
},
"control_weight": {
"name": "control_weight",
@@ -288,7 +294,14 @@
"inputs": {
"model": {
"name": "model",
"label": ""
"label": "",
"value": {
"key": "5cd43ca0-dd0a-418d-9f7e-35b2b9d5e106",
"hash": "blake3:6987f323017f597213cc3264250edf57056d21a40a0a85d83a1a33a7d44dc41a",
"name": "Deliberate_v5",
"base": "sd-1",
"type": "main"
}
}
},
"isOpen": true,
@@ -835,4 +848,4 @@
"targetHandle": "image_resolution"
}
]
}
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_cbf0e034-7b54-4b2c-b670-3b1e2e4b4a88",
"name": "Image to Image - FLUX",
"name": "FLUX Image to Image",
"author": "InvokeAI",
"description": "A simple image-to-image workflow using a FLUX dev model. ",
"version": "1.1.0",
"contact": "",
"tags": "flux, image to image",
"tags": "image2image, flux, image-to-image",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
"exposedFields": [
{
@@ -201,15 +200,36 @@
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
"label": "",
"value": {
"key": "d18d5575-96b6-4da3-b3d8-eb58308d6705",
"hash": "random:f2f9ed74acdfb4bf6fec200e780f6c25f8dd8764a35e65d425d606912fdf573a",
"name": "t5_bnb_int8_quantized_encoder",
"base": "any",
"type": "t5_encoder"
}
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": ""
"label": "",
"value": {
"key": "5a19d7e5-8d98-43cd-8a81-87515e4b3b4e",
"hash": "random:4bd08514c08fb6ff04088db9aeb45def3c488e8b5fd09a35f2cc4f2dc346f99f",
"name": "clip-vit-large-patch14",
"base": "any",
"type": "clip_embed"
}
},
"vae_model": {
"name": "vae_model",
"label": ""
"label": "",
"value": {
"key": "9172beab-5c1d-43f0-b2f0-6e0b956710d9",
"hash": "random:c54dde288e5fa2e6137f1c92e9d611f598049e6f16e360207b6d96c9f5a67ba0",
"name": "FLUX.1-schnell_ae",
"base": "flux",
"type": "vae"
}
}
}
},

View File

@@ -1,11 +1,10 @@
{
"id": "default_dec5a2e9-f59c-40d9-8869-a056751d79b8",
"name": "Face Detailer - SD1.5",
"name": "Face Detailer with IP-Adapter & Canny (See Note in Details)",
"author": "kosmoskatten",
"description": "A workflow to add detail to and improve faces. This workflow is most effective when used with a model that creates realistic outputs. ",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "sd1.5, reference image, control",
"tags": "face detailer, IP-Adapter, Canny",
"notes": "Set this image as the blur mask: https://i.imgur.com/Gxi61zP.png",
"exposedFields": [
{
@@ -136,7 +135,14 @@
},
"control_model": {
"name": "control_model",
"label": "Control Model (select canny)"
"label": "Control Model (select canny)",
"value": {
"key": "5bdaacf7-a7a3-4fb8-b394-cc0ffbb8941d",
"hash": "blake3:260c7f8e10aefea9868cfc68d89970e91033bd37132b14b903e70ee05ebf530e",
"name": "sd-controlnet-canny",
"base": "sd-1",
"type": "controlnet"
}
},
"control_weight": {
"name": "control_weight",
@@ -190,7 +196,14 @@
},
"ip_adapter_model": {
"name": "ip_adapter_model",
"label": "IP-Adapter Model (select IP Adapter Face)"
"label": "IP-Adapter Model (select IP Adapter Face)",
"value": {
"key": "1cc210bb-4d0a-4312-b36c-b5d46c43768e",
"hash": "blake3:3d669dffa7471b357b4df088b99ffb6bf4d4383d5e0ef1de5ec1c89728a3d5a5",
"name": "ip_adapter_sd15",
"base": "sd-1",
"type": "ip_adapter"
}
},
"clip_vision_model": {
"name": "clip_vision_model",
@@ -1432,4 +1445,4 @@
"targetHandle": "vae"
}
]
}
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_444fe292-896b-44fd-bfc6-c0b5d220fffc",
"name": "Text to Image - FLUX",
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
"version": "1.1.0",
"contact": "",
"tags": "flux, text to image",
"tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [
{
@@ -169,15 +168,36 @@
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
"label": "",
"value": {
"key": "d18d5575-96b6-4da3-b3d8-eb58308d6705",
"hash": "random:f2f9ed74acdfb4bf6fec200e780f6c25f8dd8764a35e65d425d606912fdf573a",
"name": "t5_bnb_int8_quantized_encoder",
"base": "any",
"type": "t5_encoder"
}
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": ""
"label": "",
"value": {
"key": "5a19d7e5-8d98-43cd-8a81-87515e4b3b4e",
"hash": "random:4bd08514c08fb6ff04088db9aeb45def3c488e8b5fd09a35f2cc4f2dc346f99f",
"name": "clip-vit-large-patch14",
"base": "any",
"type": "clip_embed"
}
},
"vae_model": {
"name": "vae_model",
"label": ""
"label": "",
"value": {
"key": "9172beab-5c1d-43f0-b2f0-6e0b956710d9",
"hash": "random:c54dde288e5fa2e6137f1c92e9d611f598049e6f16e360207b6d96c9f5a67ba0",
"name": "FLUX.1-schnell_ae",
"base": "flux",
"type": "vae"
}
}
}
},

View File

@@ -1,11 +1,10 @@
{
"id": "default_2d05e719-a6b9-4e64-9310-b875d3b2f9d2",
"name": "Text to Image - SD1.5, Control",
"name": "Multi ControlNet (Canny & Depth)",
"author": "InvokeAI",
"description": "A sample workflow using canny & depth ControlNets to guide the generation process. ",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "sd1.5, control, text to image",
"tags": "ControlNet, canny, depth",
"notes": "",
"exposedFields": [
{
@@ -217,7 +216,14 @@
},
"control_model": {
"name": "control_model",
"label": "Control Model (select canny)"
"label": "Control Model (select canny)",
"value": {
"key": "5bdaacf7-a7a3-4fb8-b394-cc0ffbb8941d",
"hash": "blake3:260c7f8e10aefea9868cfc68d89970e91033bd37132b14b903e70ee05ebf530e",
"name": "sd-controlnet-canny",
"base": "sd-1",
"type": "controlnet"
}
},
"control_weight": {
"name": "control_weight",
@@ -364,7 +370,14 @@
},
"control_model": {
"name": "control_model",
"label": "Control Model (select depth)"
"label": "Control Model (select depth)",
"value": {
"key": "87e8855c-671f-4c9e-bbbb-8ed47ccb4aac",
"hash": "blake3:2550bf22a53942dfa28ab2fed9d10d80851112531f44d977168992edf9d0534c",
"name": "control_v11f1p_sd15_depth",
"base": "sd-1",
"type": "controlnet"
}
},
"control_weight": {
"name": "control_weight",
@@ -1001,4 +1014,4 @@
"targetHandle": "image_resolution"
}
]
}
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_f96e794f-eb3e-4d01-a960-9b4e43402bcf",
"name": "Upscaler - SD1.5, MultiDiffusion",
"name": "MultiDiffusion SD1.5",
"author": "Invoke",
"description": "A workflow to upscale an input image with tiled upscaling, using SD1.5 based models.",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "sd1.5, upscaling",
"tags": "tiled, upscaling, sdxl",
"notes": "",
"exposedFields": [
{
@@ -53,6 +52,7 @@
"version": "3.0.0",
"category": "default"
},
"id": "e5b5fb01-8906-463a-963a-402dbc42f79b",
"nodes": [
{
"id": "33fe76a0-5efd-4482-a7f0-e2abf1223dc2",
@@ -135,7 +135,14 @@
"inputs": {
"model": {
"name": "model",
"label": ""
"label": "",
"value": {
"key": "e7b402e5-62e5-4acb-8c39-bee6bdb758ab",
"hash": "c8659e796168d076368256b57edbc1b48d6dafc1712f1bb37cc57c7c06889a6b",
"name": "526mix",
"base": "sd-1",
"type": "main"
}
}
}
},
@@ -377,11 +384,21 @@
},
"image": {
"name": "image",
"label": "Image to Upscale"
"label": "Image to Upscale",
"value": {
"image_name": "ee7009f7-a35d-488b-a2a6-21237ef5ae05.png"
}
},
"image_to_image_model": {
"name": "image_to_image_model",
"label": ""
"label": "",
"value": {
"key": "38bb1a29-8ede-42ba-b77f-64b3478896eb",
"hash": "blake3:e52fdbee46a484ebe9b3b20ea0aac0a35a453ab6d0d353da00acfd35ce7a91ed",
"name": "4xNomosWebPhoto_esrgan",
"base": "sdxl",
"type": "spandrel_image_to_image"
}
},
"tile_size": {
"name": "tile_size",
@@ -420,7 +437,14 @@
"inputs": {
"model": {
"name": "model",
"label": "ControlNet Model - Choose a Tile ControlNet"
"label": "ControlNet Model - Choose a Tile ControlNet",
"value": {
"key": "20645e4d-ef97-4c5a-9243-b834a3483925",
"hash": "f0812e13758f91baf4e54b7dbb707b70642937d3b2098cd2b94cc36d3eba308e",
"name": "tile",
"base": "sd-1",
"type": "controlnet"
}
}
}
},
@@ -1403,4 +1427,4 @@
"targetHandle": "noise"
}
]
}
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_35658541-6d41-4a20-8ec5-4bf2561faed0",
"name": "Upscaler - SDXL, MultiDiffusion",
"name": "MultiDiffusion SDXL",
"author": "Invoke",
"description": "A workflow to upscale an input image with tiled upscaling, using SDXL based models.",
"version": "1.1.0",
"contact": "invoke@invoke.ai",
"tags": "sdxl, upscaling",
"tags": "tiled, upscaling, sdxl",
"notes": "",
"exposedFields": [
{
@@ -57,6 +56,7 @@
"version": "3.0.0",
"category": "default"
},
"id": "dd607062-9e1b-48b9-89ad-9762cdfbb8f4",
"nodes": [
{
"id": "71a116e1-c631-48b3-923d-acea4753b887",
@@ -341,7 +341,14 @@
"inputs": {
"model": {
"name": "model",
"label": "ControlNet Model - Choose a Tile ControlNet"
"label": "ControlNet Model - Choose a Tile ControlNet",
"value": {
"key": "74f4651f-0ace-4b7b-b616-e98360257797",
"hash": "blake3:167a5b84583aaed3e5c8d660b45830e82e1c602743c689d3c27773c6c8b85b4a",
"name": "controlnet-tile-sdxl-1.0",
"base": "sdxl",
"type": "controlnet"
}
}
}
},
@@ -794,7 +801,14 @@
"inputs": {
"vae_model": {
"name": "vae_model",
"label": ""
"label": "",
"value": {
"key": "ff926845-090e-4d46-b81e-30289ee47474",
"hash": "9705ab1c31fa96b308734214fb7571a958621c7a9247eed82b7d277145f8d9fa",
"name": "VAEFix",
"base": "sdxl",
"type": "vae"
}
}
}
},
@@ -818,7 +832,14 @@
"inputs": {
"model": {
"name": "model",
"label": "SDXL Model"
"label": "SDXL Model",
"value": {
"key": "ab191f73-68d2-492c-8aec-b438a8cf0f45",
"hash": "blake3:2d50e940627e3bf555f015280ec0976d5c1fa100f7bc94e95ffbfc770e98b6fe",
"name": "CustomXLv7",
"base": "sdxl",
"type": "main"
}
}
}
},
@@ -1621,4 +1642,4 @@
"targetHandle": "noise"
}
]
}
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_d7a1c60f-ca2f-4f90-9e33-75a826ca6d8f",
"name": "Text to Image - SD1.5, Prompt from File",
"name": "Prompt from File",
"author": "InvokeAI",
"description": "Sample workflow using Prompt from File node",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "sd1.5, text to image",
"tags": "text2image, prompt from file, default",
"notes": "",
"exposedFields": [
{
@@ -513,4 +512,4 @@
"targetHandle": "vae"
}
]
}
}

View File

@@ -3,14 +3,12 @@
Workflows placed in this directory will be synced to the `workflow_library` as
_default workflows_ on app startup.
- Default workflows must have an id that starts with "default\_". The ID must be retained when the workflow is updated. You may need to do this manually.
- Default workflows are not editable by users. If they are loaded and saved,
they will save as a copy of the default workflow.
- Default workflows must have the `meta.category` property set to `"default"`.
An exception will be raised during sync if this is not set correctly.
- Default workflows appear on the "Default Workflows" tab of the Workflow
Library.
- Default workflows should not reference any resources that are user-created or installed. That includes images and models. For example, if a default workflow references Juggernaut as an SDXL model, when a user loads the workflow, even if they have a version of Juggernaut installed, it will have a different UUID. They may see a warning. So, it's best to ship default workflows without any references to these types of resources.
After adding or updating default workflows, you **must** start the app up and
load them to ensure:

View File

@@ -1,375 +1,382 @@
{
"id": "default_dbe46d95-22aa-43fb-9c16-94400d0ce2fd",
"name": "Text to Image - SD3.5",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 3.5",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "SD3.5, text to image",
"name": "SD3.5 Text to Image",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 3.5",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, SD3.5, default",
"notes": "",
"exposedFields": [
{
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"fieldName": "model"
"exposedFields": [
{
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"fieldName": "model"
},
{
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"fieldName": "prompt"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
{
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"fieldName": "prompt"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "invocation",
"data": {
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
"nodes": [
{
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "sd3_model_loader",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"model": {
"name": "model",
"label": ""
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_l_model": {
"name": "clip_l_model",
"label": ""
},
"clip_g_model": {
"name": "clip_g_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
"type": "invocation",
"data": {
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "sd3_model_loader",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"model": {
"name": "model",
"label": "",
"value": {
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
"hash": "placeholder",
"name": "stable-diffusion-3.5-medium",
"base": "sd-3",
"type": "main"
}
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_l_model": {
"name": "clip_l_model",
"label": ""
},
"clip_g_model": {
"name": "clip_g_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": -55.58689609637031,
"y": -111.53602444662268
}
},
"position": {
"x": -55.58689609637031,
"y": -111.53602444662268
}
},
{
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "invocation",
"data": {
{
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"nodePack": "invokeai",
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
"type": "invocation",
"data": {
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"nodePack": "invokeai",
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 470.45870147220353,
"y": 350.3141781644303
}
},
"position": {
"x": 470.45870147220353,
"y": 350.3141781644303
}
},
{
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "invocation",
"data": {
{
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "sd3_l2i",
"version": "1.3.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
"type": "invocation",
"data": {
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "sd3_l2i",
"version": "1.3.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1192.3097009334897,
"y": -366.0994675072209
}
},
"position": {
"x": 1192.3097009334897,
"y": -366.0994675072209
}
},
{
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "invocation",
"data": {
{
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
"type": "invocation",
"data": {
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 408.16054647924784,
"y": 65.06415352118786
}
},
"position": {
"x": 408.16054647924784,
"y": 65.06415352118786
}
},
{
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "invocation",
"data": {
{
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
"type": "invocation",
"data": {
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 378.9283412440941,
"y": -302.65777497352553
}
},
"position": {
"x": 378.9283412440941,
"y": -302.65777497352553
}
},
{
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "invocation",
"data": {
{
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "sd3_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_conditioning": {
"name": "positive_conditioning",
"label": ""
},
"negative_conditioning": {
"name": "negative_conditioning",
"label": ""
},
"cfg_scale": {
"name": "cfg_scale",
"label": "",
"value": 3.5
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"steps": {
"name": "steps",
"label": "",
"value": 30
},
"seed": {
"name": "seed",
"label": "",
"value": 0
"type": "invocation",
"data": {
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "sd3_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_conditioning": {
"name": "positive_conditioning",
"label": ""
},
"negative_conditioning": {
"name": "negative_conditioning",
"label": ""
},
"cfg_scale": {
"name": "cfg_scale",
"label": "",
"value": 3.5
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"steps": {
"name": "steps",
"label": "",
"value": 30
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 813.7814762740603,
"y": -142.20529727605867
}
},
"position": {
"x": 813.7814762740603,
"y": -142.20529727605867
}
}
],
"edges": [
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
"type": "default",
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
"type": "default",
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
"type": "default",
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
"type": "default",
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
}
]
}
],
"edges": [
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
"type": "default",
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
"type": "default",
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
"type": "default",
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
"type": "default",
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
}
]
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_7dde3e36-d78f-4152-9eea-00ef9c8124ed",
"name": "Text to Image - SD1.5",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 1.5/2",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "SD1.5, text to image",
"tags": "text2image, SD1.5, SD2, default",
"notes": "",
"exposedFields": [
{
@@ -417,4 +416,4 @@
"targetHandle": "vae"
}
]
}
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_5e8b008d-c697-45d0-8883-085a954c6ace",
"name": "Text to Image - SDXL",
"author": "InvokeAI",
"description": "Sample text to image workflow for SDXL",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "SDXL, text to image",
"tags": "text2image, SDXL, default",
"notes": "",
"exposedFields": [
{
@@ -46,7 +45,14 @@
"inputs": {
"vae_model": {
"name": "vae_model",
"label": "VAE (use the FP16 model)"
"label": "VAE (use the FP16 model)",
"value": {
"key": "f20f9e5c-1bce-4c46-a84d-34ebfa7df069",
"hash": "blake3:9705ab1c31fa96b308734214fb7571a958621c7a9247eed82b7d277145f8d9fa",
"name": "sdxl-vae-fp16-fix",
"base": "sdxl",
"type": "vae"
}
}
},
"isOpen": true,
@@ -196,7 +202,14 @@
"inputs": {
"model": {
"name": "model",
"label": ""
"label": "",
"value": {
"key": "4a63b226-e8ff-4da4-854e-0b9f04b562ba",
"hash": "blake3:d279309ea6e5ee6e8fd52504275865cc280dac71cbf528c5b07c98b888bddaba",
"name": "dreamshaper-xl-v2-turbo",
"base": "sdxl",
"type": "main"
}
}
},
"isOpen": true,
@@ -701,4 +714,4 @@
"targetHandle": "style"
}
]
}
}

View File

@@ -1,11 +1,10 @@
{
"id": "default_e71d153c-2089-43c7-bd2c-f61f37d4c1c1",
"name": "Text to Image - SD1.5, LoRA",
"name": "Text to Image with LoRA",
"author": "InvokeAI",
"description": "Simple text to image workflow with a LoRA",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "sd1.5, text to image, lora",
"tags": "text to image, lora, default",
"notes": "",
"exposedFields": [
{

View File

@@ -1,11 +1,10 @@
{
"id": "default_43b0d7f7-6a12-4dcf-a5a4-50c940cbee29",
"name": "Upscaler - SD1.5, Tiled",
"name": "Tiled Upscaling (Beta)",
"author": "Invoke",
"description": "A workflow to upscale an input image with tiled upscaling. ",
"version": "2.1.0",
"contact": "invoke@invoke.ai",
"tags": "sd1.5, upscaling",
"tags": "tiled, upscaling, sd1.5",
"notes": "",
"exposedFields": [
{
@@ -86,7 +85,14 @@
},
"ip_adapter_model": {
"name": "ip_adapter_model",
"label": "IP-Adapter Model (select ip_adapter_sd15)"
"label": "IP-Adapter Model (select ip_adapter_sd15)",
"value": {
"key": "1cc210bb-4d0a-4312-b36c-b5d46c43768e",
"hash": "blake3:3d669dffa7471b357b4df088b99ffb6bf4d4383d5e0ef1de5ec1c89728a3d5a5",
"name": "ip_adapter_sd15",
"base": "sd-1",
"type": "ip_adapter"
}
},
"clip_vision_model": {
"name": "clip_vision_model",
@@ -194,7 +200,14 @@
},
"control_model": {
"name": "control_model",
"label": "Control Model (select control_v11f1e_sd15_tile)"
"label": "Control Model (select contro_v11f1e_sd15_tile)",
"value": {
"key": "773843c8-db1f-4502-8f65-59782efa7960",
"hash": "blake3:f0812e13758f91baf4e54b7dbb707b70642937d3b2098cd2b94cc36d3eba308e",
"name": "control_v11f1e_sd15_tile",
"base": "sd-1",
"type": "controlnet"
}
},
"control_weight": {
"name": "control_weight",
@@ -1802,4 +1815,4 @@
"targetHandle": "unet"
}
]
}
}

View File

@@ -41,36 +41,10 @@ class WorkflowRecordsStorageBase(ABC):
self,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
categories: Optional[list[WorkflowCategory]],
category: WorkflowCategory,
page: int,
per_page: Optional[int],
query: Optional[str],
tags: Optional[list[str]],
has_been_opened: Optional[bool],
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass
@abstractmethod
def counts_by_category(
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided categories."""
pass
@abstractmethod
def counts_by_tag(
self,
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided tags."""
pass
@abstractmethod
def update_opened_at(self, workflow_id: str) -> None:
"""Open a workflow."""
pass

View File

@@ -1,6 +1,6 @@
import datetime
from enum import Enum
from typing import Any, Optional, Union
from typing import Any, Union
import semver
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
@@ -36,7 +36,9 @@ class WorkflowCategory(str, Enum, metaclass=MetaEnum):
class WorkflowMeta(BaseModel):
version: str = Field(description="The version of the workflow schema.")
category: WorkflowCategory = Field(description="The category of the workflow (user or default).")
category: WorkflowCategory = Field(
default=WorkflowCategory.User, description="The category of the workflow (user or default)."
)
@field_validator("version")
def validate_version(cls, version: str):
@@ -60,13 +62,9 @@ class WorkflowWithoutID(BaseModel):
notes: str = Field(description="The notes of the workflow.")
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
meta: WorkflowMeta = Field(description="The meta of the workflow.")
# TODO(psyche): nodes, edges and form are very loosely typed - they are strictly modeled and checked on the frontend.
# TODO: nodes and edges are very loosely typed
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
# TODO(psyche): We have a crapload of workflows that have no form, bc it was added after we introduced workflows.
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
# it is None.
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
model_config = ConfigDict(extra="ignore")
@@ -98,9 +96,7 @@ class WorkflowRecordDTOBase(BaseModel):
name: str = Field(description="The name of the workflow.")
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
opened_at: Optional[Union[datetime.datetime, str]] = Field(
default=None, description="The opened timestamp of the workflow."
)
opened_at: Union[datetime.datetime, str] = Field(description="The opened timestamp of the workflow.")
class WorkflowRecordDTO(WorkflowRecordDTOBase):
@@ -118,15 +114,6 @@ WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):
description: str = Field(description="The description of the workflow.")
category: WorkflowCategory = Field(description="The description of the workflow.")
tags: str = Field(description="The tags of the workflow.")
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)
class WorkflowRecordWithThumbnailDTO(WorkflowRecordDTO):
thumbnail_url: str | None = Field(default=None, description="The URL of the workflow thumbnail.")
class WorkflowRecordListItemWithThumbnailDTO(WorkflowRecordListItemDTO):
thumbnail_url: str | None = Field(default=None, description="The URL of the workflow thumbnail.")

View File

@@ -14,18 +14,18 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowRecordListItemDTO,
WorkflowRecordListItemDTOValidator,
WorkflowRecordOrderBy,
WorkflowValidator,
WorkflowWithoutID,
WorkflowWithoutIDValidator,
)
from invokeai.app.util.misc import uuid_string
SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,28 +33,42 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = self._cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
try:
# Only user workflows may be created by this method
assert workflow.meta.category is WorkflowCategory.User
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
workflow_id,
@@ -68,15 +82,14 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
@@ -88,15 +101,14 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = 'user';
@@ -107,27 +119,27 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(
self,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
categories: Optional[list[WorkflowCategory]],
category: WorkflowCategory,
page: int = 0,
per_page: Optional[int] = None,
query: Optional[str] = None,
tags: Optional[list[str]] = None,
has_been_opened: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# We will construct the query dynamically based on the query params
# The main query to get the workflows / counts
main_query = """
try:
self._lock.acquire()
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
main_query = """
SELECT
workflow_id,
category,
@@ -135,222 +147,51 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
description,
created_at,
updated_at,
opened_at,
tags
opened_at
FROM workflow_library
WHERE category = ?
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
main_query += f" ORDER BY {order_by.value} {direction.value}"
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
self._cursor.execute(main_query, main_params)
rows = self._cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
self._cursor.execute(count_query, count_params)
total = self._cursor.fetchone()[0]
conditions.append(category_condition)
params.extend(category_params)
if per_page:
pages = total // per_page + (total % per_page > 0)
else:
pages = 1 # If no pagination, there is only one page
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
conditions.append(tags_condition)
params.extend(tags_params)
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Put a ring on it
main_query += ";"
count_query += ";"
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
else:
pages = 1 # If no pagination, there is only one page
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page if per_page else total,
pages=pages,
total=total,
)
def counts_by_tag(
self,
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
) -> dict[str, int]:
if not tags:
return {}
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
def counts_by_category(
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page if per_page else total,
pages=pages,
total=total,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -366,68 +207,27 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
"""
try:
cursor = self._conn.cursor()
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
self._lock.acquire()
workflows: list[Workflow] = []
workflows_dir = Path(__file__).parent / Path("default_workflows")
workflow_paths = workflows_dir.glob("*.json")
for path in workflow_paths:
bytes_ = path.read_bytes()
workflow_from_file = WorkflowValidator.validate_json(bytes_)
assert workflow_from_file.id.startswith("default_"), (
f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
)
assert workflow_from_file.meta.category is WorkflowCategory.Default, (
f"Invalid default workflow category: {workflow_from_file.meta.category}"
)
workflows_from_file.append(workflow_from_file)
try:
workflow_from_db = self.get(workflow_from_file.id).workflow
if workflow_from_file != workflow_from_db:
self._invoker.services.logger.debug(
f"Updating library workflow {workflow_from_file.name} ({workflow_from_file.id})"
)
workflows_to_update.append(workflow_from_file)
continue
except WorkflowNotFoundError:
self._invoker.services.logger.debug(
f"Adding missing default workflow {workflow_from_file.name} ({workflow_from_file.id})"
)
workflows_to_add.append(workflow_from_file)
continue
library_workflows_from_db = self.get_many(
order_by=WorkflowRecordOrderBy.Name,
direction=SQLiteDirection.Ascending,
categories=[WorkflowCategory.Default],
).items
workflows_from_file_ids = [w.id for w in workflows_from_file]
for w in library_workflows_from_db:
if w.workflow_id not in workflows_from_file_ids:
self._invoker.services.logger.debug(
f"Deleting obsolete default workflow {w.name} ({w.workflow_id})"
)
# We cannot use the `delete` method here, as it only deletes non-default workflows
cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ?;
""",
(w.workflow_id,),
)
for w in workflows_to_add:
# We cannot use the `create` method here, as it only creates non-default workflows
cursor.execute(
workflow_without_id = WorkflowWithoutIDValidator.validate_json(bytes_)
workflow = Workflow(**workflow_without_id.model_dump(), id=uuid_string())
workflows.append(workflow)
# Only default workflows may be managed by this method
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
self._cursor.execute(
"""--sql
DELETE FROM workflow_library
WHERE category = 'default';
"""
)
for w in workflows:
self._cursor.execute(
"""--sql
INSERT INTO workflow_library (
INSERT OR REPLACE INTO workflow_library (
workflow_id,
workflow
)
@@ -435,19 +235,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.id, w.model_dump_json()),
)
for w in workflows_to_update:
# We cannot use the `update` method here, as it only updates non-default workflows
cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
WHERE workflow_id = ?;
""",
(w.model_dump_json(), w.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()

View File

@@ -1,28 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL import Image
class WorkflowThumbnailServiceBase(ABC):
"""Base class for workflow thumbnail services"""
@abstractmethod
def get_path(self, workflow_id: str, with_hash: bool = True) -> Path:
"""Gets the path to a workflow thumbnail"""
pass
@abstractmethod
def get_url(self, workflow_id: str, with_hash: bool = True) -> str | None:
"""Gets the URL of a workflow thumbnail"""
pass
@abstractmethod
def save(self, workflow_id: str, image: Image.Image) -> None:
"""Saves a workflow thumbnail"""
pass
@abstractmethod
def delete(self, workflow_id: str) -> None:
"""Deletes a workflow thumbnail"""
pass

View File

@@ -1,22 +0,0 @@
class WorkflowThumbnailFileNotFoundException(Exception):
"""Raised when a workflow thumbnail file is not found"""
def __init__(self, message: str = "Workflow thumbnail file not found"):
self.message = message
super().__init__(self.message)
class WorkflowThumbnailFileSaveException(Exception):
"""Raised when a workflow thumbnail file cannot be saved"""
def __init__(self, message: str = "Workflow thumbnail file cannot be saved"):
self.message = message
super().__init__(self.message)
class WorkflowThumbnailFileDeleteException(Exception):
"""Raised when a workflow thumbnail file cannot be deleted"""
def __init__(self, message: str = "Workflow thumbnail file cannot be deleted"):
self.message = message
super().__init__(self.message)

View File

@@ -1,87 +0,0 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowCategory
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_base import WorkflowThumbnailServiceBase
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import (
WorkflowThumbnailFileDeleteException,
WorkflowThumbnailFileNotFoundException,
WorkflowThumbnailFileSaveException,
)
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
class WorkflowThumbnailFileStorageDisk(WorkflowThumbnailServiceBase):
def __init__(self, thumbnails_path: Path):
self._workflow_thumbnail_folder = thumbnails_path
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, workflow_id: str) -> PILImageType:
try:
path = self.get_path(workflow_id)
return Image.open(path)
except FileNotFoundError as e:
raise WorkflowThumbnailFileNotFoundException from e
def save(self, workflow_id: str, image: PILImageType) -> None:
try:
self._validate_storage_folders()
image_path = self._workflow_thumbnail_folder / (workflow_id + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise WorkflowThumbnailFileSaveException from e
def get_path(self, workflow_id: str, with_hash: bool = True) -> Path:
workflow = self._invoker.services.workflow_records.get(workflow_id).workflow
if workflow.meta.category is WorkflowCategory.Default:
default_thumbnails_dir = Path(__file__).parent / Path("default_workflow_thumbnails")
path = default_thumbnails_dir / (workflow_id + ".png")
else:
path = self._workflow_thumbnail_folder / (workflow_id + ".webp")
return path
def get_url(self, workflow_id: str, with_hash: bool = True) -> str | None:
path = self.get_path(workflow_id)
if not self._validate_path(path):
return
url = self._invoker.services.urls.get_workflow_thumbnail_url(workflow_id)
# The image URL never changes, so we must add random query string to it to prevent caching
if with_hash:
url += f"?{uuid_string()}"
return url
def delete(self, workflow_id: str) -> None:
try:
path = self.get_path(workflow_id)
if not self._validate_path(path):
raise WorkflowThumbnailFileNotFoundException
path.unlink()
except WorkflowThumbnailFileNotFoundException as e:
raise WorkflowThumbnailFileNotFoundException from e
except Exception as e:
raise WorkflowThumbnailFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._workflow_thumbnail_folder.mkdir(parents=True, exist_ok=True)

View File

@@ -1,64 +0,0 @@
import logging
import mimetypes
import socket
import torch
def find_open_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_open_port(port=port + 1)
else:
return port
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def enable_dev_reload() -> None:
"""Enable hot reloading on python file changes during development."""
from invokeai.backend.util.logging import InvokeAILogger
try:
import jurigged
except ImportError as e:
raise RuntimeError(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
) from e
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
def apply_monkeypatches() -> None:
"""Apply monkeypatches to fix issues with third-party libraries."""
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
def register_mime_types() -> None:
"""Register additional mime types for windows."""
# Fix for windows mimetypes registry entries being borked.
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")

View File

@@ -1,52 +0,0 @@
import logging
import os
import sys
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger):
"""Configure the PyTorch CUDA memory allocator. See
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf for supported
configurations.
"""
if "torch" in sys.modules:
raise RuntimeError("configure_torch_cuda_allocator() must be called before importing torch.")
# Log a warning if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
prev_cuda_alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)
if prev_cuda_alloc_conf is not None:
if prev_cuda_alloc_conf == pytorch_cuda_alloc_conf:
logger.info(
f"PYTORCH_CUDA_ALLOC_CONF is already set to '{pytorch_cuda_alloc_conf}'. Skipping configuration."
)
return
else:
logger.warning(
f"Attempted to configure the PyTorch CUDA memory allocator with '{pytorch_cuda_alloc_conf}', but PYTORCH_CUDA_ALLOC_CONF is already set to "
f"'{prev_cuda_alloc_conf}'. Skipping configuration."
)
return
# Configure the PyTorch CUDA memory allocator.
# NOTE: It is important that this happens before torch is imported.
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = pytorch_cuda_alloc_conf
import torch
# Relevant docs: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
if not torch.cuda.is_available():
raise RuntimeError(
"Attempted to configure the PyTorch CUDA memory allocator, but no CUDA devices are available."
)
# Verify that the torch allocator was properly configured.
allocator_backend = torch.cuda.get_allocator_backend()
expected_backend = "cudaMallocAsync" if "cudaMallocAsync" in pytorch_cuda_alloc_conf else "native"
if allocator_backend != expected_backend:
raise RuntimeError(
f"Failed to configure the PyTorch CUDA memory allocator. Expected backend: '{expected_backend}', but got "
f"'{allocator_backend}'. Verify that 1) the pytorch_cuda_alloc_conf is set correctly, and 2) that torch is "
"not imported before calling configure_torch_cuda_allocator()."
)
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")

View File

@@ -3,11 +3,7 @@ from typing import Optional
import torch
import torchvision
from invokeai.backend.flux.text_conditioning import (
FluxReduxConditioning,
FluxRegionalTextConditioning,
FluxTextConditioning,
)
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
@@ -36,19 +32,14 @@ class RegionalPromptingExtension:
return order[block_index % len(order)]
@classmethod
def from_text_conditioning(
cls,
text_conditioning: list[FluxTextConditioning],
redux_conditioning: list[FluxReduxConditioning],
img_seq_len: int,
):
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
"""Create a RegionalPromptingExtension from a list of text conditionings.
Args:
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
"""
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning, redux_conditioning)
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
regional_text_conditioning, img_seq_len
)
@@ -211,7 +202,6 @@ class RegionalPromptingExtension:
def _concat_regional_text_conditioning(
cls,
text_conditionings: list[FluxTextConditioning],
redux_conditionings: list[FluxReduxConditioning],
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
@@ -227,26 +217,17 @@ class RegionalPromptingExtension:
global_clip_embedding = text_conditioning.clip_embeddings
break
# Handle T5 text embeddings.
cur_t5_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
# Handle Redux embeddings.
for redux_conditioning in redux_conditionings:
concat_t5_embeddings.append(redux_conditioning.redux_embeddings)
concat_t5_embedding_ranges.append(
Range(
start=cur_t5_embedding_len, end=cur_t5_embedding_len + redux_conditioning.redux_embeddings.shape[1]
)
)
image_masks.append(redux_conditioning.mask)
cur_t5_embedding_len += redux_conditioning.redux_embeddings.shape[1]
image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)

View File

@@ -1,17 +0,0 @@
import torch
# This model definition is based on:
# https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/modules/image_embedders.py#L66
class FluxReduxModel(torch.nn.Module):
def __init__(self, redux_dim: int = 1152, txt_in_features: int = 4096) -> None:
super().__init__()
self.redux_dim = redux_dim
self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3)
self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.redux_down(torch.nn.functional.silu(self.redux_up(x)))

Some files were not shown because too many files have changed in this diff Show More