Compare commits

..

15 Commits

Author SHA1 Message Date
Ryan Dick
3ed6e65a6e Enable LoRAPatcher.apply_smart_lora_patches(...) throughout the stack. 2024-12-12 22:41:50 +00:00
Ryan Dick
52c9646f84 (minor) Rename num_layers -> num_loras in unit tests. 2024-12-12 22:41:50 +00:00
Ryan Dick
7662f0522b Add test_apply_smart_lora_patches_to_partially_loaded_model(...). 2024-12-12 22:41:50 +00:00
Ryan Dick
e50fe69839 Add LoRAPatcher.smart_apply_lora_patches() 2024-12-12 22:41:50 +00:00
Ryan Dick
5a9f884620 Refactor LoRAPatcher slightly in preparation for a 'smart' patcher. 2024-12-12 22:41:46 +00:00
Ryan Dick
edc72d1739 Fix LoRAPatcher.apply_lora_wrapper_patches(...) 2024-12-12 22:33:07 +00:00
Ryan Dick
23f521dc7c Finish consolidating LoRA sidecar wrapper implementations. 2024-12-12 22:33:07 +00:00
Ryan Dick
3d6b93efdd Begin to consolidate the LoRA sidecar and LoRA layer wrapper implementations. 2024-12-12 22:33:07 +00:00
Ryan Dick
3f28d3afad Fix bias handling in LoRAModuleWrapper and add unit test that checks that all LoRA patching methods produce the same outputs. 2024-12-12 22:33:07 +00:00
Ryan Dick
9353bfbdd6 Add LoRA wrapper patching to LoRAPatcher. 2024-12-12 22:33:07 +00:00
Ryan Dick
93f2bc6118 Add LoRA wrapper layer. 2024-12-12 22:33:07 +00:00
Ryan Dick
9019026d6d Fixes to get FLUX Control LoRA working. 2024-12-12 00:19:39 +00:00
Brandon Rising
c195b326ec Lots of updates centered around using the lora patcher rather than changing the modules in the transformer model 2024-12-11 14:14:50 -05:00
Brandon Rising
2f460d2a45 Support bnb quantized nf4 flux models, Use controlnet vae, only support 1 structural lora per transformer. various other refractors and bugfixes 2024-12-10 03:26:29 -05:00
Brandon Rising
4473cba512 Initial setup for flux tools control loras 2024-12-09 16:01:29 -05:00
675 changed files with 12272 additions and 33736 deletions

View File

@@ -76,6 +76,9 @@ jobs:
latest=${{ matrix.gpu-driver == 'cuda' && github.ref == 'refs/heads/main' }}
suffix=-${{ matrix.gpu-driver }},onlatest=false
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
@@ -100,7 +103,7 @@ jobs:
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
# cache-from: |
# type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
# type=gha,scope=main-${{ matrix.gpu-driver }}
# cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
cache-from: |
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
type=gha,scope=main-${{ matrix.gpu-driver }}
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}

View File

@@ -1,85 +0,0 @@
# Runs typegen schema quality checks.
# Frontend types should match the server.
#
# Checks for changes to files before running the checks.
# If always_run is true, always runs the checks.
name: 'typegen checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
jobs:
typegen-checks:
runs-on: ubuntu-22.04
timeout-minutes: 15 # expected run time: <5 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed files
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
src:
- 'pyproject.toml'
- 'invokeai/**'
- name: setup python
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install python dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: pip3 install --use-pep517 --editable="."
- name: install frontend dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: ./.github/actions/install-frontend-deps
- name: copy schema
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: cp invokeai/frontend/web/src/services/api/schema.ts invokeai/frontend/web/src/services/api/schema_orig.ts
shell: bash
- name: generate schema
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: make frontend-typegen
shell: bash
- name: compare files
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: |
if ! diff invokeai/frontend/web/src/services/api/schema.ts invokeai/frontend/web/src/services/api/schema_orig.ts; then
echo "Files are different!";
exit 1;
fi
shell: bash

1
.nvmrc
View File

@@ -1 +0,0 @@
v22.12.0

View File

@@ -30,12 +30,51 @@ Invoke is available in two editions:
|----------------------------------------------------------------------------------------------------------------------------|
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
# Installation
</div>
To get started with Invoke, [Download the Installer](https://www.invoke.com/downloads).
## Quick Start
For detailed step by step instructions, or for instructions on manual/docker installations, visit our documentation on [Installation and Updates][installation docs]
1. Download and unzip the installer from the bottom of the [latest release][latest release link].
2. Run the installer script.
- **Windows**: Double-click on the `install.bat` script.
- **macOS**: Open a Terminal window, drag the file `install.sh` from Finder into the Terminal, and press enter.
- **Linux**: Run `install.sh`.
3. When prompted, enter a location for the install and select your GPU type.
4. Once the install finishes, find the directory you selected during install. The default location is `C:\Users\Username\invokeai` for Windows or `~/invokeai` for Linux/macOS.
5. Run the launcher script (`invoke.bat` for Windows, `invoke.sh` for macOS and Linux) the same way you ran the installer script in step 2.
6. Select option 1 to start the application. Once it starts up, open your browser and go to <http://localhost:9090>.
7. Open the model manager tab to install a starter model and then you'll be ready to generate.
More detail, including hardware requirements and manual install instructions, are available in the [installation documentation][installation docs].
## Docker Container
We publish official container images in Github Container Registry: https://github.com/invoke-ai/InvokeAI/pkgs/container/invokeai. Both CUDA and ROCm images are available. Check the above link for relevant tags.
> [!IMPORTANT]
> Ensure that Docker is set up to use the GPU. Refer to [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] documentation.
### Generate!
Run the container, modifying the command as necessary:
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
Then open `http://localhost:9090` and install some models using the Model Manager tab to begin generating.
For ROCm, add `--device /dev/kfd --device /dev/dri` to the `docker run` command.
### Persist your data
You will likely want to persist your workspace outside of the container. Use the `--volume /home/myuser/invokeai:/invokeai` flag to mount some local directory (using its **absolute** path) to the `/invokeai` path inside the container. Your generated images and models will reside there. You can use this directory with other InvokeAI installations, or switch between runtime directories as needed.
### DIY
Build your own image and customize the environment to match your needs using our `docker-compose` stack. See [README.md](./docker/README.md) in the [docker](./docker) directory.
## Troubleshooting, FAQ and Support

View File

@@ -13,67 +13,52 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
git
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
ENV VIRTUAL_ENV=/opt/venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ENV INVOKEAI_SRC=/opt/invokeai
ENV PYTHON_VERSION=3.11
ENV UV_PYTHON=3.11
ENV UV_COMPILE_BYTECODE=1
ENV UV_LINK_MODE=copy
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
ARG GPU_DRIVER=cuda
ARG TARGETPLATFORM="linux/amd64"
# unused but available
ARG BUILDPLATFORM
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
RUN mkdir -p ${VIRTUAL_ENV} && \
mkdir -p ${INVOKEAI_SRC} && \
chmod -R a+w /opt && \
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
chmod -R a+w /opt
USER ubuntu
# Install python
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
uv python install ${PYTHON_VERSION}
# Install python and create the venv
RUN uv python install ${PYTHON_VERSION} && \
uv venv --relocatable --prompt "invoke" --python ${PYTHON_VERSION} ${VIRTUAL_ENV}
WORKDIR ${INVOKEAI_SRC}
COPY invokeai ./invokeai
COPY pyproject.toml ./
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# Editable mode helps use the same image for development:
# the local working copy can be bind-mounted into the image
# at path defined by ${INVOKEAI_SRC}
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
fi && \
uv sync --no-install-project
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
COPY invokeai invokeai
COPY pyproject.toml .
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv sync
uv pip install --python ${PYTHON_VERSION} $extra_index_url_arg -e "."
#### Build the Web UI ------------------------------------
FROM docker.io/node:22-slim AS web-builder
FROM node:20-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
@@ -113,7 +98,6 @@ RUN apt update && apt install -y --no-install-recommends \
ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV PYTHON_VERSION=3.11
ENV INVOKEAI_ROOT=/invokeai
ENV INVOKEAI_HOST=0.0.0.0
@@ -125,7 +109,7 @@ ENV CONTAINER_GID=${CONTAINER_GID:-1000}
# Install `uv` for package management
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
USER ubuntu
RUN uv python install ${PYTHON_VERSION}
USER root

View File

@@ -1,50 +1,41 @@
# Release Process
The Invoke application is published as a python package on [PyPI]. This includes both a source distribution and built distribution (a wheel).
The app is published in twice, in different build formats.
Most users install it with the [Launcher](https://github.com/invoke-ai/launcher/), others with `pip`.
The launcher uses GitHub as the source of truth for available releases.
## Broad Strokes
- Merge all changes and bump the version in the codebase.
- Tag the release commit.
- Wait for the release workflow to complete.
- Approve the PyPI publish jobs.
- Write GH release notes.
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
## General Prep
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
Make a developer call-out for PRs to merge. Merge and test things out.
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
## Release Workflow
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
It is triggered on **tag push**, when the tag matches `v*`.
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
### Triggering the Workflow
Ensure all commits that should be in the release are merged, and you have pulled them locally.
Run `make tag-release` to tag the current commit and kick off the workflow.
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
The release may also be dispatched [manually].
### Workflow Jobs and Process
The workflow consists of a number of concurrently-run checks and tests, then two final publish jobs.
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
The publish jobs require manual approval and are only run if the other jobs succeed.
#### `check-version` Job
This job ensures that the `invokeai` python package version specifier matches the tag for the release. The version specifier is pulled from the `__version__` variable in `invokeai/version/invokeai_version.py`.
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
This job uses [samuelcolvin/check-python-version].
@@ -52,52 +43,62 @@ This job uses [samuelcolvin/check-python-version].
#### Check and Test Jobs
Next, these jobs run and must pass. They are the same jobs that are run for every PR.
- **`python-tests`**: runs `pytest` on matrix of platforms
- **`python-checks`**: runs `ruff` (format and lint)
- **`frontend-tests`**: runs `vitest`
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
- **`typegen-checks`**: ensures the frontend and backend types are synced
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
> **TODO** We should add an end-to-end test job that generates an image.
#### `build-installer` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
- **`dist`**: the python distribution, to be published on PyPI
- **`InvokeAI-installer-${VERSION}.zip`**: the legacy install scripts
You don't need to download either of these files.
> The legacy install scripts are no longer used, but we haven't updated the workflow to skip building them.
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
#### Sanity Check & Smoke Test
At this point, the release workflow pauses as the remaining publish jobs require approval.
At this point, the release workflow pauses as the remaining publish jobs require approval. Time to test the installer.
It's possible to test the python package before it gets published to PyPI. We've never had problems with it, so it's not necessary to do this.
Because the installer pulls from PyPI, and we haven't published to PyPI yet, you will need to install from the wheel:
But, if you want to be extra-super careful, here's how to test it:
- Download and unzip `dist.zip` and the installer from the **Summary** tab of the workflow
- Run the installer script using the `--wheel` CLI arg, pointing at the wheel:
- Download the `dist.zip` build artifact from the `build-installer` job
- Unzip it and find the wheel file
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/) - but instead of installing from PyPI, install from the wheel
- Test the app
```sh
./install.sh --wheel ../InvokeAI-4.0.0rc6-py3-none-any.whl
```
- Install to a temporary directory so you get the new user experience
- Download a model and generate
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation as if the installer got the wheel from PyPI.
##### Something isn't right
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?) and start over.
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?).
Now you can start from the top:
- Fix the issues and PR the fixes per usual
- Get the PR approved and merged per usual
- Switch to `main` and pull in the fixes
- Run `make tag-release` to move the tag to `HEAD` (which has the fixes) and kick off the release workflow again
- Re-do the sanity check
#### PyPI Publish Jobs
The publish jobs will not run if any of the previous jobs fail.
The publish jobs will run if any of the previous jobs fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
- Click the **Review deployments** button
- Select the environment (either `testpypi` or `pypi` - typically you select both)
- Select the environment (either `testpypi` or `pypi`)
- Click **Approve and deploy**
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
@@ -112,33 +113,46 @@ If there are no incidents, contact @hipsterusername or @lstein, who have owner a
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release for some reason:
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
- Approve this publish job without approving the prod publish
- Let it finish
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/), making sure to use the Test PyPI index URL: `https://test.pypi.org/simple/`
- Test the app
If approved and successful, you could try out the test release like this:
```sh
# Create a new virtual environment
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
# Install the distribution from Test PyPI
pip install --index-url https://test.pypi.org/simple/ invokeai
# Run and test the app
invokeai-web
# Cleanup
deactivate
rm -rf ~/.test-invokeai-dist
```
#### `publish-pypi` Job
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
It's a good idea to wait to approve and run this job until you have the release notes ready!
## Publish the GitHub Release with installer
## Prep and publish the GitHub Release
Once the release is published to PyPI, it's time to publish the GitHub release.
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
2. The **Generate release notes** button automatically inserts the changelog and new contributors. Make sure to select the correct tags for this release and the last stable release. GH often selects the wrong tags - do this manually.
3. Write the release notes, describing important changes. Contributions from community members should be shouted out. Use the GH-generated changelog to see all contributors. If there are Weblate translation updates, open that PR and shout out every person who contributed a translation.
4. Check **Set as a pre-release** if it's a pre-release.
5. Approve and wait for the `publish-pypi` job to finish if you haven't already.
6. Publish the GH release.
7. Post the release in Discord in the [releases](https://discord.com/channels/1020123559063990373/1149260708098359327) channel with abbreviated notes. For example:
> Invoke v5.7.0 (stable): <https://github.com/invoke-ai/InvokeAI/releases/tag/v5.7.0>
>
> It's a pretty big one - Form Builder, Metadata Nodes (thanks @SkunkWorxDark!), and much more.
8. Right click the message in releases and copy the link to it. Then, post that link in the [new-release-discussion](https://discord.com/channels/1020123559063990373/1149506274971631688) channel. For example:
> Invoke v5.7.0 (stable): <https://discord.com/channels/1020123559063990373/1149260708098359327/1344521744916021248>
1. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
1. Use `scripts/get_external_contributions.py` to get a list of external contributions to shout out in the release notes.
1. Upload the zip file created in **`build`** job into the Assets section of the release notes.
1. Check **Set as a pre-release** if it's a pre-release.
1. Check **Create a discussion for this release**.
1. Publish the release.
1. Announce the release in Discord.
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
## Manual Build
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
No checks are run, it just builds.
## Manual Release
@@ -146,10 +160,12 @@ The `release` workflow can be dispatched manually. You must dispatch the workflo
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
[PyPI]: https://pypi.org/
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
[Test PyPI]: https://test.pypi.org/
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
[ncipollo/release-action]: https://github.com/ncipollo/release-action
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version

View File

@@ -39,7 +39,7 @@ It has two sections - one for internal use and one for user settings:
```yaml
# Internal metadata - do not edit:
schema_version: 4.0.2
schema_version: 4
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
host: 0.0.0.0 # serve the app on your local network
@@ -83,10 +83,6 @@ A subset of settings may be specified using CLI args:
- `--root`: specify the root directory
- `--config`: override the default `invokeai.yaml` file location
### Low-VRAM Mode
See the [Low-VRAM mode docs][low-vram] for details on enabling this feature.
### All Settings
Following the table are additional explanations for certain settings.
@@ -118,10 +114,6 @@ remote_api_tokens:
The provided token will be added as a `Bearer` token to the network requests to download the model files. As far as we know, this works for all model marketplaces that require authorization.
!!! tip "HuggingFace Models"
If you get an error when installing a HF model using a URL instead of repo id, you may need to [set up a HF API token](https://huggingface.co/settings/tokens) and add an entry for it under `remote_api_tokens`. Use `huggingface.co` for `url_regex`.
#### Model Hashing
Models are hashed during installation, providing a stable identifier for models across all platforms. Hashing is a one-time operation.
@@ -189,4 +181,3 @@ The `log_format` option provides several alternative formats:
[basic guide to yaml files]: https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/
[Model Marketplace API Keys]: #model-marketplace-api-keys
[low-vram]: ./features/low-vram.md

View File

@@ -1364,6 +1364,7 @@ the in-memory loaded model:
|----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
### get_model_by_key(key, [submodel]) -> LoadedModel

View File

@@ -1,10 +1,12 @@
# Dev Environment
To make changes to Invoke's backend, frontend or documentation, you'll need to set up a dev environment.
To make changes to Invoke's backend, frontend, or documentation, you'll need to set up a dev environment.
If you only want to make changes to the docs site, you can skip the frontend dev environment setup as described in the below guide.
If you just want to use Invoke, you should use the [installer][installer link].
If you just want to use Invoke, you should use the [launcher][launcher link].
!!! info "Why do I need the frontend toolchain?"
The repo doesn't contain a build of the frontend. You'll be responsible for rebuilding it every time you pull in new changes, or run it in dev mode (which incurs a substantial performance penalty).
!!! warning
@@ -15,66 +17,84 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
## Setup
1. Run through the [requirements][requirements link].
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
- Add `-e` after the `install` operation to make this an [editable install][editable install link]. That means your changes to the python code will be reflected when you restart the Invoke server.
- When installing the `invokeai` package, add the `dev`, `test` and `docs` package options to the package specifier. You may or may not need the `xformers` option - follow the manual install guide to figure that out. So, your package specifier will be either `".[dev,test,docs]"` or `".[dev,test,docs,xformers]"`. Note the quotes!
With the modifications made, the install command should look something like this:
4. Create a python virtual environment inside the directory you just created:
```sh
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
python3 -m venv .venv --prompt InvokeAI-Dev
```
5. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
This is because the UI build is not distributed with the source code. You need to build it manually. End the running server instance.
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
6. Install the frontend dev toolchain:
- [`nodejs`](https://nodejs.org/) (v20+)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
7. Do a production build of the frontend:
5. Activate the venv (you'll need to do this every time you want to run the app):
```sh
cd <PATH_TO_INVOKEAI_REPO>/invokeai/frontend/web
source .venv/bin/activate
```
6. Install the repo as an [editable install][editable install link]:
```sh
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
Refer to the [manual installation][manual install link] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
7. Install the frontend dev toolchain:
- [`nodejs`](https://nodejs.org/) (recommend v20 LTS)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
8. Do a production build of the frontend:
```sh
cd PATH_TO_INVOKEAI_REPO/invokeai/frontend/web
pnpm i
pnpm build
```
8. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
9. Start the application:
```sh
cd PATH_TO_INVOKEAI_REPO
python scripts/invokeai-web.py
```
10. Access the UI at `localhost:9090`.
## Updating the UI
You'll need to run `pnpm build` every time you pull in new changes.
Another option is to skip the build and instead run the UI in dev mode:
You'll need to run `pnpm build` every time you pull in new changes. Another option is to skip the build and instead run the app in dev mode:
```sh
pnpm dev
```
This starts a vite dev server for the UI at `127.0.0.1:5173`, which you will use instead of `127.0.0.1:9090`.
This starts a dev server at `localhost:5173`, which you will use instead of `localhost:9090`.
The dev mode is substantially slower than the production build but may be more convenient if you just need to test things out. It will hot-reload the UI as you make changes to the frontend code. Sometimes the hot-reload doesn't work, and you need to manually refresh the browser tab.
The dev mode is substantially slower than the production build but may be more convenient if you just need to test things out.
## Documentation
The documentation is built with `mkdocs`. It provides a hot-reload dev server for the docs. Start it with `mkdocs serve`.
The documentation is built with `mkdocs`. To preview it locally, you need a additional set of packages installed.
[launcher link]: ../installation/quick_start.md
```sh
# after activating the venv
pip install -e ".[docs]"
```
Then, you can start a live docs dev server, which will auto-refresh when you edit the docs:
```sh
mkdocs serve
```
On macOS and Linux, there is a `make` target for this:
```sh
make docs
```
[installer link]: ../installation/installer.md
[forking link]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo
[requirements link]: ../installation/requirements.md
[repo link]: https://github.com/invoke-ai/InvokeAI

View File

@@ -1,18 +1,26 @@
# FAQ
!!! info "How to Reinstall"
Many issues can be resolved by re-installing the application. You won't lose any data by re-installing. We suggest downloading the [latest release](https://github.com/invoke-ai/InvokeAI/releases/latest) and using it to re-install the application. Consult the [installer guide](./installation/installer.md) for more information.
When you run the installer, you'll have an option to select the version to install. If you aren't ready to upgrade, you choose the current version to fix a broken install.
If the troubleshooting steps on this page don't get you up and running, please either [create an issue] or hop on [discord] for help.
## How to Install
Follow the [Quick Start guide](./installation/quick_start.md) to install Invoke.
You can download the latest installers [here](https://github.com/invoke-ai/InvokeAI/releases).
Note that any releases marked as _pre-release_ are in a beta state. You may experience some issues, but we appreciate your help testing those! For stable/reliable installations, please install the [latest release].
## Downloading models and using existing models
The Model Manager tab in the UI provides a few ways to install models, including using your already-downloaded models. You'll see a popup directing you there on first startup. For more information, see the [model install docs].
## Missing models after updating from v3
## Missing models after updating to v4
If you find some models are missing after updating from v3, it's likely they weren't correctly registered before the update and didn't get picked up in the migration.
If you find some models are missing after updating to v4, it's likely they weren't correctly registered before the update and didn't get picked up in the migration.
You can use the `Scan Folder` tab in the Model Manager UI to fix this. The models will either be in the old, now-unused `autoimport` folder, or your `models` folder.
@@ -29,27 +37,115 @@ Follow the same steps to scan and import the missing models.
## Slow generation
- Check the [system requirements] to ensure that your system is capable of generating images.
- Follow the [Low-VRAM mode guide](./features/low-vram.md) to optimize performance.
- Check that your generations are happening on your GPU (if you have one). Invoke will log what is being used for generation upon startup. If your GPU isn't used, re-install to and ensure you select the appropriate GPU option.
- If you are on Windows with an Nvidia GPU, you may have exceeded your GPU's VRAM capacity and are triggering Nvidia's "sysmem fallback". There's a guide to opt out of this behaviour in the [Low-VRAM mode guide](./features/low-vram.md).
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
- If you are on Windows, you may have exceeded your GPU's VRAM capacity and are using slower [shared GPU memory](#shared-gpu-memory-windows). There's a guide to opt out of this behaviour in the linked FAQ entry.
## Shared GPU Memory (Windows)
!!! tip "Nvidia GPUs with driver 536.40"
This only applies to current Nvidia cards with driver 536.40 or later, released in June 2023.
When the GPU doesn't have enough VRAM for a task, Windows is able to allocate some of its CPU RAM to the GPU. This is much slower than VRAM, but it does allow the system to generate when it otherwise might no have enough VRAM.
When shared GPU memory is used, generation slows down dramatically - but at least it doesn't crash.
If you'd like to opt out of this behavior and instead get an error when you exceed your GPU's VRAM, follow [this guide from Nvidia](https://nvidia.custhelp.com/app/answers/detail/a_id/5490).
Here's how to get the python path required in the linked guide:
- Run `invoke.bat`.
- Select option 2 for developer console.
- At least one python path will be printed. Copy the path that includes your invoke installation directory (typically the first).
## Installer cannot find python (Windows)
Ensure that you checked **Add python.exe to PATH** when installing Python. This can be found at the bottom of the Python Installer window. If you already have Python installed, you can re-run the python installer, choose the Modify option and check the box.
## Triton error on startup
This can be safely ignored. Invoke doesn't use Triton, but if you are on Linux and wish to dismiss the error, you can install Triton.
This can be safely ignored. InvokeAI doesn't use Triton, but if you are on Linux and wish to dismiss the error, you can install Triton.
## Unable to Copy on Firefox
## Updated to 3.4.0 and xformers cant load C++/CUDA
Firefox does not allow Invoke to directly access the clipboard by default. As a result, you may be unable to use certain copy functions. You can fix this by configuring Firefox to allow access to write to the clipboard:
An issue occurred with your PyTorch update. Follow these steps to fix :
- Go to `about:config` and click the Accept button
- Search for `dom.events.asyncClipboard.clipboardItem`
- Set it to `true` by clicking the toggle button
- Restart Firefox
1. Launch your invoke.bat / invoke.sh and select the option to open the developer console
2. Run:`pip install ".[xformers]" --upgrade --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121`
- If you run into an error with `typing_extensions`, re-open the developer console and run: `pip install -U typing-extensions`
Note that v3.4.0 is an old, unsupported version. Please upgrade to the [latest release].
## Install failed and says `pip` is out of date
An out of date `pip` typically won't cause an installation to fail. The cause of the error can likely be found above the message that says `pip` is out of date.
If you saw that warning but the install went well, don't worry about it (but you can update `pip` afterwards if you'd like).
## Replicate image found online
Most example images with prompts that you'll find on the internet have been generated using different software, so you can't expect to get identical results. In order to reproduce an image, you need to replicate the exact settings and processing steps, including (but not limited to) the model, the positive and negative prompts, the seed, the sampler, the exact image size, any upscaling steps, etc.
## OSErrors on Windows while installing dependencies
During a zip file installation or an update, installation stops with an error like this:
![broken-dependency-screenshot](./assets/troubleshooting/broken-dependency.png){:width="800px"}
To resolve this, re-install the application as described above.
## HuggingFace install failed due to invalid access token
Some HuggingFace models require you to authenticate using an [access token].
Invoke doesn't manage this token for you, but it's easy to set it up:
- Follow the instructions in the link above to create an access token. Copy it.
- Run the launcher script.
- Select option 2 (developer console).
- Paste the following command:
```sh
python -c "import huggingface_hub; huggingface_hub.login()"
```
- Paste your access token when prompted and press Enter. You won't see anything when you paste it.
- Type `n` if prompted about git credentials.
If you get an error, try the command again - maybe the token didn't paste correctly.
Once your token is set, start Invoke and try downloading the model again. The installer will automatically use the access token.
If the install still fails, you may not have access to the model.
## Stable Diffusion XL generation fails after trying to load UNet
InvokeAI is working in other respects, but when trying to generate
images with Stable Diffusion XL you get a "Server Error". The text log
in the launch window contains this log line above several more lines of
error messages:
`INFO --> Loading model:D:\LONG\PATH\TO\MODEL, type sdxl:main:unet`
This failure mode occurs when there is a network glitch during
downloading the very large SDXL model.
To address this, first go to the Model Manager and delete the
Stable-Diffusion-XL-base-1.X model. Then, click the HuggingFace tab,
paste the Repo ID stabilityai/stable-diffusion-xl-base-1.0 and install
the model.
## Package dependency conflicts during installation or update
If you have previously installed InvokeAI or another Stable Diffusion
package, the installer may occasionally pick up outdated libraries and
either the installer or `invoke` will fail with complaints about
library conflicts.
To resolve this, re-install the application as described above.
## Invalid configuration file
Everything seems to install ok, you get a `ValidationError` when starting up the app.
@@ -58,9 +154,64 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them.
## Out of Memory Errors
## `ModuleNotFoundError: No module named 'controlnet_aux'`
The models are large, VRAM is expensive, and you may find yourself faced with Out of Memory errors when generating images. Follow our [Low-VRAM mode guide](./features/low-vram.md) to configure Invoke to prevent these.
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer][latest release], selecting your current install location
## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself
faced with Out of Memory errors when generating images. Here are some
tips to reduce the problem:
!!! info "Optimizing for GPU VRAM"
=== "4GB VRAM GPU"
This should be adequate for 512x512 pixel images using Stable Diffusion 1.5
and derived models, provided that you do not use the NSFW checker. It won't be loaded unless you go into the UI settings and turn it on.
If you are on a CUDA-enabled GPU, we will automatically use xformers or torch-sdp to reduce VRAM requirements, though you can explicitly configure this. See the [configuration docs].
=== "6GB VRAM GPU"
This is a border case. Using the SD 1.5 series you should be able to
generate images up to 640x640 with the NSFW checker enabled, and up to
1024x1024 with it disabled.
If you run into persistent memory issues there are a series of
environment variables that you can set before launching InvokeAI that
alter how the PyTorch machine learning library manages memory. See
<https://pytorch.org/docs/stable/notes/cuda.html#memory-management> for
a list of these tweaks.
=== "12GB VRAM GPU"
This should be sufficient to generate larger images up to about 1280x1280.
## Checkpoint Models Load Slowly or Use Too Much RAM
The difference between diffusers models (a folder containing multiple
subfolders) and checkpoint models (a file ending with .safetensors or
.ckpt) is that InvokeAI is able to load diffusers models into memory
incrementally, while checkpoint models must be loaded all at
once. With very large models, or systems with limited RAM, you may
experience slowdowns and other memory-related issues when loading
checkpoint models.
To solve this, go to the Model Manager tab (the cube), select the
checkpoint model that's giving you trouble, and press the "Convert"
button in the upper right of your browser window. This will convert the
checkpoint into a diffusers model, after which loading should be
faster and less memory-intensive.
## Memory Leak (Linux)
@@ -102,6 +253,8 @@ Note the differences between memory allocated as chunks in an arena vs. memory a
[model install docs]: ./installation/models.md
[system requirements]: ./installation/requirements.md
[latest release]: https://github.com/invoke-ai/InvokeAI/releases/latest
[create an issue]: https://github.com/invoke-ai/InvokeAI/issues
[discord]: https://discord.gg/ZmtBAhwWhy
[configuration docs]: ./configuration.md
[access token]: https://huggingface.co/docs/hub/security-tokens#how-to-manage-user-access-tokens

Binary file not shown.

Before

Width:  |  Height:  |  Size: 72 KiB

View File

@@ -1,176 +0,0 @@
---
title: Low-VRAM mode
---
As of v5.6.0, Invoke has a low-VRAM mode. It works on systems with dedicated GPUs (Nvidia GPUs on Windows/Linux and AMD GPUs on Linux).
This allows you to generate even if your GPU doesn't have enough VRAM to hold full models. Most users should be able to run even the beefiest models - like the ~24GB unquantised FLUX dev model.
## Enabling Low-VRAM mode
To enable Low-VRAM mode, add this line to your `invokeai.yaml` configuration file, then restart Invoke:
```yaml
enable_partial_loading: true
```
**Windows users should also [disable the Nvidia sysmem fallback](#disabling-nvidia-sysmem-fallback-windows-only)**.
It is possible to fine-tune the settings for best performance or if you still get out-of-memory errors (OOMs).
!!! tip "How to find `invokeai.yaml`"
The `invokeai.yaml` configuration file lives in your install directory. To access it, run the **Invoke Community Edition** launcher and click the install location. This will open your install directory in a file explorer window.
You'll see `invokeai.yaml` there and can edit it with any text editor. After making changes, restart Invoke.
If you don't see `invokeai.yaml`, launch Invoke once. It will create the file on its first startup.
## Details and fine-tuning
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
- Partial model loading (`enable_partial_loading`)
- PyTorch CUDA allocator config (`pytorch_cuda_alloc_conf`)
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
- Working memory (`device_working_mem_gb`)
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
### Partial model loading
Invoke's partial model loading works by streaming model "layers" between RAM and VRAM as they are needed.
When an operation needs layers that are not in VRAM, but there isn't enough room to load them, inactive layers are offloaded to RAM to make room.
#### Enabling partial model loading
As described above, you can enable partial model loading by adding this line to `invokeai.yaml`:
```yaml
enable_partial_loading: true
```
### PyTorch CUDA allocator config
The PyTorch CUDA allocator's behavior can be configured using the `pytorch_cuda_alloc_conf` config. Tuning the allocator configuration can help to reduce the peak reserved VRAM. The optimal configuration is dependent on many factors (e.g. device type, VRAM, CUDA driver version, etc.), but switching from PyTorch's native allocator to using CUDA's built-in allocator works well on many systems. To try this, add the following line to your `invokeai.yaml` file:
```yaml
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
```
A more complete explanation of the available configuration options is [here](https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf).
### Dynamic RAM and VRAM cache sizes
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
By default, Invoke manages these caches' sizes dynamically for best performance.
#### Fine-tuning cache sizes
Prior to v5.6.0, the cache sizes were static, and for best performance, many users needed to manually fine-tune the `ram` and `vram` settings in `invokeai.yaml`.
As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings are no longer used, and new settings are added to configure the cache.
**Most users will not need to fine-tune the cache sizes.**
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
```yaml
# The default max cache RAM size is logged on InvokeAI startup. It is determined based on your system RAM / VRAM.
# You can override the default value by setting `max_cache_ram_gb`.
# Increasing `max_cache_ram_gb` will increase the amount of RAM used to cache inactive models, resulting in faster model
# reloads for the cached models.
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
# might be a good value to achieve aggressive model caching.
max_cache_ram_gb: 28
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
# consideration the VRAM used by other processes).
# You can override the default value by setting `max_cache_vram_gb`.
# CAUTION: Most users should not manually set this value. See warning below.
max_cache_vram_gb: 16
```
!!! warning "Max safe value for `max_cache_vram_gb`"
Most users should not manually configure the `max_cache_vram_gb`. This configuration value takes precedence over the `device_working_mem_gb` and any operations that explicitly reserve additional working memory (e.g. VAE decode). As such, manually configuring it increases the likelihood of encountering out-of-memory errors.
For users who wish to configure `max_cache_vram_gb`, the max safe value can be determined by subtracting `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
Most users who override `max_cache_vram_gb` are doing so because they wish to use significantly less VRAM, and should be setting `max_cache_vram_gb` to a value significantly less than the 'max safe value'.
### Working memory
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
Invoke reserves 3GB VRAM as working memory by default, which is enough for most use-cases. However, it is possible to fine-tune this setting if you still get OOMs.
#### Fine-tuning working memory
You can increase the working memory size in `invokeai.yaml` to prevent OOMs:
```yaml
# The default is 3GB - bump it up to 4GB to prevent OOMs.
device_working_mem_gb: 4
```
!!! tip "Operations may request more working memory"
For some operations, we can determine VRAM requirements in advance and allocate additional working memory to prevent OOMs.
VAE decoding is one such operation. This operation converts the generation process's output into an image. For large image outputs, this might use more than the default working memory size of 3GB.
During this decoding step, Invoke calculates how much VRAM will be required to decode and requests that much VRAM from the model manager. If the amount exceeds the working memory size, the model manager will offload cached model layers from VRAM until there's enough VRAM to decode.
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
### Keeping a RAM weight copy
Invoke has the option of keeping a RAM copy of all model weights, even when they are loaded onto the GPU. This optimization is _on_ by default, and enables faster model switching and LoRA patching. Disabling this feature will reduce the average RAM load while running Invoke (peak RAM likely won't change), at the cost of slower model switching and LoRA patching. If you have limited RAM, you can disable this optimization:
```yaml
# Set to false to reduce the average RAM usage at the cost of slower model switching and LoRA patching.
keep_ram_copy_of_weights: false
```
### Disabling Nvidia sysmem fallback (Windows only)
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
It is strongly suggested to disable this feature:
- Open the **NVIDIA Control Panel** app.
- Expand **3D Settings** on the left panel.
- Click **Manage 3D Settings** in the left panel.
- Find **CUDA - Sysmem Fallback Policy** in the right panel and set it to **Prefer No Sysmem Fallback**.
![cuda-sysmem-fallback](./cuda-sysmem-fallback.png)
!!! tip "Invoke does the same thing, but better"
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.
## Troubleshooting
### Windows page file
Invoke has high virtual memory (a.k.a. 'committed memory') requirements. This can cause issues on Windows if the page file size limits are hit. (See this issue for the technical details on why this happens: https://github.com/invoke-ai/InvokeAI/issues/7563).
If you run out of page file space, InvokeAI may crash. Often, these crashes will happen with one of the following errors:
- InvokeAI exits with Windows error code `3221225477`
- InvokeAI crashes without an error, but `eventvwr.msc` reveals an error with code `0xc0000005` (the hex equivalent of `3221225477`)
If you are running out of page file space, try the following solutions:
- Make sure that you have sufficient disk space for the page file to grow. Watch your disk usage as Invoke runs. If it climbs near 100% leading up to the crash, then this is very likely the source of the issue. Clear out some disk space to resolve the issue.
- Make sure that your page file is set to "System managed size" (this is the default) rather than a custom size. Under the "System managed size" policy, the page file will grow dynamically as needed.

View File

@@ -50,9 +50,11 @@ title: Invoke
## Installation
The [Invoke Launcher](installation/quick_start.md) is the easiest way to install, update and run Invoke on Windows, macOS and Linux.
The [installer script](installation/installer.md) is the easiest way to install and update the application.
You can also install Invoke as [python package](installation/manual.md) or with [docker](installation/docker.md).
You can also install Invoke as python package [via PyPI](installation/manual.md) or [docker](installation/docker.md).
See the [installation section](./installation/index.md) for more information.
## Help

View File

@@ -4,7 +4,7 @@ title: Docker
!!! warning "macOS users"
Docker can not access the GPU on macOS, so your generation speeds will be slow. Use the [launcher](./quick_start.md) instead.
Docker can not access the GPU on macOS, so your generation speeds will be slow. Use the [installer](./installer.md) instead.
!!! tip "Linux and Windows Users"

View File

@@ -0,0 +1,36 @@
# Installation and Updating Overview
Before installing, review the [installation requirements](./requirements.md) to ensure your system is set up properly.
See the [FAQ](../faq.md) for frequently-encountered installation issues.
If you need more help, join our [discord](https://discord.gg/ZmtBAhwWhy) or [create a GitHub issue](https://github.com/invoke-ai/InvokeAI/issues).
## Automated Installer & Updates
✅ The automated [installer](./installer.md) is the best way to install Invoke.
⬆️ The same installer is also the best way to update Invoke - simply rerun it for the same folder you installed to.
The installation process simply manages installation for the core libraries & application dependencies that run Invoke.
Models, images, or other assets in the Invoke root folder won't be affected by the installation process.
## Manual Install
If you are familiar with python and want more control over the packages that are installed, you can [install Invoke manually via PyPI](./manual.md).
Updates are managed by reinstalling the latest version through PyPi.
## Developer Install
If you want to contribute to InvokeAI, you'll need to set up a [dev environment](../contributing/dev-environment.md).
## Docker
Invoke publishes docker images. See the [docker installation guide](./docker.md) for details.
## Other Installation Guides
- [PyPatchMatch](./patchmatch.md)
- [Installing Models](./models.md)

View File

@@ -1,10 +1,4 @@
# Legacy Scripts
!!! warning "Legacy Scripts"
We recommend using the Invoke Launcher to install and update Invoke. It's a desktop application for Windows, macOS and Linux. It takes care of a lot of nitty gritty details for you.
Follow the [quick start guide](./quick_start.md) to get started.
# Automatic Install & Updates
!!! tip "Use the installer to update"

View File

@@ -4,11 +4,11 @@
**Python experience is mandatory.**
If you want to use Invoke locally, you should probably use the [launcher](./quick_start.md).
If you want to use Invoke locally, you should probably use the [installer](./installer.md).
If you want to contribute to Invoke or run the app on the latest dev branch, instead follow the [dev environment](../contributing/dev-environment.md) guide.
If you want to contribute to Invoke, instead follow the [dev environment](../contributing/dev-environment.md) guide.
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the launcher that you'll need to manage manually, described in this guide.
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer and launcher that you'll need to manage manually, described in this guide.
## Requirements
@@ -16,39 +16,43 @@ Before you start, go through the [installation requirements](./requirements.md).
## Walkthrough
We'll use [`uv`](https://github.com/astral-sh/uv) to install python and create a virtual environment, then install the `invokeai` package. `uv` is a modern, very fast alternative to `pip`.
The following commands vary depending on the version of Invoke being installed and the system onto which it is being installed.
1. Install `uv` as described in its [docs](https://docs.astral.sh/uv/getting-started/installation/#standalone-installer). We suggest using the standalone installer method.
Run `uv --version` to confirm that `uv` is installed and working. After installation, you may need to restart your terminal to get access to `uv`.
2. Create a directory for your installation, typically in your home directory (e.g. `~/invokeai` or `$Home/invokeai`):
1. Create a directory to contain your InvokeAI library, configuration files, and models. This is known as the "runtime" or "root" directory, and typically lives in your home directory under the name `invokeai`.
=== "Linux/macOS"
```bash
mkdir ~/invokeai
cd ~/invokeai
```
=== "Windows (PowerShell)"
```bash
mkdir $Home/invokeai
cd $Home/invokeai
```
3. Create a virtual environment in that directory:
1. Enter the root directory and create a virtual Python environment within it named `.venv`.
```sh
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
```
!!! warning "Virtual Environment Location"
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
While you may create the virtual environment anywhere in the file system, we recommend that you create it within the root directory as shown here. This allows the application to automatically detect its data directories.
4. Activate the virtual environment:
If you choose a different location for the venv, then you _must_ set the `INVOKEAI_ROOT` environment variable or specify the root directory using the `--root` CLI arg.
=== "Linux/macOS"
```bash
cd ~/invokeai
python3 -m venv .venv --prompt InvokeAI
```
=== "Windows (PowerShell)"
```bash
cd $Home/invokeai
python3 -m venv .venv --prompt InvokeAI
```
1. Activate the new environment:
=== "Linux/macOS"
@@ -56,48 +60,41 @@ The following commands vary depending on the version of Invoke being installed a
source .venv/bin/activate
```
=== "Windows (PowerShell)"
=== "Windows"
```ps
.venv\Scripts\activate
```
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
!!! info "Permissions Error (Windows)"
6. Determine the package package specifier to use when installing. This is a performance optimization.
If you get a permissions error at this point, run this command and try again.
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
`Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser`
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
The command-line prompt should change to to show `(InvokeAI)`, indicating the venv is active.
=== "Invoke v5 or later"
1. Make sure that pip is installed in your virtual environment and up to date:
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
- **In all other cases, do not use an index.**
=== "Invoke v4"
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
- **In all other cases, do not use an index.**
8. Install the `invokeai` package. Substitute the package specifier and version.
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
```bash
python3 -m pip install --upgrade pip
```
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
```
- You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
- If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not strictly necessary. PyTorch includes an implementation of the SDP attention algorithm with similar performance for most GPUs.
```bash
pip install "InvokeAI[xformers]" --use-pep517
```
1. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
=== "Linux/macOS"
@@ -105,31 +102,17 @@ The following commands vary depending on the version of Invoke being installed a
deactivate && source .venv/bin/activate
```
=== "Windows (PowerShell)"
=== "Windows"
```ps
deactivate
.venv\Scripts\activate
```
10. Run the application, specifying the directory you created earlier as the root directory:
1. Run the application:
=== "Linux/macOS"
Run `invokeai-web` to start the UI. You must activate the virtual environment before running the app.
```bash
invokeai-web --root ~/invokeai
```
!!! warning
=== "Windows (PowerShell)"
```bash
invokeai-web --root $Home/invokeai
```
## Headless Install and Launch Scripts
If you run Invoke on a headless server, you might want to install and run Invoke on the command line.
We do not plan to maintain scripts to do this moving forward, instead focusing our dev resources on the GUI [launcher](../installation/quick_start.md).
You can create your own scripts for this by copying the handful of commands in this guide. `uv`'s [`pip` interface docs](https://docs.astral.sh/uv/reference/cli/#uv-pip-install) may be useful.
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.

View File

@@ -1,128 +0,0 @@
# Invoke Community Edition Quick Start
Welcome to Invoke! Follow these steps to install, update, and get started creating.
## Step 1: System Requirements
Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
Hardware requirements vary significantly depending on model and image output size. The requirements below are rough guidelines.
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.
!!! info "Hardware Requirements (Windows/Linux)"
=== "SD1.5 - 512×512"
- GPU: Nvidia 10xx series or later, 4GB+ VRAM.
- Memory: At least 8GB RAM.
- Disk: 10GB for base installation plus 30GB for models.
=== "SDXL - 1024×1024"
- GPU: Nvidia 20xx series or later, 8GB+ VRAM.
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 100GB for models.
=== "FLUX - 1024×1024"
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
- Memory: At least 32GB RAM.
- Disk: 10GB for base installation plus 200GB for models.
More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download
Download the most launcher for your operating system:
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
- [Download for Linux](https://download.invoke.ai/Invoke%20Community%20Edition.AppImage)
## Step 3: Install or Update
Run the launcher you just downloaded, click **Install** and follow the instructions to get set up.
If you have an existing Invoke installation, you can select it and let the launcher manage the install. You'll be able to update or launch the installation.
!!! warning "Problem running the launcher on macOS"
macOS may not allow you to run the launcher. We are working to resolve this by signing the launcher executable. Until that is done, you can either use the [legacy scripts](./legacy_scripts.md) to install, or manually flag the launcher as safe:
- Open the **Invoke-Installer-mac-arm64.dmg** file.
- Drag the launcher to **Applications**.
- Open a terminal.
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
You should now be able to run the launcher.
## Step 4: Launch
Once installed, click **Finish**, then **Launch** to start Invoke.
The very first run after an installation or update will take a few extra moments to get ready.
!!! tip "Server Mode"
The launcher runs Invoke as a desktop application. You can enable **Server Mode** in the launcher's settings to disable this and instead access the UI through your web browser.
## Step 5: Install Models
With Invoke started up, you'll need to install some models.
The quickest way to get started is to install a **Starter Model** bundle. If you already have a model collection, Invoke can use it.
!!! info "Install Models"
=== "Install a Starter Model bundle"
1. Go to the **Models** tab.
2. Click **Starter Models** on the right.
3. Click one of the bundles to install its models. Refer to the [system requirements](#step-1-confirm-system-requirements) if you're unsure which model architecture will work for your system.
=== "Use my model collection"
4. Go to the **Models** tab.
5. Click **Scan Folder** on the right.
6. Paste the path to your models collection and click **Scan Folder**.
7. With **In-place install** enabled, Invoke will leave the model files where they are. If you disable this, **Invoke will move the models into its own folders**.
Youre now ready to start creating!
## Step 6: Learn the Basics
We recommend watching our [Getting Started Playlist](https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO). It covers essential features and workflows, including:
- Generating your first image.
- Using control layers and reference guides.
- Refining images with advanced workflows.
## Troubleshooting
If installation fails, retrying the install in Repair Mode may fix it. There's a checkbox to enable this on the Review step of the install flow.
If that doesn't fix it, [clearing the `uv` cache](https://docs.astral.sh/uv/reference/cli/#uv-cache-clean) might do the trick:
- Open and start the dev console (button at the bottom-left of the launcher).
- Run `uv cache clean`.
- Retry the installation. Enable Repair Mode for good measure.
If you are still unable to install, try installing to a different location and see if that works.
If you still have problems, ask for help on the Invoke [discord](https://discord.gg/ZmtBAhwWhy).
## Other Installation Methods
- You can install the Invoke application as a python package. See our [manual install](./manual.md) docs.
- You can run Invoke with docker. See our [docker install](./docker.md) docs.
- You can still use our legacy scripts to install and run Invoke. See the [legacy scripts](./legacy_scripts.md) docs.
## Need Help?
- Visit our [Support Portal](https://support.invoke.ai).
- Watch the [Getting Started Playlist](https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO).
- Join the conversation on [Discord][discord link].
[discord link]: https://discord.gg/ZmtBAhwWhy

View File

@@ -1,35 +1,90 @@
# Requirements
Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
## GPU
## Hardware
!!! warning "Problematic Nvidia GPUs"
Hardware requirements vary significantly depending on model and image output size.
We do not recommend these GPUs. They cannot operate with half precision, but have insufficient VRAM to generate 512x512 images at full precision.
The requirements below are rough guidelines for best performance. GPUs with less VRAM typically still work, if a bit slower. Follow the [Low-VRAM mode guide](./features/low-vram.md) to optimize performance.
- NVIDIA 10xx series cards such as the 1080 TI
- GTX 1650 series cards
- GTX 1660 series cards
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.
Invoke runs best with a dedicated GPU, but will fall back to running on CPU, albeit much slower. You'll need a beefier GPU for SDXL.
!!! info "Hardware Requirements (Windows/Linux)"
!!! example "Stable Diffusion 1.5"
=== "SD1.5 - 512×512"
=== "Nvidia"
- GPU: Nvidia 10xx series or later, 4GB+ VRAM.
- Memory: At least 8GB RAM.
- Disk: 10GB for base installation plus 30GB for models.
```
Any GPU with at least 4GB VRAM.
```
=== "SDXL - 1024×1024"
=== "AMD"
- GPU: Nvidia 20xx series or later, 8GB+ VRAM.
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 100GB for models.
```
Any GPU with at least 4GB VRAM. Linux only.
```
=== "FLUX - 1024×1024"
=== "Mac"
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
- Memory: At least 32GB RAM.
- Disk: 10GB for base installation plus 200GB for models.
```
Any Apple Silicon Mac with at least 8GB memory.
```
!!! example "Stable Diffusion XL"
=== "Nvidia"
```
Any GPU with at least 8GB VRAM.
```
=== "AMD"
```
Any GPU with at least 16GB VRAM. Linux only.
```
=== "Mac"
```
Any Apple Silicon Mac with at least 16GB memory.
```
## RAM
At least 12GB of RAM.
## Disk
SSDs will, of course, offer the best performance.
The base application disk usage depends on the torch backend.
!!! example "Disk"
=== "Nvidia (CUDA)"
```
~6.5GB
```
=== "AMD (ROCm)"
```
~12GB
```
=== "Mac (MPS)"
```
~3.5GB
```
You'll need to set aside some space for images, depending on how much you generate. A couple GB is enough to get started.
You'll need a good chunk of space for models. Even if you only install the most popular models and the usual support models (ControlNet, IP Adapter ,etc), you will quickly hit 50GB of models.
!!! info "`tmpfs` on Linux"
@@ -37,32 +92,26 @@ The requirements below are rough guidelines for best performance. GPUs with less
## Python
!!! tip "The launcher installs python for you"
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
Check that your system has an up-to-date Python installed by running `python --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
!!! info "Installing Python"
<h3>Installing Python (Windows)</h3>
=== "Windows"
- Install python 3.11 with [an official installer].
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
- You may need to install [Microsoft Visual C++ Redistributable].
- Install python 3.11 with [an official installer].
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
- You may need to install [Microsoft Visual C++ Redistributable].
<h3>Installing Python (macOS)</h3>
=== "macOS"
- Install python 3.11 with [an official installer].
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
- Install python 3.11 with [an official installer].
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
<h3>Installing Python (Linux)</h3>
=== "Linux"
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
- Follow the [linux install instructions], being sure to install python 3.11.
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
## Drivers
@@ -126,4 +175,7 @@ An alternative to installing ROCm locally is to use a [ROCm docker container] to
[ROCm Documentation]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html
[cuDNN support matrix]: https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html
[Nvidia Container Runtime]: https://developer.nvidia.com/container-runtime
[linux install instructions]: https://docs.python-guide.org/starting/install3/linux/
[Microsoft Visual C++ Redistributable]: https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170
[an official installer]: https://www.python.org/downloads/
[CUDA Toolkit Downloads]: https://developer.nvidia.com/cuda-downloads

View File

@@ -49,7 +49,6 @@ To use a community workflow, download the `.json` node graph file and load it in
+ [BriaAI Background Remove](#briaai-remove-background)
+ [Remove Background](#remove-background)
+ [Retroize](#retroize)
+ [Stereogram](#stereogram-nodes)
+ [Size Stepper Nodes](#size-stepper-nodes)
+ [Simple Skin Detection](#simple-skin-detection)
+ [Text font to Image](#text-font-to-image)
@@ -527,16 +526,6 @@ View:
<img src="https://github.com/Ar7ific1al/InvokeAI_nodes_retroize/assets/2306586/de8b4fa6-324c-4c2d-b36c-297600c73974" width="500" />
--------------------------------
### Stereogram Nodes
**Description:** A set of custom nodes for InvokeAI to create cross-view or parallel-view stereograms. Stereograms are 2D images that, when viewed properly, reveal a 3D scene. Check out [r/crossview](https://www.reddit.com/r/CrossView/) for tutorials.
**Node Link:** https://github.com/simonfuhrmann/invokeai-stereo
**Example Workflow and Output**
</br><img src="https://raw.githubusercontent.com/simonfuhrmann/invokeai-stereo/refs/heads/main/docs/example_promo_03.jpg" width="600" />
--------------------------------
### Simple Skin Detection

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -32,7 +31,7 @@ class DeleteBoardResult(BaseModel):
response_model=BoardDTO,
)
async def create_board(
board_name: str = Query(description="The name of the board to create", max_length=300),
board_name: str = Query(description="The name of the board to create"),
is_private: bool = Query(default=False, description="Whether the board is private"),
) -> BoardDTO:
"""Creates a board"""
@@ -88,9 +87,7 @@ async def delete_board(
try:
if include_images is True:
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id=board_id,
categories=None,
is_intermediate=None,
board_id=board_id
)
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
@@ -101,9 +98,7 @@ async def delete_board(
)
else:
deleted_board_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id=board_id,
categories=None,
is_intermediate=None,
board_id=board_id
)
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
return DeleteBoardResult(
@@ -147,14 +142,10 @@ async def list_boards(
)
async def list_all_board_image_names(
board_id: str = Path(description="The id of the board"),
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
) -> list[str]:
"""Gets a list of images for a board"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id,
categories,
is_intermediate,
)
return image_names

View File

@@ -4,6 +4,7 @@
import contextlib
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from enum import Enum
@@ -20,6 +21,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config import get_config
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
@@ -35,7 +37,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
@@ -846,6 +848,74 @@ async def get_starter_models() -> StarterModelResponse:
return StarterModelResponse(starter_models=starter_models, starter_bundles=starter_bundles)
@model_manager_router.get(
"/model_cache",
operation_id="get_cache_size",
response_model=float,
summary="Get maximum size of model manager RAM or VRAM cache.",
)
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
"""Return the current RAM or VRAM cache size setting (in GB)."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
value = 0.0
if cache_type == CacheType.RAM:
value = cache.max_cache_size
elif cache_type == CacheType.VRAM:
value = cache.max_vram_cache_size
return value
@model_manager_router.put(
"/model_cache",
operation_id="set_cache_size",
response_model=float,
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
)
async def set_cache_size(
value: float = Query(description="The new value for the maximum cache size"),
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
) -> float:
"""Set the current RAM or VRAM cache size setting (in GB). ."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
app_config = get_config()
# Record initial state.
vram_old = app_config.vram
ram_old = app_config.ram
# Prepare target state.
vram_new = vram_old
ram_new = ram_old
if cache_type == CacheType.RAM:
ram_new = value
elif cache_type == CacheType.VRAM:
vram_new = value
else:
raise ValueError(f"Unexpected {cache_type=}.")
config_path = app_config.config_file_path
new_config_path = config_path.with_suffix(".yaml.new")
try:
# Try to apply the target state.
cache.max_vram_cache_size = vram_new
cache.max_cache_size = ram_new
app_config.ram = ram_new
app_config.vram = vram_new
if persist:
app_config.write_file(new_config_path)
shutil.move(new_config_path, config_path)
except Exception as e:
# If there was a failure, restore the initial state.
cache.max_cache_size = ram_old
cache.max_vram_cache_size = vram_old
app_config.ram = ram_old
app_config.vram = vram_old
raise RuntimeError("Failed to update cache size") from e
return value
@model_manager_router.get(
"/stats",
operation_id="get_stats",
@@ -858,18 +928,6 @@ async def get_stats() -> Optional[CacheStats]:
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
@model_manager_router.post(
"/empty_model_cache",
operation_id="empty_model_cache",
status_code=200,
)
async def empty_model_cache() -> None:
"""Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped."""
# Request 1000GB of room in order to force the cache to drop all models.
ApiDependencies.invoker.services.logger.info("Emptying model cache.")
ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30)
class HFTokenStatus(str, Enum):
VALID = "valid"
INVALID = "invalid"

View File

@@ -10,13 +10,11 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
CancelByBatchIDsResult,
CancelByDestinationResult,
ClearResult,
EnqueueBatchResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
@@ -48,9 +46,7 @@ async def enqueue_batch(
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
@session_queue_router.get(
@@ -98,18 +94,6 @@ async def Pause(
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
"/{queue_id}/cancel_all_except_current",
operation_id="cancel_all_except_current",
responses={200: {"model": CancelAllExceptCurrentResult}},
)
async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
@session_queue_router.put(
"/{queue_id}/cancel_by_batch_ids",
operation_id="cancel_by_batch_ids",
@@ -138,19 +122,6 @@ async def cancel_by_destination(
)
@session_queue_router.put(
"/{queue_id}/retry_items_by_id",
operation_id="retry_items_by_id",
responses={200: {"model": RetryItemsResult}},
)
async def retry_items_by_id(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",

View File

@@ -25,7 +25,6 @@ async def parse_dynamicprompts(
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
max_prompts: int = Body(ge=1, le=10000, default=1000, description="The max number of prompts to generate"),
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
seed: int | None = Body(None, description="The seed to use for random generation. Only used if not combinatorial"),
) -> DynamicPromptsResponse:
"""Creates a batch process"""
max_prompts = min(max_prompts, 10000)
@@ -36,7 +35,7 @@ async def parse_dynamicprompts(
generator = CombinatorialPromptGenerator()
prompts = generator.generate(prompt, max_prompts=max_prompts)
else:
generator = RandomPromptGenerator(seed=seed)
generator = RandomPromptGenerator()
prompts = generator.generate(prompt, num_images=max_prompts)
except ParseException as e:
prompts = [prompt]

View File

@@ -1,8 +1,12 @@
import asyncio
import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from pathlib import Path
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -11,7 +15,11 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
@@ -30,10 +38,24 @@ from invokeai.app.api.routers import (
from invokeai.app.api.sockets import SocketIO
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
app_config = get_config()
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
logger = InvokeAILogger.get_logger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
@@ -42,23 +64,6 @@ loop = asyncio.new_event_loop()
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
proto = "https" if app_config.ssl_certfile else "http"
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
# Logging this way ignores the logger's log level and _always_ logs the message
record = logger.makeRecord(
name=logger.name,
level=logging.INFO,
fn="",
lno=0,
msg=msg,
args=(),
exc_info=None,
)
logger.handle(record)
yield
# Shut down threads
ApiDependencies.shutdown()
@@ -160,3 +165,73 @@ except RuntimeError:
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def invoke_api() -> None:
def find_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_port(port=port + 1)
else:
return port
if app_config.dev_reload:
try:
import jurigged
except ImportError as e:
logger.error(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
check_cudnn(logger)
config = uvicorn.Config(
app=app,
host=app_config.host,
port=port,
loop="asyncio",
log_level=app_config.log_level,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
log = InvokeAILogger.get_logger(logname)
log.handlers.clear()
for ch in logger.handlers:
log.addHandler(ch)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

View File

@@ -1,5 +1,33 @@
import shutil
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.services.config.config_default import get_config
custom_nodes_path = Path(get_config().custom_nodes_path)
custom_nodes_path.mkdir(parents=True, exist_ok=True)
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
custom_nodes_readme_path = str(custom_nodes_path / "README.md")
# copy our custom nodes __init__.py to the custom nodes directory
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
# set the same permissions as the destination directory, in case our source is read-only,
# so that the files are user-writable
for p in custom_nodes_path.glob("**/*"):
p.chmod(custom_nodes_path.stat().st_mode)
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not load custom nodes from {custom_nodes_init_path}")
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
# add core nodes to __all__
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
__all__ = [f.stem for f in python_files] # type: ignore

View File

@@ -44,6 +44,8 @@ if TYPE_CHECKING:
logger = InvokeAILogger.get_logger()
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
class InvalidVersionError(ValueError):
pass
@@ -238,11 +240,6 @@ class BaseInvocation(ABC, BaseModel):
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
@@ -449,27 +446,8 @@ def invocation(
if re.compile(r"^\S+$").match(invocation_type) is None:
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
# The node pack is the module name - will be "invokeai" for built-in nodes
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in BaseInvocation.get_invocation_types():
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The node being clobbered is a core node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
)
else:
# The node being clobbered is a custom node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
)
raise ValueError(f'Invocation type "{invocation_type}" already exists')
validate_fields(cls.model_fields, invocation_type)
@@ -479,7 +457,8 @@ def invocation(
uiconfig["tags"] = tags
uiconfig["category"] = category
uiconfig["classification"] = classification
uiconfig["node_pack"] = node_pack
# The node pack is the module name - will be "invokeai" for built-in nodes
uiconfig["node_pack"] = cls.__module__.split(".")[0]
if version is not None:
try:

View File

@@ -1,274 +0,0 @@
from typing import Literal
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import (
FloatOutput,
ImageOutput,
IntegerOutput,
StringOutput,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]
class NotExecutableNodeError(Exception):
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
super().__init__(message)
pass
class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)
def __init__(self):
raise NotExecutableNodeError()
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[],
min_length=1,
description="The images to batch over",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()
@invocation_output("image_generator_output")
class ImageGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of boards"""
images: list[ImageField] = OutputField(description="The generated images")
class ImageGeneratorField(BaseModel):
pass
@invocation(
"image_generator",
title="Image Generator",
tags=["primitives", "board", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageGenerator(BaseInvocation):
"""Generated a collection of images for use in a batched generation"""
generator: ImageGeneratorField = InputField(
description="The image generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> ImageGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[],
min_length=1,
description="The strings to batch over",
)
def invoke(self, context: InvocationContext) -> StringOutput:
raise NotExecutableNodeError()
@invocation_output("string_generator_output")
class StringGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
strings: list[str] = OutputField(description="The generated strings")
class StringGeneratorField(BaseModel):
pass
@invocation(
"string_generator",
title="String Generator",
tags=["primitives", "string", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringGenerator(BaseInvocation):
"""Generated a range of strings for use in a batched generation"""
generator: StringGeneratorField = InputField(
description="The string generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> StringGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)
def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotExecutableNodeError()
@invocation_output("integer_generator_output")
class IntegerGeneratorOutput(BaseInvocationOutput):
integers: list[int] = OutputField(description="The generated integers")
class IntegerGeneratorField(BaseModel):
pass
@invocation(
"integer_generator",
title="Integer Generator",
tags=["primitives", "int", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerGenerator(BaseInvocation):
"""Generated a range of integers for use in a batched generation"""
generator: IntegerGeneratorField = InputField(
description="The integer generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> IntegerGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)
def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotExecutableNodeError()
@invocation_output("float_generator_output")
class FloatGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
floats: list[float] = OutputField(description="The generated floats")
class FloatGeneratorField(BaseModel):
pass
@invocation(
"float_generator",
title="Float Generator",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatGenerator(BaseInvocation):
"""Generated a range of floats for use in a batched generation"""
generator: FloatGeneratorField = InputField(
description="The float generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> FloatGeneratorOutput:
raise NotExecutableNodeError()

View File

@@ -19,9 +19,9 @@ from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
@@ -63,28 +63,30 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
text_encoder_info = context.models.load(self.clip.text_encoder)
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
context.models.load(self.clip.tokenizer) as tokenizer,
LayerPatcher.apply_smart_model_patches(
tokenizer_info as tokenizer,
LoRAPatcher.apply_smart_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
dtype=text_encoder.dtype,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -103,7 +105,6 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -138,7 +139,9 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.model
@@ -160,11 +163,11 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, ModelPatchRaw)
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
del lora_info
return
@@ -176,12 +179,12 @@ class SDXLPromptInvocationBase:
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
context.models.load(clip_field.tokenizer) as tokenizer,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
tokenizer_info as tokenizer,
LoRAPatcher.apply_smart_lora_patches(
text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
dtype=text_encoder.dtype,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -204,7 +207,6 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -222,6 +224,7 @@ class SDXLPromptInvocationBase:
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
c = c.detach().to("cpu")

View File

@@ -1,5 +1,7 @@
from typing import Literal
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
"""
HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
@@ -10,3 +12,5 @@ The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

View File

@@ -6,6 +6,7 @@ from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField
@@ -28,7 +29,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=4)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":

View File

@@ -7,6 +7,7 @@ from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -75,7 +76,11 @@ class CreateGradientMaskInvocation(BaseInvocation):
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=9)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:

View File

@@ -10,12 +10,10 @@ from pathlib import Path
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
loaded_packs: list[str] = []
failed_packs: list[str] = []
loaded_count = 0
custom_nodes_dir = Path(__file__).parent
for d in custom_nodes_dir.iterdir():
for d in Path(__file__).parent.iterdir():
# skip files
if not d.is_dir():
continue
@@ -49,16 +47,12 @@ for d in custom_nodes_dir.iterdir():
sys.modules[spec.name] = module
spec.loader.exec_module(module)
loaded_packs.append(module_name)
loaded_count += 1
except Exception:
failed_packs.append(module_name)
full_error = traceback.format_exc()
logger.error(f"Failed to load node pack {module_name} (may have partially loaded):\n{full_error}")
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
del init, module_name
loaded_count = len(loaded_packs)
if loaded_count > 0:
logger.info(
f"Loaded {loaded_count} node pack{'s' if loaded_count != 1 else ''} from {custom_nodes_dir}: {', '.join(loaded_packs)}"
)
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")

View File

@@ -10,9 +10,7 @@ import torchvision.transforms as T
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.adapter import T2IAdapter
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from PIL import Image
@@ -39,11 +37,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
@@ -86,14 +83,12 @@ def get_scheduler(
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
unet_config: AnyModelConfig,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -105,17 +100,10 @@ def get_scheduler(
"_backup": scheduler_config,
}
if hasattr(unet_config, "prediction_type"):
scheduler_config["prediction_type"] = unet_config.prediction_type
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
if scheduler_class is DPMSolverMultistepScheduler or scheduler_class is DPMSolverSinglestepScheduler:
if scheduler_config["_class_name"] == "DEISMultistepScheduler" and scheduler_config["algorithm_type"] == "deis":
scheduler_config["algorithm_type"] = "dpmsolver++"
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
@@ -423,7 +411,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
device: torch.device,
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> list[ControlNetData] | None:
@@ -465,7 +452,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=device,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
@@ -560,6 +547,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@@ -568,7 +556,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
single_ipa_images = [
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
]
with context.models.load(single_ip_adapter.image_encoder_model) as image_encoder_model:
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
@@ -618,7 +606,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
device: torch.device,
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
@@ -634,6 +621,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -649,7 +637,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
t2i_adapter_model: T2IAdapter
with context.models.load(t2i_adapter_field.t2i_adapter_model) as t2i_adapter_model:
with t2i_adapter_loaded_model as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
@@ -669,7 +657,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
width=control_width_resize,
height=control_height_resize,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=device,
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
@@ -834,9 +822,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
@@ -856,7 +841,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
@@ -868,6 +852,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
@@ -898,7 +885,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
### inpaint
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
# NOTE: We used to identify inpainting models by inspecting the shape of the loaded UNet model weights. Now we
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
# prevalent, we will have to revisit how we initialize the inpainting extensions.
@@ -939,8 +926,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
@@ -961,7 +950,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
device = TorchDevice.choose_torch_device()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@@ -976,7 +964,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context,
self.t2i_adapter,
latents.shape,
device=device,
do_classifier_free_guidance=True,
)
@@ -1000,21 +987,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
LayerPatcher.apply_smart_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
@@ -1023,20 +1012,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=device, dtype=unet.dtype)
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=device, dtype=unet.dtype)
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=device, dtype=unet.dtype)
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet, scheduler)
@@ -1046,7 +1034,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=device,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
@@ -1059,7 +1047,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
control_input=self.control,
latents_shape=latents.shape,
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
@@ -1077,7 +1064,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=device,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,

View File

@@ -56,7 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
ControlLoRAModel = "ControlLoRAModelField"
StructuralLoRAModel = "StructuralLoRAModelField"
# endregion
# region Misc Field Types
@@ -144,7 +144,7 @@ class FieldDescriptions:
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
control_lora_model = "Control LoRA model to load"
structural_lora_model = "Structural LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
@@ -300,13 +300,6 @@ class BoundingBoxField(BaseModel):
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
def tuple(self) -> Tuple[int, int, int, int]:
"""
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
"""
return (self.x_min, self.y_min, self.x_max, self.y_max)
class MetadataField(RootModel[dict[str, Any]]):
"""

View File

@@ -1,49 +0,0 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_control_lora_loader_output")
class FluxControlLoRALoaderOutput(BaseInvocationOutput):
"""Flux Control LoRA Loader Output"""
control_lora: ControlLoRAField = OutputField(
title="Flux Control LoRA", description="Control LoRAs to apply on model loading", default=None
)
@invocation(
"flux_control_lora_loader",
title="Flux Control LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxControlLoRALoaderInvocation(BaseInvocation):
"""LoRA model and Image to use with FLUX transformer generation."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
)
image: ImageField = InputField(description="The image to encode.")
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
if not context.models.exists(self.lora.key):
raise ValueError(f"Unknown lora: {self.lora.key}!")
return FluxControlLoRALoaderOutput(
control_lora=ControlLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
)

View File

@@ -1,15 +1,15 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple, Union
import einops
import numpy as np
import numpy.typing as npt
import torch
import torchvision.transforms as tv_transforms
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
@@ -23,9 +23,8 @@ from invokeai.app.invocations.fields import (
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -46,11 +45,13 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -92,9 +93,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
control_lora: Optional[ControlLoRAField] = InputField(
description=FieldDescriptions.control_lora_model, input=Input.Connection, title="Control LoRA", default=None
)
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
@@ -199,8 +197,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else None
)
transformer_config = context.models.get_config(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
timesteps = get_schedule(
@@ -240,12 +238,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if len(timesteps) <= 1:
return x
if is_schnell and self.control_lora:
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
img_cond = self._prep_structural_control_img_cond(context)
inpaint_mask = self._prep_inpaint_mask(context, x)
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
@@ -253,7 +245,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
img_cond = pack(img_cond) if img_cond is not None else None
noise = pack(noise)
x = pack(x)
@@ -276,7 +267,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
ip_adapter_fields = self._normalize_ip_adapter_fields()
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
ip_adapter_fields, context, device=x.device
ip_adapter_fields, context
)
cfg_scale = self.prep_cfg_scale(
@@ -297,42 +288,54 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
device=x.device,
)
img_cond = None
if struct_lora := self.transformer.structural_lora:
# What should we do when we have multiple of these?
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a strutural lora")
ae_info = context.models.load(self.controlnet_vae.vae)
img = context.images.get_pil(struct_lora.img.image_name)
with ae_info as ae:
assert isinstance(ae, AutoEncoder)
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(
context.models.load(self.transformer.transformer).model_on_device()
)
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, Flux)
config = transformer_config
config = transformer_info.config
assert config is not None
# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if config.format in [ModelFormat.Checkpoint]:
model_is_quantized = False
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_smart_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
)
)
elif config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
model_is_quantized = True
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoRAPatcher.apply_lora_wrapper_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)
else:
raise ValueError(f"Unsupported model format: {config.format}")
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
# Prepare IP-Adapter extensions.
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
@@ -357,7 +360,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond=img_cond
)
x = unpack(x.float(), self.height, self.width)
@@ -514,18 +517,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet in controlnets:
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
image = context.images.get_pil(controlnet.image.image_name)
# HACK(ryand): We have to load the ControlNet model to determine whether the VAE needs to be run. We really
# shouldn't have to load the model here. There's a risk that the model will be dropped from the model cache
# before we load it into VRAM and thus we'll have to load it again (context:
# https://github.com/invoke-ai/InvokeAI/issues/7513).
controlnet_model = context.models.load(controlnet.control_model)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
@@ -555,8 +555,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond in zip(controlnets, controlnet_conds, strict=True):
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)
if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
@@ -589,29 +591,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return controlnet_extensions
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
if self.control_lora is None:
return None
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
# Load the conditioning image and resize it to the target image size.
cond_img = context.images.get_pil(self.control_lora.img.image_name)
cond_img = cond_img.convert("RGB")
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
cond_img = np.array(cond_img)
# Normalize the conditioning image to the range [-1, 1].
# This normalization is based on the original implementations here:
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L34
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L60
img_cond = torch.from_numpy(cond_img).float() / 127.5 - 1.0
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
vae_info = context.models.load(self.controlnet_vae.vae)
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
if self.ip_adapter is None:
return []
@@ -626,7 +605,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
self,
ip_adapter_fields: list[IPAdapterField],
context: InvocationContext,
device: torch.device,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
clip_image_processor = CLIPImageProcessor()
@@ -666,11 +644,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
@@ -719,15 +697,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
if self.control_lora:
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
# applied last.
loras.append(self.control_lora)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
if self.transformer.structural_lora:
loras.append(self.transformer.structural_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@@ -21,9 +21,6 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
t5_encoder: Optional[T5EncoderField] = OutputField(
default=None, description=FieldDescriptions.t5_encoder, title="T5 Encoder"
)
@invocation(
@@ -31,7 +28,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.2.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
@@ -53,12 +50,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
@@ -71,8 +62,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')
output = FluxLoRALoaderOutput()
@@ -93,14 +82,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
weight=self.weight,
)
)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return output
@@ -110,14 +91,14 @@ class FluxLoRALoaderInvocation(BaseInvocation):
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.3.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a FLUX transformer."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
transformer: Optional[TransformerField] = InputField(
@@ -132,30 +113,13 @@ class FLUXLoRACollectionLoader(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
@@ -166,13 +130,14 @@ class FLUXLoRACollectionLoader(BaseInvocation):
added_loras.append(lora.lora.key)
if self.transformer is not None and output.transformer is not None:
if self.transformer is not None:
if output.transformer is None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(lora)
if self.clip is not None and output.clip is not None:
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.t5_encoder is not None and output.t5_encoder is not None:
output.t5_encoder.loras.append(lora)
return output

View File

@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
@@ -40,7 +36,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.5",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
@@ -78,16 +74,16 @@ class FluxModelLoaderInvocation(BaseInvocation):
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -0,0 +1,70 @@
from typing import Optional, Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_structural_lora_loader_output")
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
"""Flux Structural LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_structural_lora_loader",
title="Flux Structural LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
image: ImageField = InputField(
description="The image to encode.",
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
output = FluxStructuralLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.structural_lora = StructuralLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
return output

View File

@@ -2,7 +2,7 @@ from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
@@ -17,11 +17,12 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -69,46 +70,17 @@ class FluxTextEncoderInvocation(BaseInvocation):
)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_encoder_config = t5_encoder_info.config
assert t5_encoder_config is not None
with (
t5_encoder_info.model_on_device() as (cached_weights, t5_text_encoder),
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
ExitStack() as exit_stack,
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
if t5_encoder_config.format in [ModelFormat.T5Encoder, ModelFormat.Diffusers]:
model_is_quantized = False
elif t5_encoder_config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
model_is_quantized = True
else:
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")
# Apply LoRA models to the T5 encoder.
# Note: We apply the LoRA after the encoder has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=t5_text_encoder,
patches=self._t5_lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX,
dtype=t5_text_encoder.dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
@@ -119,30 +91,32 @@ class FluxTextEncoderInvocation(BaseInvocation):
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
prompt = [self.prompt]
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
context.models.load(self.clip.tokenizer) as clip_tokenizer,
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=clip_text_encoder.dtype,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)
@@ -158,16 +132,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _t5_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.t5_encoder.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -3,7 +3,6 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -25,7 +24,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Latents to Image",
tags=["latents", "image", "vae", "l2i", "flux"],
category="latents",
version="1.0.1",
version="1.0.0",
)
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -39,18 +38,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

View File

@@ -21,7 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.4",
version="1.0.3",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
@@ -41,16 +41,11 @@ class IdealSizeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
if unet_config.base == BaseModelType.StableDiffusion1:
dimension = 512
elif unet_config.base == BaseModelType.StableDiffusion2:
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base in (BaseModelType.StableDiffusionXL, BaseModelType.Flux, BaseModelType.StableDiffusion3):
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
else:
raise ValueError(f"Unsupported model type: {unet_config.base}")
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images

View File

@@ -13,7 +13,6 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
FieldDescriptions,
ImageField,
@@ -24,7 +23,6 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
@@ -163,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
crop: bool = InputField(default=False, description="Crop to base image dimensions")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
image = context.images.get_pil(self.image.image_name, mode="RGBA")
base_image = context.images.get_pil(self.base_image.image_name)
image = context.images.get_pil(self.image.image_name)
mask = None
if self.mask is not None:
mask = context.images.get_pil(self.mask.image_name, mode="L")
mask = ImageOps.invert(mask)
mask = context.images.get_pil(self.mask.image_name)
mask = ImageOps.invert(mask.convert("L"))
# TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x)
@@ -178,11 +176,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
new_image.paste(base_image, (abs(min_x), abs(min_y)))
# Create a temporary image to paste the image with transparency
temp_image = Image.new("RGBA", new_image.size)
temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
new_image = Image.alpha_composite(new_image, temp_image)
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
if self.crop:
base_w, base_h = base_image.size
@@ -307,44 +301,14 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
image = context.images.get_pil(self.image.image_name)
# Split the image into RGBA channels
r, g, b, a = image.split()
# Premultiply RGB channels by alpha
premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
premultiplied_image.putalpha(a)
# Apply the blur
blur = (
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
)
blurred_image = premultiplied_image.filter(blur)
blur_image = image.filter(blur)
# Split the blurred image into RGBA channels
r, g, b, a_orig = blurred_image.split()
# Convert to float using NumPy. float 32/64 division are much faster than float 16
r = numpy.array(r, dtype=numpy.float32)
g = numpy.array(g, dtype=numpy.float32)
b = numpy.array(b, dtype=numpy.float32)
a = numpy.array(a_orig, dtype=numpy.float32) / 255.0 # Normalize alpha to [0, 1]
# Unpremultiply RGB channels by alpha
r /= a + 1e-6 # Add a small epsilon to avoid division by zero
g /= a + 1e-6
b /= a + 1e-6
# Convert back to PIL images
r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))
# Merge back into a single image
result_image = Image.merge("RGBA", (r, g, b, a_orig))
image_dto = context.images.save(image=result_image)
image_dto = context.images.save(image=blur_image)
return ImageOutput.build(image_dto)
@@ -843,7 +807,7 @@ CHANNEL_FORMATS = {
"value",
],
category="image",
version="1.2.3",
version="1.2.2",
)
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add or subtract a value from a specific color channel of an image."""
@@ -853,22 +817,18 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGBA")
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
mode = CHANNEL_FORMATS[self.channel][0]
channel_number = CHANNEL_FORMATS[self.channel][1]
# Convert PIL image to new format
converted_image = numpy.array(image.convert(mode)).astype(int)
converted_image = numpy.array(pil_image.convert(mode)).astype(int)
image_channel = converted_image[:, :, channel_number]
if self.channel == "Hue (HSV)":
# loop around the values because hue is special
image_channel = (image_channel + self.offset) % 256
else:
# Adjust the value, clipping to 0..255
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
# Adjust the value, clipping to 0..255
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
# Put the channel back into the image
converted_image[:, :, channel_number] = image_channel
@@ -876,10 +836,6 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
# Convert back to RGBA format and output
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
# restore the alpha channel
if self.channel != "Alpha (RGBA)":
pil_image.putalpha(image.getchannel("A"))
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@@ -907,7 +863,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
"value",
],
category="image",
version="1.2.3",
version="1.2.2",
)
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Scale a specific color channel of an image."""
@@ -918,14 +874,14 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGBA")
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
mode = CHANNEL_FORMATS[self.channel][0]
channel_number = CHANNEL_FORMATS[self.channel][1]
# Convert PIL image to new format
converted_image = numpy.array(image.convert(mode)).astype(float)
converted_image = numpy.array(pil_image.convert(mode)).astype(float)
image_channel = converted_image[:, :, channel_number]
# Adjust the value, clipping to 0..255
@@ -941,10 +897,6 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
# Convert back to RGBA format and output
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
# restore the alpha channel
if self.channel != "Alpha (RGBA)":
pil_image.putalpha(image.getchannel("A"))
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@@ -1010,10 +962,10 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"mask_from_id",
title="Mask from Segmented Image",
title="Mask from ID",
tags=["image", "mask", "id"],
category="image",
version="1.0.1",
version="1.0.0",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate a mask for a particular color in an ID Map"""
@@ -1023,24 +975,40 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
threshold: int = InputField(default=100, description="Threshold for color detection")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]):
r, g, b, a = rgba_color
hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255))
return hex_code
np_color = numpy.array(self.color.tuple())
def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100):
if id_mask.mode != "RGB":
id_mask = id_mask.convert("RGB")
# Can directly just use the tuple but I'll leave this rgba_to_hex here
# incase anyone prefers using hex codes directly instead of the color picker
hex_color_str = self.rgba_to_hex(color)
rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)])
# Maybe there's a faster way to calculate this distance but I can't think of any right now.
color_distance = numpy.linalg.norm(image - np_color, axis=-1)
color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1)
# Create a mask based on the threshold and the distance calculated above
binary_mask = (color_distance < self.threshold).astype(numpy.uint8) * 255
binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255
# Convert the mask back to PIL
binary_mask_pil = Image.fromarray(binary_mask)
if self.invert:
binary_mask_pil = ImageOps.invert(binary_mask_pil)
return binary_mask_pil
image_dto = context.images.save(image=binary_mask_pil, image_category=ImageCategory.MASK)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mask = self.id_to_mask(image, self.color.tuple(), self.threshold)
if self.invert:
mask = ImageOps.invert(mask)
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@@ -1087,123 +1055,3 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=generated_image)
return ImageOutput.build(image_dto)
@invocation(
"img_noise",
title="Add Image Noise",
tags=["image", "noise"],
category="image",
version="1.0.1",
)
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add noise to an image"""
image: ImageField = InputField(description="The image to add noise to")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
)
noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
default="gaussian",
description="The type of noise to add",
)
amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
noise_color: bool = InputField(default=True, description="Whether to add colored noise")
size: int = InputField(default=1, ge=1, description="The size of the noise points")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
# Save out the alpha channel
alpha = image.getchannel("A")
# Set the seed for numpy random
rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))
if self.noise_type == "gaussian":
if self.noise_color:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
else:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
noise = numpy.stack([noise] * 3, axis=-1)
elif self.noise_type == "salt_and_pepper":
if self.noise_color:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
)
else:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
)
noise = numpy.stack([noise] * 3, axis=-1)
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
(image.width, image.height), Image.Resampling.NEAREST
)
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
# Paste back the alpha channel
noisy_image.putalpha(alpha)
image_dto = context.images.save(image=noisy_image)
return ImageOutput.build(image_dto)
@invocation(
"crop_image_to_bounding_box",
title="Crop Image to Bounding Box",
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
image: ImageField = InputField(description="The image to crop")
bounding_box: BoundingBoxField | None = InputField(
default=None, description="The bounding box to crop the image to"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
bounding_box = self.bounding_box.tuple() if self.bounding_box is not None else image.getbbox()
cropped_image = image.crop(bounding_box)
image_dto = context.images.save(image=cropped_image)
return ImageOutput.build(image_dto)
@invocation(
"paste_image_into_bounding_box",
title="Paste Image into Bounding Box",
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Paste the source image into the target image at the given bounding box.
The source image must be the same size as the bounding box, and the bounding box must fit within the target image."""
source_image: ImageField = InputField(description="The image to paste")
target_image: ImageField = InputField(description="The image to paste into")
bounding_box: BoundingBoxField = InputField(description="The bounding box to paste the image into")
def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name, mode="RGBA")
target_image = context.images.get_pil(self.target_image.image_name, mode="RGBA")
bounding_box = self.bounding_box.tuple()
target_image.paste(source_image, bounding_box, source_image)
image_dto = context.images.save(image=target_image)
return ImageOutput.build(image_dto)

View File

@@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@@ -26,7 +26,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -50,7 +49,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(
@@ -99,7 +98,7 @@ class ImageToLatentsInvocation(BaseInvocation):
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)

View File

@@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.3.1",
version="1.3.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -51,58 +51,18 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
def _estimate_working_memory(
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if self.fp32 else 2
scaling_constant = 2200 # Determined experimentally.
if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
out_h = tile_size
out_w = tile_size
working_memory = out_h * out_w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
working_memory = out_h * out_w * element_size * scaling_constant
if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
return int(working_memory)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
use_tiling = self.tiled or context.config.get().force_tiled_decode
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(TorchDevice.choose_torch_device())
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@@ -128,7 +88,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
if use_tiling:
if self.tiled or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()

View File

@@ -1,40 +0,0 @@
import shutil
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
def load_custom_nodes(custom_nodes_path: Path):
"""
Loads all custom nodes from the custom_nodes_path directory.
This function copies a custom __init__.py file to the custom_nodes_path directory, effectively turning it into a
python module.
The custom __init__.py file itself imports all the custom node packs as python modules from the custom_nodes_path
directory.
Then,the custom __init__.py file is programmatically imported using importlib. As it executes, it imports all the
custom node packs as python modules.
"""
custom_nodes_path.mkdir(parents=True, exist_ok=True)
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
custom_nodes_readme_path = str(custom_nodes_path / "README.md")
# copy our custom nodes __init__.py to the custom nodes directory
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
# set the same permissions as the destination directory, in case our source is read-only,
# so that the files are user-writable
for p in custom_nodes_path.glob("**/*"):
p.chmod(custom_nodes_path.stat().st_mode)
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not load custom nodes from {custom_nodes_init_path}")
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)

View File

@@ -2,22 +2,9 @@ import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
)
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
ImageField,
InputField,
TensorField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import BoundingBoxOutput, ImageOutput, MaskOutput
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
from invokeai.backend.image_util.util import pil_to_np
@@ -86,7 +73,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
title="Invert Tensor Mask",
tags=["conditioning"],
category="conditioning",
version="1.1.0",
version="1.0.0",
classification=Classification.Beta,
)
class InvertTensorMaskInvocation(BaseInvocation):
@@ -96,15 +83,6 @@ class InvertTensorMaskInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Verify dtype and shape.
assert mask.dtype == torch.bool
assert mask.dim() in [2, 3]
# Unsqueeze the channel dimension if it is missing. The MaskOutput type expects a single channel.
if mask.dim() == 2:
mask = mask.unsqueeze(0)
inverted = ~mask
return MaskOutput(
@@ -223,48 +201,3 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=masked_image)
return ImageOutput.build(image_dto)
WHITE = ColorField(r=255, g=255, b=255, a=255)
@invocation(
"get_image_mask_bounding_box",
title="Get Image Mask Bounding Box",
tags=["mask"],
category="mask",
version="1.0.0",
classification=Classification.Beta,
)
class GetMaskBoundingBoxInvocation(BaseInvocation):
"""Gets the bounding box of the given mask image."""
mask: ImageField = InputField(description="The mask to crop.")
margin: int = InputField(default=0, description="Margin to add to the bounding box.")
mask_color: ColorField = InputField(default=WHITE, description="Color of the mask in the image.")
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
mask = context.images.get_pil(self.mask.image_name, mode="RGBA")
mask_np = np.array(mask)
# Convert mask_color to RGBA tuple
mask_color_rgb = self.mask_color.tuple()
# Find the bounding box of the mask color
y, x = np.where(np.all(mask_np == mask_color_rgb, axis=-1))
if len(x) == 0 or len(y) == 0:
# No pixels found with the given color
return BoundingBoxOutput(bounding_box=BoundingBoxField(x_min=0, y_min=0, x_max=0, y_max=0))
left, upper, right, lower = x.min(), y.min(), x.max(), y.max()
# Add the margin
left = max(0, left - self.margin)
upper = max(0, upper - self.margin)
right = min(mask_np.shape[1], right + self.margin)
lower = min(mask_np.shape[0], lower + self.margin)
bounding_box = BoundingBoxField(x_min=left, y_min=upper, x_max=right, y_max=lower)
return BoundingBoxOutput(bounding_box=bounding_box)

View File

@@ -18,7 +18,6 @@ from invokeai.app.invocations.fields import (
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from invokeai.version.invokeai_version import __version__
@@ -276,34 +275,3 @@ class CoreMetadataInvocation(BaseInvocation):
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
model_config = ConfigDict(extra="allow")
@invocation(
"metadata_field_extractor",
title="Metadata Field Extractor",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Deprecated,
)
class MetadataFieldExtractorInvocation(BaseInvocation):
"""Extracts the text value from an image's metadata given a key.
Raises an error if the image has no metadata or if the value is not a string (nesting not permitted)."""
image: ImageField = InputField(description="The image to extract metadata from")
key: str = InputField(description="The key in the image's metadata to extract the value from")
def invoke(self, context: InvocationContext) -> StringOutput:
image_name = self.image.image_name
metadata = context.images.get_metadata(image_name=image_name)
if not metadata:
raise ValueError(f"No metadata found on image {image_name}")
try:
val = metadata.root[self.key]
if not isinstance(val, str):
raise ValueError(f"Metadata at key '{self.key}' must be a string")
return StringOutput(value=val)
except KeyError as e:
raise ValueError(f"No key '{self.key}' found in the metadata for {image_name}") from e

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List, Optional, Literal
from pydantic import BaseModel, Field
@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@@ -68,22 +68,19 @@ class CLIPField(BaseModel):
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class ControlLoRAField(LoRAField):
class StructuralLoRAField(LoRAField):
img: ImageField = Field(description="Image to use in structural conditioning")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
@@ -206,7 +203,7 @@ class LoRALoaderInvocation(BaseInvocation):
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
@@ -257,12 +254,12 @@ class LoRASelectorInvocation(BaseInvocation):
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.1.0")
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
@@ -282,14 +279,7 @@ class LoRACollectionLoader(BaseInvocation):
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
@@ -300,10 +290,14 @@ class LoRACollectionLoader(BaseInvocation):
added_loras.append(lora.lora.key)
if self.unet is not None and output.unet is not None:
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None and output.clip is not None:
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output
@@ -403,13 +397,13 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.1.0",
version="1.0.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
@@ -435,18 +429,7 @@ class SDXLLoRACollectionLoader(BaseInvocation):
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
if self.clip2 is not None:
output.clip2 = self.clip2.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
@@ -457,13 +440,19 @@ class SDXLLoRACollectionLoader(BaseInvocation):
added_loras.append(lora.lora.key)
if self.unet is not None and output.unet is not None:
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None and output.clip is not None:
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.clip2 is not None and output.clip2 is not None:
if self.clip2 is not None:
if output.clip2 is None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(lora)
return output
@@ -481,7 +470,7 @@ class VAELoaderInvocation(BaseInvocation):
key = self.vae_model.key
if not context.models.exists(key):
raise Exception(f"Unknown vae: {key}!")
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VAEField(vae=self.vae_model))

View File

@@ -7,6 +7,7 @@ import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -265,9 +266,13 @@ class ImageInvocation(BaseInvocation):
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.images.get_dto(self.image.image_name)
image = context.images.get_pil(self.image.image_name)
return ImageOutput.build(image_dto=image_dto)
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@invocation(
@@ -412,7 +417,6 @@ class ColorInvocation(BaseInvocation):
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor."""
# shape: [1, H, W], dtype: bool
mask: TensorField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
@@ -535,3 +539,23 @@ class BoundingBoxInvocation(BaseInvocation):
# endregion
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")

View File

@@ -16,7 +16,6 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -40,7 +39,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling()
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.

View File

@@ -6,7 +6,6 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -27,7 +26,7 @@ from invokeai.backend.util.devices import TorchDevice
title="SD3 Latents to Image",
tags=["latents", "image", "vae", "l2i", "sd3"],
category="latents",
version="1.3.1",
version="1.3.0",
)
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -41,29 +40,16 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(TorchDevice.choose_torch_device())
latents = latents.to(vae.device)
vae.disable_tiling()

View File

@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.config import SubModelType
@@ -92,13 +88,21 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder, loras=[]),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -16,10 +16,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -87,11 +87,14 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
@@ -118,7 +121,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
@@ -126,12 +129,14 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
context.models.load(clip_model.tokenizer) as clip_tokenizer,
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
context.util.signal_progress("Running CLIP encoder")
@@ -146,11 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=clip_text_encoder.dtype,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)
@@ -181,7 +186,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
@@ -190,9 +195,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[ModelPatchRaw, float]]:
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -49,7 +49,7 @@ class SAMPointsField(BaseModel):
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.2.0",
version="1.1.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
@@ -96,10 +96,8 @@ class SegmentAnythingInvocation(BaseInvocation):
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
# Unsqueeze the channel dimension.
combined_mask = combined_mask.unsqueeze(0)
mask_tensor_name = context.tensors.save(combined_mask)
_, height, width = combined_mask.shape
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod

View File

@@ -22,7 +22,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
from invokeai.backend.util.devices import TorchDevice
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
@@ -103,7 +102,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=spandrel_model.dtype)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile.
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
@@ -117,7 +116,9 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
@@ -150,12 +151,15 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return pil_image
@torch.no_grad()
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
message=f"Processing tile {step}/{total_steps}",
@@ -163,7 +167,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
)
# Do the upscaling.
with context.models.load(self.image_to_image_model) as spandrel_model:
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image
@@ -196,12 +200,15 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@torch.no_grad()
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
target_width = int(image.width * self.scale)
@@ -214,7 +221,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
)
# Do the upscaling.
with context.models.load(self.image_to_image_model) as spandrel_model:
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
iteration = 1

View File

@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@@ -194,31 +194,32 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
device = TorchDevice.choose_torch_device()
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with (
ExitStack() as exit_stack,
context.models.load(self.unet.unet) as unet,
LayerPatcher.apply_smart_model_patches(
unet_info as unet,
LoRAPatcher.apply_smart_lora_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=device, dtype=unet.dtype)
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
@@ -227,7 +228,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=device,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
@@ -240,7 +241,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
@@ -266,7 +266,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=device,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,

View File

@@ -9,6 +9,6 @@ def validate_weights(weights: Union[float, list[float]]) -> None:
def validate_begin_end_step(begin_step_percent: float, end_step_percent: float) -> None:
"""Validate that begin_step_percent is less than or equal to end_step_percent"""
if begin_step_percent > end_step_percent:
"""Validate that begin_step_percent is less than end_step_percent"""
if begin_step_percent >= end_step_percent:
raise ValueError("Begin step percent must be less than or equal to end step percent")

View File

@@ -1,82 +1,12 @@
import uvicorn
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
def get_app():
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
"""
from invokeai.app.api_app import app, loop
return app, loop
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
def run_app() -> None:
"""The main entrypoint for the app."""
# Parse the CLI arguments.
# Before doing _anything_, parse CLI args!
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
InvokeAIArgs.parse_args()
# Load config.
app_config = get_config()
from invokeai.app.api_app import invoke_api
logger = InvokeAILogger.get_logger(config=app_config)
# Configure the torch CUDA memory allocator.
# NOTE: It is important that this happens before torch is imported.
if app_config.pytorch_cuda_alloc_conf:
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
# Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called.
from invokeai.app.util.startup_utils import (
apply_monkeypatches,
check_cudnn,
enable_dev_reload,
find_open_port,
register_mime_types,
)
# Find an open port, and modify the config accordingly.
orig_config_port = app_config.port
app_config.port = find_open_port(app_config.port)
if orig_config_port != app_config.port:
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
# Miscellaneous startup tasks.
apply_monkeypatches()
register_mime_types()
if app_config.dev_reload:
enable_dev_reload()
check_cudnn(logger)
# Initialize the app and event loop.
app, loop = get_app()
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)
# Start the server.
config = uvicorn.Config(
app=app,
host=app_config.host,
port=app_config.port,
loop="asyncio",
log_level=app_config.log_level_network,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)
loop.run_until_complete(server.serve())
invoke_api()

View File

@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.image_records.image_records_common import ImageCategory
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for the one-to-many board-image relationship record storage."""
@@ -28,8 +26,6 @@ class BoardImageRecordStorageBase(ABC):
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
"""Gets all board images for a board, as a list of the image names."""
pass

View File

@@ -1,20 +1,23 @@
import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecord,
deserialize_image_record,
)
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def add_image_to_board(
self,
@@ -22,8 +25,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
@@ -35,14 +38,16 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
@@ -53,6 +58,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
@@ -61,108 +68,96 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
AND board_images.board_id = ?
"""
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT image_name
FROM board_images
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_board_for_image(
self,
image_name: str,
) -> Optional[str]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_image_count_for_board(self, board_id: str) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
return count
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.image_records.image_records_common import ImageCategory
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@@ -28,8 +26,6 @@ class BoardImagesServiceABC(ABC):
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
"""Gets all board images for a board, as a list of the image names."""
pass

View File

@@ -1,7 +1,6 @@
from typing import Optional
from invokeai.app.services.board_images.board_images_base import BoardImagesServiceABC
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.invoker import Invoker
@@ -27,14 +26,8 @@ class BoardImagesService(BoardImagesServiceABC):
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(
board_id,
categories,
is_intermediate,
)
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
def get_board_for_image(
self,

View File

@@ -57,7 +57,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300)
board_name: Optional[str] = Field(default=None, description="The board's new name.")
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")

View File

@@ -1,4 +1,5 @@
import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
@@ -18,14 +19,20 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def delete(self, board_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
@@ -33,9 +40,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(board_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
@@ -43,8 +55,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
) -> BoardRecord:
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
@@ -55,6 +67,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get(
@@ -62,8 +76,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
@@ -72,9 +86,12 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -85,10 +102,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
changes: BoardChanges,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
@@ -99,7 +117,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
@@ -110,7 +128,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
@@ -123,6 +141,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get_many(
@@ -133,10 +153,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
cursor = self._conn.cursor()
try:
self._lock.acquire()
# Build base query
base_query = """
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
@@ -144,67 +165,81 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
# Execute count query
self._cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
try:
self._lock.acquire()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
self._cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards
return boards
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@@ -63,11 +63,7 @@ class BulkDownloadService(BulkDownloadBase):
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
def _board_handler(self, board_id: str) -> list[ImageDTO]:
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(
board_id,
categories=None,
is_intermediate=None,
)
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
return self._image_handler(image_names)
def generate_item_id(self, board_id: Optional[str]) -> str:

View File

@@ -13,6 +13,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Literal, Optional
import psutil
import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
@@ -24,6 +25,8 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
@@ -33,6 +36,24 @@ LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
def get_default_ram_cache_size() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
# On some machines, psutil.virtual_memory().total gives a value that is slightly less than the actual RAM, so the
# limits are set slightly lower than than what we expect the actual RAM to be.
GB = 1024**3
max_ram = psutil.virtual_memory().total / GB
if max_ram >= 60:
return 15.0
if max_ram >= 30:
return 7.5
if max_ram >= 14:
return 4.0
return 2.1 # 2.1 is just large enough for sd 1.5 ;-)
class URLRegexTokenPair(BaseModel):
url_regex: str = Field(description="Regular expression to match against the URL")
token: str = Field(description="Token to use when the URL matches the regex")
@@ -76,22 +97,15 @@ class InvokeAIAppConfig(BaseSettings):
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
log_sql: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
log_level_network: Log level for network-related messages. 'info' and 'debug' are very verbose.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
use_memory_db: Use in-memory database. Useful for development.
dev_reload: Automatically reload when Python sources are changed. Does not reload node definitions.
profile_graphs: Enable graph profiling using `cProfile`.
profile_prefix: An optional prefix for profile output files.
profiles_dir: Path to profiles output directory.
max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
@@ -149,7 +163,6 @@ class InvokeAIAppConfig(BaseSettings):
log_format: LOG_FORMAT = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.')
log_level: LOG_LEVEL = Field(default="info", description="Emit logging messages at this level or higher.")
log_sql: bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.")
log_level_network: LOG_LEVEL = Field(default='warning', description="Log level for network-related messages. 'info' and 'debug' are very verbose.")
# Development
use_memory_db: bool = Field(default=False, description="Use in-memory database. Useful for development.")
@@ -159,19 +172,10 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE
max_cache_ram_gb: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
max_cache_vram_gb: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
# PyTorch Memory Allocator
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")

View File

@@ -8,7 +8,7 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
from typing import Any, Dict, List, Literal, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
@@ -28,13 +28,11 @@ from invokeai.app.services.download.download_base import (
ServiceInactiveException,
UnknownJobIDException,
)
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.model_manager.metadata import RemoteModelFile
from invokeai.backend.util.logging import InvokeAILogger
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000

View File

@@ -0,0 +1 @@
from .events_base import EventServiceBase # noqa F401

View File

@@ -28,7 +28,6 @@ from invokeai.app.services.events.events_common import (
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemsRetriedEvent,
QueueItemStatusChangedEvent,
)
@@ -40,7 +39,6 @@ if TYPE_CHECKING:
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
RetryItemsResult,
SessionQueueItem,
SessionQueueStatus,
)
@@ -101,10 +99,6 @@ class EventServiceBase:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None:
"""Emitted when a list of queue items are retried"""
self.dispatch(QueueItemsRetriedEvent.build(retry_result))
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))

View File

@@ -4,13 +4,11 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
RetryItemsResult,
SessionQueueItem,
SessionQueueStatus,
)
@@ -20,7 +18,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventBase(BaseModel):
@@ -291,22 +289,6 @@ class BatchEnqueuedEvent(QueueEventBase):
)
@payload_schema.register
class QueueItemsRetriedEvent(QueueEventBase):
"""Event model for queue_items_retried"""
__event_name__ = "queue_items_retried"
retried_item_ids: list[int] = Field(description="The IDs of the queue items that were retried")
@classmethod
def build(cls, retry_result: RetryItemsResult) -> "QueueItemsRetriedEvent":
return cls(
queue_id=retry_result.queue_id,
retried_item_ids=retry_result.retried_item_ids,
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
@@ -440,7 +422,7 @@ class ModelInstallDownloadStartedEvent(ModelEventBase):
__event_name__ = "model_install_download_started"
id: int = Field(description="The ID of the install job")
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
@@ -461,7 +443,7 @@ class ModelInstallDownloadStartedEvent(ModelEventBase):
]
return cls(
id=job.id,
source=job.source,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
@@ -476,7 +458,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
@@ -497,7 +479,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
]
return cls(
id=job.id,
source=job.source,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
@@ -512,11 +494,11 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=job.source)
return cls(id=job.id, source=str(job.source))
@payload_schema.register
@@ -526,11 +508,11 @@ class ModelInstallStartedEvent(ModelEventBase):
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=job.source)
return cls(id=job.id, source=str(job.source))
@payload_schema.register
@@ -540,14 +522,14 @@ class ModelInstallCompleteEvent(ModelEventBase):
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=job.source, key=(job.config_out.key), total_bytes=job.total_bytes)
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
@@ -557,11 +539,11 @@ class ModelInstallCancelledEvent(ModelEventBase):
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=job.source)
return cls(id=job.id, source=str(job.source))
@payload_schema.register
@@ -571,7 +553,7 @@ class ModelInstallErrorEvent(ModelEventBase):
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@@ -579,7 +561,7 @@ class ModelInstallErrorEvent(ModelEventBase):
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=job.source, error_type=job.error_type, error=job.error)
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEventBase(EventBase):

View File

@@ -1,4 +1,5 @@
import sqlite3
import threading
from datetime import datetime
from typing import Optional, Union, cast
@@ -21,14 +22,21 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def get(self, image_name: str) -> ImageRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
@@ -36,9 +44,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
@@ -47,8 +58,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
@@ -56,7 +68,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
if not result:
raise ImageRecordNotFoundException
@@ -65,7 +77,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
def update(
self,
@@ -73,10 +88,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
changes: ImageRecordChanges,
) -> None:
try:
cursor = self._conn.cursor()
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
@@ -87,7 +102,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
@@ -98,7 +113,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
@@ -109,7 +124,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE images
SET starred = ?
@@ -122,6 +137,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
@@ -135,104 +152,110 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_params.append(is_intermediate)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
@@ -243,48 +266,58 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
placeholders = ",".join("?" for _ in image_names)
self._lock.acquire()
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
return count
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, self._cursor.fetchone()[0])
self._conn.commit()
return count
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def delete_intermediates(self) -> list[str]:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
@@ -295,6 +328,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
@@ -311,8 +346,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata: Optional[str] = None,
) -> datetime:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
@@ -345,7 +380,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
)
self._conn.commit()
cursor.execute(
self._cursor.execute(
"""--sql
SELECT created_at
FROM images
@@ -354,30 +389,34 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
finally:
self._lock.release()
if result is None:
return None

View File

@@ -265,11 +265,7 @@ class ImageService(ImageServiceABC):
def delete_images_on_board(self, board_id: str):
try:
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(
board_id,
categories=None,
is_intermediate=None,
)
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
for image_name in image_names:
self.__invoker.services.image_files.delete(image_name)
self.__invoker.services.image_records.delete_many(image_names)
@@ -282,7 +278,7 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Failed to delete image files")
raise
except Exception as e:
self.__invoker.services.logger.error(f"Problem deleting image records and files: {str(e)}")
self.__invoker.services.logger.error("Problem deleting image records and files")
raise e
def delete_intermediates(self) -> int:

View File

@@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
NodeExecutionStatsSummary,
)
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache import CacheStats
# Size of 1GB in bytes.
GB = 2**30

View File

@@ -1,4 +1,4 @@
# TODO: Should these exceptions subclass existing python exceptions?
# TODO: Should these excpetions subclass existing python exceptions?
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""

View File

@@ -3,20 +3,18 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Union
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""

View File

@@ -9,7 +9,7 @@ from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import yaml
@@ -20,6 +20,7 @@ from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_install.model_install_common import (
@@ -56,10 +57,6 @@ from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
TMPDIR_PREFIX = "tmpinstall_"
@@ -441,10 +438,9 @@ class ModelInstallService(ModelInstallServiceBase):
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
source_stripped = source.strip('"')
if Path(source_stripped).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source_stripped))
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),

View File

@@ -7,7 +7,7 @@ from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
class ModelLoadServiceBase(ABC):
@@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
@property
@abstractmethod
def ram_cache(self) -> ModelCache:
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
@abstractmethod

View File

@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
def __init__(
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCache,
ram_cache: ModelCacheBase[AnyModel],
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
@@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._invoker = invoker
@property
def ram_cache(self) -> ModelCache:
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
return self._ram_cache
@@ -78,8 +78,9 @@ class ModelLoadService(ModelLoadServiceBase):
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
@@ -108,5 +109,5 @@ class ModelLoadService(ModelLoadServiceBase):
)
assert loader is not None
raw_model = loader(model_path)
self._ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@@ -16,8 +16,7 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
from invokeai.app.services.model_load.model_load_default import ModelLoadService
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -82,13 +81,11 @@ class ModelManagerService(ModelManagerServiceBase):
logger.setLevel(app_config.log_level.upper())
ram_cache = ModelCache(
execution_device_working_mem_gb=app_config.device_working_mem_gb,
enable_partial_loading=app_config.enable_partial_loading,
keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights,
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
loader = ModelLoadService(
app_config=app_config,

View File

@@ -78,6 +78,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._logger = logger
@property
@@ -95,38 +96,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(config.key)
@@ -138,21 +139,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
DELETE FROM models
WHERE id=?;
""",
(key,),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
with self._db.lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM models
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
@@ -163,23 +164,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -191,33 +192,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
@@ -226,15 +227,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
count = 0
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -282,18 +284,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
with self._db.lock:
self._cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = self._cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -312,28 +313,34 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results
def list_models(
@@ -349,32 +356,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
ModelRecordOrderBy.Format: "format",
}
cursor = self._db.conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
with self._db.lock:
# query1: get the total number of model configs
self._cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(self._cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
# query2: fetch key fields
self._cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)

View File

@@ -439,9 +439,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
self._invoker.services.logger.info(
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# Run the graph

View File

@@ -1,11 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, Coroutine, Optional
from typing import Optional
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
CancelByBatchIDsResult,
CancelByDestinationResult,
CancelByQueueIDResult,
@@ -14,7 +13,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
IsEmptyResult,
IsFullResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
@@ -33,7 +31,7 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
"""Enqueues all permutations of a batch for execution."""
pass
@@ -114,11 +112,6 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with matching queue ID"""
pass
@abstractmethod
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
"""Cancels all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
@@ -140,8 +133,3 @@ class SessionQueueBase(ABC):
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
"""Sets the session for a session queue item. Use this to update the session state."""
pass
@abstractmethod
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
pass

View File

@@ -1,7 +1,7 @@
import datetime
import json
from itertools import chain, product
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import (
AliasChoices,
@@ -108,16 +108,8 @@ class Batch(BaseModel):
return v
for batch_data_list in v:
for datum in batch_data_list:
if not datum.items:
continue
# Special handling for numbers - they can be mixed
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
if all(isinstance(item, (int, float)) for item in datum.items):
continue
# Get the type of the first item in the list
first_item_type = type(datum.items[0])
first_item_type = type(datum.items[0]) if datum.items else None
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
@@ -234,9 +226,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
retried_from_item_id: Optional[int] = Field(
default=None, description="The item_id of the queue item that this item was retried from"
)
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
@@ -347,11 +336,6 @@ class EnqueueBatchResult(BaseModel):
priority: int = Field(description="The priority of the enqueued batch")
class RetryItemsResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
retried_item_ids: list[int] = Field(description="The IDs of the queue items that were retried")
class ClearResult(BaseModel):
"""Result of clearing the session queue"""
@@ -382,12 +366,6 @@ class CancelByQueueIDResult(CancelByBatchIDsResult):
pass
class CancelAllExceptCurrentResult(CancelByBatchIDsResult):
"""Result of canceling all except current"""
pass
class IsEmptyResult(BaseModel):
"""Result of checking if the session queue is empty"""
@@ -406,143 +384,61 @@ class IsFullResult(BaseModel):
# region Util
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
"""
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
field_values_json for each session.
Populates the given graph with the given batch data items.
"""
graph_clone = graph.model_copy(deep=True)
for item in node_field_values:
node = graph_clone.get_node(item.node_path)
if node is None:
continue
setattr(node, item.field_name, item.value)
graph_clone.update_node(item.node_path, node)
return graph_clone
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
the field.
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
Each inner list now represents the substitution values for a single permutation (session).
- For each permutation, substitute the values into the graph
This function is optimized for performance, as it is used to generate a large number of sessions at once.
Args:
batch: The batch to generate sessions from
maximum: The maximum number of sessions to generate
Returns:
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
generator will stop early if the maximum number of sessions is reached.
def create_session_nfv_tuples(
batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
that was applied to the graph.
"""
# TODO: Should this be a class method on Batch?
data: list[list[tuple[dict]]] = []
data: list[list[tuple[NodeFieldValue]]] = []
batch_data_collection = batch.data if batch.data is not None else []
for batch_datum_list in batch_data_collection:
node_field_values_to_zip: list[list[dict]] = []
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
node_field_values_to_zip: list[list[NodeFieldValue]] = []
for batch_datum in batch_datum_list:
node_field_values = [
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
# Zip the dicts together to create a list of dicts for each permutation
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# We serialize the graph and session once, then mutate the graph dict in place for each session.
#
# This sounds scary, but it's actually fine.
#
# The batch prep logic injects field values into the same fields for each generated session.
#
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
# [
# (
# {"node_path": "1", "field_name": "a", "value": 1},
# {"node_path": "2", "field_name": "b", "value": 2},
# {"node_path": "3", "field_name": "c", "value": 3},
# ),
# (
# {"node_path": "1", "field_name": "a", "value": 4},
# {"node_path": "2", "field_name": "b", "value": 5},
# {"node_path": "3", "field_name": "c", "value": 6},
# )
# ]
#
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
# No matter the complexity of the batch, this property holds true.
#
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
#
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
# but this was also slow.
#
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
# objects for each session.
#
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
# dict as the session's graph.
# Dump the batch's graph to a dict once
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
# overwritten for each session by the mutated graph_as_dict.
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
# create generator to yield session,nfv tuples
count = 0
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
# still limited by the maximum number of sessions.
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
# We've reached the maximum number of sessions we may generate
return
# Flatten the list of lists of dicts into a single list of dicts
# TODO(psyche): Is the a more efficient way to do this?
flat_node_field_values = list(chain.from_iterable(d))
# Need a fresh ID for each session
session_id = uuid_string()
# Mutate the session dict in place
session_dict["id"] = session_id
# Substitute the values into the graph
for nfv in flat_node_field_values:
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
# Mutate the session dict in place
session_dict["graph"] = graph_as_dict
# Serialize the session and field values
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
session_json = json.dumps(session_dict, default=to_jsonable_python)
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
# Yield the session_id, session_json, and field_values_json
yield (session_id, session_json, field_values_json)
# Increment the count so we know when to stop
graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
Calculates the number of sessions that would be created by the batch, without incurring
the overhead of actually generating them. Adapted from `create_sessions().
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
@@ -558,75 +454,41 @@ def calc_session_count(batch: Batch) -> int:
return len(data_product) * batch.runs
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)
str, # session_id
str, # batch_id
str | None, # field_values (optional, as stringified JSON)
int, # priority
str | None, # workflow (optional, as stringified JSON)
str | None, # origin (optional)
str | None, # destination (optional)
int | None, # retried_from_item_id (optional, this is always None for new items)
]
"""A type alias for the tuple of values to insert into the session queue table."""
class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table"""
# Careful with the ordering of this - it must match the insert statement
queue_id: str # queue_id
session: str # session json
session_id: str # session_id
batch_id: str # batch_id
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: str | None
destination: str | None
def prepare_values_to_insert(
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
) -> list[ValueToInsertTuple]:
"""
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
`executemany` statement to insert multiple rows at once.
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
Args:
queue_id: The ID of the queue to insert the items into
batch: The batch to prepare the values for
priority: The priority of the queue items
max_new_queue_items: The maximum number of queue items to insert
Returns:
A list of tuples to insert into the session queue table. Each tuple contains the following values:
- queue_id
- session (as stringified JSON)
- session_id
- batch_id
- field_values (optional, as stringified JSON)
- priority
- workflow (optional, as stringified JSON)
- origin (optional)
- destination (optional)
- retried_from_item_id (optional, this is always None for new items)
"""
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
# this difference becomes noticeable.
#
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
values_to_insert: list[ValueToInsertTuple] = []
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
# not support by default. Apparently there are sets somewhere in the graph.
# The same workflow is used for all sessions in the batch - serialize it once
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = []
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id
session.id = uuid_string()
values_to_insert.append(
(
queue_id,
session_json,
session_id,
batch.batch_id,
field_values_json,
priority,
workflow_json,
batch.origin,
batch.destination,
None,
SessionQueueValueToInsert(
queue_id, # queue_id
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
session.id, # session_id
batch.batch_id, # batch_id
# must use pydantic_encoder bc field_values is a list of models
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
batch.destination, # destination
)
)
return values_to_insert

View File

@@ -1,10 +1,7 @@
import asyncio
import json
import sqlite3
import threading
from typing import Optional, Union, cast
from pydantic_core import to_jsonable_python
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
@@ -12,7 +9,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
CancelByBatchIDsResult,
CancelByDestinationResult,
CancelByQueueIDResult,
@@ -21,7 +17,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
IsEmptyResult,
IsFullResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
@@ -37,6 +32,9 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteSessionQueue(SessionQueueBase):
__invoker: Invoker
__conn: sqlite3.Connection
__cursor: sqlite3.Cursor
__lock: threading.RLock
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
@@ -52,7 +50,9 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
def _set_in_progress_to_canceled(self) -> None:
"""
@@ -60,8 +60,8 @@ class SqliteSessionQueue(SessionQueueBase):
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
cursor = self._conn.cursor()
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -69,13 +69,14 @@ class SqliteSessionQueue(SessionQueueBase):
"""
)
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
cursor = self._conn.cursor()
cursor.execute(
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -85,12 +86,11 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
return cast(int, self.__cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
cursor = self._conn.cursor()
cursor.execute(
self.__cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
@@ -100,14 +100,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
self.__lock.acquire()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
@@ -129,17 +127,19 @@ class SqliteSessionQueue(SessionQueueBase):
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
cursor.executemany(
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
@@ -151,19 +151,25 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -171,40 +177,52 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -218,8 +236,8 @@ class SqliteSessionQueue(SessionQueueBase):
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
@@ -227,10 +245,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -238,36 +258,48 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
cursor = self._conn.cursor()
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -275,8 +307,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
"""--sql
DELETE
FROM session_queue
@@ -284,16 +316,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id = ?
@@ -303,7 +336,8 @@ class SqliteSessionQueue(SessionQueueBase):
OR status = 'canceled'
)
"""
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -311,8 +345,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
DELETE
FROM session_queue
@@ -320,10 +354,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
@@ -352,8 +388,8 @@ class SqliteSessionQueue(SessionQueueBase):
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
WHERE
@@ -364,7 +400,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id] + batch_ids
cursor.execute(
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -372,8 +408,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -381,18 +417,20 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
@@ -402,7 +440,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = (queue_id, destination)
cursor.execute(
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -410,8 +448,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -419,18 +457,20 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
self.__conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByDestinationResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id is ?
@@ -439,7 +479,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id]
cursor.execute(
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -447,8 +487,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = cursor.fetchone()[0]
cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -456,7 +496,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
@@ -464,64 +504,41 @@ class SqliteSessionQueue(SessionQueueBase):
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
@@ -529,10 +546,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items(
@@ -543,71 +562,83 @@ class SqliteSessionQueue(SessionQueueBase):
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
try:
item_id = cursor
self.__lock.acquire()
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
self.__cursor.execute(query, params)
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
@@ -626,23 +657,29 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return BatchStatus(
batch_id=batch_id,
@@ -658,18 +695,24 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
@@ -684,68 +727,3 @@ class SqliteSessionQueue(SessionQueueBase):
canceled=counts.get("canceled", 0),
total=total,
)
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
values_to_insert: list[tuple] = []
retried_item_ids: list[int] = []
for item_id in item_ids:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ("failed", "canceled"):
continue
retried_item_ids.append(item_id)
field_values_json = (
json.dumps(queue_item.field_values, default=to_jsonable_python) if queue_item.field_values else None
)
workflow_json = (
json.dumps(queue_item.workflow, default=to_jsonable_python) if queue_item.workflow else None
)
cloned_session = GraphExecutionState(graph=queue_item.session.graph)
cloned_session_json = cloned_session.model_dump_json(warnings=False, exclude_none=True)
retried_from_item_id = (
queue_item.retried_from_item_id
if queue_item.retried_from_item_id is not None
else queue_item.item_id
)
value_to_insert = (
queue_item.queue_id,
queue_item.batch_id,
queue_item.destination,
field_values_json,
queue_item.origin,
queue_item.priority,
workflow_json,
cloned_session_json,
cloned_session.id,
retried_from_item_id,
)
values_to_insert.append(value_to_insert)
# TODO(psyche): Handle max queue size?
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,
)
self.__invoker.services.events.emit_queue_items_retried(retry_result)
return retry_result

View File

@@ -51,18 +51,15 @@ class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
def __str__(self):
return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}"
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
def get_input_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
@@ -96,10 +93,6 @@ def is_list_or_contains_list(t):
return False
def is_any(t: Any) -> bool:
return t == Any or Any in get_args(t)
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
@@ -109,7 +102,13 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
if from_type == to_type or is_any(from_type) or is_any(to_type):
if (
from_type == to_type
or from_type == Any
or to_type == Any
or Any in get_args(from_type)
or Any in get_args(to_type)
):
return True
if from_type in get_args(to_type):
@@ -141,10 +140,10 @@ def are_connections_compatible(
"""Determines if a connection between fields of two nodes is compatible."""
# TODO: handle iterators and collectors
from_type = get_output_field_type(from_node, from_field)
to_type = get_input_field_type(to_node, to_field)
from_node_field = get_output_field(from_node, from_field)
to_node_field = get_input_field(to_node, to_field)
return are_connection_types_compatible(from_type, to_type)
return are_connection_types_compatible(from_node_field, to_node_field)
T = TypeVar("T")
@@ -441,19 +440,17 @@ class Graph(BaseModel):
self.get_node(edge.destination.node_id),
edge.destination.field,
):
raise InvalidEdgeError(f"Edge source and target types do not match ({edge})")
raise InvalidEdgeError(
f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate all iterators & collectors
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
for node in self.nodes.values():
if isinstance(node, IterateInvocation):
err = self._is_iterator_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}")
if isinstance(node, CollectInvocation):
err = self._is_collector_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}")
if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid iterator node {node.id}")
if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid collector node {node.id}")
return None
@@ -480,11 +477,11 @@ class Graph(BaseModel):
def _is_destination_field_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == Any
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any
def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
@@ -494,40 +491,55 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError(f"One or both nodes don't exist ({edge})")
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(f"Edge already exists ({edge})")
raise InvalidEdgeError(
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
)
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})")
raise InvalidEdgeError(
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
)
# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
raise InvalidEdgeError(f"Field types are incompatible ({edge})")
raise InvalidEdgeError(
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate that we are not connecting collector to iterator (currently unsupported)
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
raise InvalidEdgeError(
f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
if (
@@ -536,9 +548,10 @@ class Graph(BaseModel):
and not self._is_destination_field_list_of_Any(edge)
and not self._is_destination_field_Any(edge)
):
err = self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}")
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph."""
@@ -621,7 +634,7 @@ class Graph(BaseModel):
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
@@ -632,47 +645,29 @@ class Graph(BaseModel):
# Only one input is allowed for iterators
if len(inputs) > 1:
return "Iterator may only have one input edge"
input_node = self.get_node(inputs[0].node_id)
return False
# Get input and output fields (the fields linked to the iterator's input/output)
input_field_type = get_output_field_type(input_node, inputs[0].field)
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field_type) is not list:
return "Iterator input must be a collection"
if get_origin(input_field) is not list:
return False
# Validate that all outputs match the input type
input_field_item_type = get_args(input_field_type)[0]
if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
return "Iterator outputs must connect to an input with a matching type"
input_field_item_type = get_args(input_field)[0]
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
return False
# Collector input type must match all iterator output types
if isinstance(input_node, CollectInvocation):
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
resolved_collector_type = (
first_collector_input_type
if get_origin(first_collector_input_type) is None
else get_args(first_collector_input_type)
)
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
return "Iterator collection type must match all iterator output types"
return None
return True
def _is_collector_connection_valid(
self,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
@@ -682,42 +677,38 @@ class Graph(BaseModel):
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Validate that all inputs are derived from or match a single type
input_field_types = {
resolved_type
for input_field_type in input_field_types
for resolved_type in (
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
)
if resolved_type != NoneType
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
} # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return "Collector input collection items must be of a single type"
return False # There is more than one root type
# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
return "Collector output must connect to a collection input"
if not all(is_list_or_contains_list(f) for f in output_fields):
return False
# Verify that all outputs match the input type (are a base class or the same class)
if not all(
is_any(t)
or is_union_subtype(input_root_type, get_args(t)[0])
or issubclass(input_root_type, get_args(t)[0])
for t in output_field_types
is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
for f in output_fields
):
return "Collector outputs must connect to a collection input with a matching type"
return False
return None
return True
def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""

View File

@@ -9,7 +9,6 @@ from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
from invokeai.app.services.board_records.board_records_common import BoardRecordOrderBy
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
@@ -17,7 +16,6 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModel,
@@ -104,9 +102,7 @@ class BoardsInterface(InvocationContextInterface):
Returns:
A list of all boards.
"""
return self._services.boards.get_all(
order_by=BoardRecordOrderBy.CreatedAt, direction=SQLiteDirection.Descending
)
return self._services.boards.get_all()
def add_image_to_board(self, board_id: str, image_name: str) -> None:
"""Adds an image to a board.
@@ -126,11 +122,7 @@ class BoardsInterface(InvocationContextInterface):
Returns:
A list of all image names for the board.
"""
return self._services.board_images.get_all_board_image_names_for_board(
board_id,
categories=None,
is_intermediate=None,
)
return self._services.board_images.get_all_board_image_names_for_board(board_id)
class LoggerInterface(InvocationContextInterface):
@@ -291,7 +283,7 @@ class ImagesInterface(InvocationContextInterface):
Returns:
The local path of the image or thumbnail.
"""
return Path(self._services.images.get_path(image_name, thumbnail))
return self._services.images.get_path(image_name, thumbnail)
class TensorsInterface(InvocationContextInterface):

View File

@@ -1,4 +1,5 @@
import sqlite3
import threading
from logging import Logger
from pathlib import Path
@@ -37,20 +38,14 @@ class SqliteDatabase:
self.logger.info(f"Initializing database at {self.db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
# Enable foreign key constraints
self.conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self.conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
@@ -58,14 +53,15 @@ class SqliteDatabase:
# No need to clean in-memory database
if not self.db_path:
return
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
raise
with self.lock:
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
raise

View File

@@ -18,7 +18,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import build_migration_16
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -54,7 +53,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.register_migration(build_migration_16())
migrator.run_migrations()
return db

View File

@@ -1,31 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration16Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_retried_from_item_id_col(cursor)
def _add_retried_from_item_id_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `retried_from_item_id` column to the session queue table.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN retried_from_item_id INTEGER;")
def build_migration_16() -> Migration:
"""
Build the migration from database version 15 to 16.
This migration does the following:
- Adds `retried_from_item_id` column to the session queue table.
"""
migration_16 = Migration(
from_version=15,
to_version=16,
callback=Migration16Callback(),
)
return migration_16

View File

@@ -43,45 +43,46 @@ class SqliteMigrator:
def run_migrations(self) -> bool:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
with self._db.lock:
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
self._logger.info("Database update needed")
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.conn as conn:
with self._db.lock, self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(
@@ -107,26 +108,27 @@ class SqliteMigrator:
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
with self._db.lock:
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
@classmethod
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:

View File

@@ -17,7 +17,9 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,25 +27,31 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = self._cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
@@ -64,16 +72,18 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
cursor = self._conn.cursor()
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
cursor.execute(
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
@@ -94,15 +104,17 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
cursor = self._conn.cursor()
self._lock.acquire()
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE style_presets
SET name = ?
@@ -113,7 +125,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
# Change the preset data for a style preset
if changes.preset_data is not None:
cursor.execute(
self._cursor.execute(
"""--sql
UPDATE style_presets
SET preset_data = ?
@@ -126,12 +138,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from style_presets
WHERE id = ?;
@@ -142,38 +156,46 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
main_query = """
SELECT
*
FROM style_presets
"""
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
return style_presets
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
@@ -183,8 +205,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)

View File

@@ -62,13 +62,9 @@ class WorkflowWithoutID(BaseModel):
notes: str = Field(description="The notes of the workflow.")
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
meta: WorkflowMeta = Field(description="The meta of the workflow.")
# TODO(psyche): nodes, edges and form are very loosely typed - they are strictly modeled and checked on the frontend.
# TODO: nodes and edges are very loosely typed
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
# TODO(psyche): We have a crapload of workflows that have no form, bc it was added after we introduced workflows.
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
# it is None.
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
model_config = ConfigDict(extra="ignore")

View File

@@ -23,7 +23,9 @@ from invokeai.app.util.misc import uuid_string
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -31,36 +33,42 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
self._conn.commit()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = self._cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
try:
# Only user workflows may be created by this method
assert workflow.meta.category is WorkflowCategory.User
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
workflow_id,
@@ -74,12 +82,14 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
@@ -91,12 +101,14 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = 'user';
@@ -107,6 +119,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(
@@ -118,60 +132,66 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
per_page: Optional[int] = None,
query: Optional[str] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at
FROM workflow_library
WHERE category = ?
"""
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
try:
self._lock.acquire()
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at
FROM workflow_library
WHERE category = ?
"""
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])
main_query += f" ORDER BY {order_by.value} {direction.value}"
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
self._cursor.execute(main_query, main_params)
rows = self._cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
self._cursor.execute(count_query, count_params)
total = self._cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
else:
pages = 1 # If no pagination, there is only one page
if per_page:
pages = total // per_page + (total % per_page > 0)
else:
pages = 1 # If no pagination, there is only one page
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page if per_page else total,
pages=pages,
total=total,
)
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page if per_page else total,
pages=pages,
total=total,
)
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -187,6 +207,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
"""
try:
self._lock.acquire()
workflows: list[Workflow] = []
workflows_dir = Path(__file__).parent / Path("default_workflows")
workflow_paths = workflows_dir.glob("*.json")
@@ -197,15 +218,14 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
workflows.append(workflow)
# Only default workflows may be managed by this method
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
cursor = self._conn.cursor()
cursor.execute(
self._cursor.execute(
"""--sql
DELETE FROM workflow_library
WHERE category = 'default';
"""
)
for w in workflows:
cursor.execute(
self._cursor.execute(
"""--sql
INSERT OR REPLACE INTO workflow_library (
workflow_id,
@@ -219,3 +239,5 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()

View File

@@ -1,64 +0,0 @@
import logging
import mimetypes
import socket
import torch
def find_open_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_open_port(port=port + 1)
else:
return port
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def enable_dev_reload() -> None:
"""Enable hot reloading on python file changes during development."""
from invokeai.backend.util.logging import InvokeAILogger
try:
import jurigged
except ImportError as e:
raise RuntimeError(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
) from e
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
def apply_monkeypatches() -> None:
"""Apply monkeypatches to fix issues with third-party libraries."""
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
def register_mime_types() -> None:
"""Register additional mime types for windows."""
# Fix for windows mimetypes registry entries being borked.
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")

View File

@@ -1,26 +0,0 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")
def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")

View File

@@ -1,52 +0,0 @@
import logging
import os
import sys
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger):
"""Configure the PyTorch CUDA memory allocator. See
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf for supported
configurations.
"""
if "torch" in sys.modules:
raise RuntimeError("configure_torch_cuda_allocator() must be called before importing torch.")
# Log a warning if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
prev_cuda_alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)
if prev_cuda_alloc_conf is not None:
if prev_cuda_alloc_conf == pytorch_cuda_alloc_conf:
logger.info(
f"PYTORCH_CUDA_ALLOC_CONF is already set to '{pytorch_cuda_alloc_conf}'. Skipping configuration."
)
return
else:
logger.warning(
f"Attempted to configure the PyTorch CUDA memory allocator with '{pytorch_cuda_alloc_conf}', but PYTORCH_CUDA_ALLOC_CONF is already set to "
f"'{prev_cuda_alloc_conf}'. Skipping configuration."
)
return
# Configure the PyTorch CUDA memory allocator.
# NOTE: It is important that this happens before torch is imported.
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = pytorch_cuda_alloc_conf
import torch
# Relevant docs: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
if not torch.cuda.is_available():
raise RuntimeError(
"Attempted to configure the PyTorch CUDA memory allocator, but no CUDA devices are available."
)
# Verify that the torch allocator was properly configured.
allocator_backend = torch.cuda.get_allocator_backend()
expected_backend = "cudaMallocAsync" if "cudaMallocAsync" in pytorch_cuda_alloc_conf else "native"
if allocator_backend != expected_backend:
raise RuntimeError(
f"Failed to configure the PyTorch CUDA memory allocator. Expected backend: '{expected_backend}', but got "
f"'{allocator_backend}'. Verify that 1) the pytorch_cuda_alloc_conf is set correctly, and 2) that torch is "
"not imported before calling configure_torch_cuda_allocator()."
)
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")

View File

@@ -31,7 +31,7 @@ def denoise(
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens
img_cond: torch.Tensor | None,
img_cond: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1

Some files were not shown because too many files have changed in this diff Show More