Compare commits

...

85 Commits

Author SHA1 Message Date
psychedelicious
d06cc71cd9 build: downgrade python to 3.11 in pins 2025-04-03 10:31:56 +10:00
psychedelicious
ecbc1cf85d build: restore prev setuptools config to fix wheel build 2025-04-03 09:43:31 +10:00
psychedelicious
dcad306aee ci: use py3.12 to build installer 2025-04-03 09:30:17 +10:00
psychedelicious
bea9d037bc experiment: add pins.json to repo
The launcher will query this file to get the pins needed for installation
2025-04-03 08:55:45 +10:00
psychedelicious
eed4260975 chore: bump version to v5.10.0dev1
Doing a dev build so I can test the launcher.
2025-04-03 08:55:45 +10:00
psychedelicious
6a79b1c64c chore: update uv.lock for latest pydantic
Ran `uv lock --upgrade-package pydantic`
2025-04-03 08:55:45 +10:00
psychedelicious
4cdfe3e30d fix(ui): handle updated schema structure during invocation parsing
In https://github.com/pydantic/pydantic/pull/10029, pydantic made an improvement to its generated JSON schemas (OpenAPI schemas). The previous and new generated schemas both meet the schema spec.

When we parse the OpenAPI schema to generate node templates, we use some typeguard to narrow schema components from generic OpenAPI schema objects to a node field schema objects. The narrower node field schema objects contain extra data.

For example, they contain a `field_kind` attribute that indicates it the field is an input field or output field. These extra attributes are not part of the OpenAPI spec (but the spec allows does allow for this extra data).

This typeguard relied on a pydantic implementation detail. This was changed in the linked pydantic PR, which released with v2.9.0. With the change, our typeguard rejects input field schema objects, causing parsing to fail with errors/warnings like `Unhandled input property` in the JS console.

In the UI, this causes many fields - mostly model fields - to not show up in the workflow editor.

The fix for this is very simple - instead of relying on an implementation detail for the typeguard, we can check if the incoming schema object has any of our invoke-specific extra attributes. Specifically, we now look for the presence of the `field_kind` attribute on the incoming schema object. If it is present, we know we are dealing with an invocation input field and can parse it appropriately.
2025-04-03 08:55:45 +10:00
psychedelicious
df294db236 chore: typegen 2025-04-03 08:55:45 +10:00
psychedelicious
52e247cfe0 chore: remove pydantic pin 2025-04-03 08:55:45 +10:00
psychedelicious
0a13640bf3 chore(ui): typegen 2025-04-03 08:55:45 +10:00
psychedelicious
643b71f56c tests: update tests/test_object_serializer_disk.py 2025-04-03 08:55:45 +10:00
psychedelicious
b745411866 fix(app): add trusted classes to torch safe globals to prevent errors when loading them
In `ObjectSerializerDisk`, we use `torch.load` to load serialized objects from disk. With torch 2.6.0, torch defaults to `weights_only=True`. As a result, torch will raise when attempting to deserialize anything with an unrecognized class.

For example, our `ConditioningFieldData` class is untrusted. When we load conditioning from disk, we will get a runtime error.

Torch provides a method to add trusted classes to an allowlist. This change adds an arg to `ObjectSerializerDisk` to add a list of safe globals to the allowlist and uses it for both `ObjectSerializerDisk` instances.

Note: My first attempt inferred the class from the generic type arg that `ObjectSerializerDisk` accepts, and added that to the allowlist. Unfortunately, this doesn't work.

For example, `ConditioningFieldData` has a `conditionings` attribute that may be one some other untrusted classes representing model-specific conditioning data. So, even if we allowlist `ConditioningFieldData`, loading will fail when torch deserializes the `conditionings` attribute.
2025-04-03 08:55:45 +10:00
Eugene Brodsky
c3ffb0feed resolve conflict between timm version needed by LLaVA and controlnet-aux 2025-04-03 08:55:45 +10:00
Eugene Brodsky
f53ff5fa3c reintroduce GPU_DRIVER build arg in CI container build, as it has apparently been removed 2025-04-03 08:55:45 +10:00
Eugene Brodsky
b2337b56bd remove obsoleted depenencies that were used by the CLI 2025-04-03 08:55:45 +10:00
Eugene Brodsky
fb777b4502 modify docs for python 3.12 2025-04-03 08:55:45 +10:00
Eugene Brodsky
4c12f5a011 update nodes schema / typegen 2025-04-03 08:55:44 +10:00
Eugene Brodsky
d4655ea21a update uv.lock 2025-04-03 08:55:44 +10:00
Eugene Brodsky
752b62d0b5 refactor Dockerfile; get rid of multi-stage build; upgrade to python 3.12 2025-04-03 08:55:44 +10:00
Eugene Brodsky
9a0efb308d use uv.lock to pin dependencies 2025-04-03 08:55:44 +10:00
Eugene Brodsky
b4c276b50f upgrade pytorch and unpin some of the strict dependency pins to facilitate upgrading co-dependencies.
we will use uv.lock to ensure reproducibility
2025-04-03 08:55:44 +10:00
Riku
db03c196a1 translationBot(ui): update translation (German)
Currently translated at 66.8% (1230 of 1840 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2025-04-03 07:42:43 +11:00
Riccardo Giovanetti
6bc36b697d translationBot(ui): update translation (Italian)
Currently translated at 98.8% (1818 of 1840 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1816 of 1840 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1816 of 1839 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-04-03 07:42:43 +11:00
Linos
b7d71d3028 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1840 of 1840 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1838 of 1838 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-04-03 07:42:43 +11:00
psychedelicious
fa1ebd9d2f fix(ui): do not switch between images when focused on a tab element
Arrow keys should only navigate between tabs, not gallery images.
2025-04-03 07:40:10 +11:00
psychedelicious
eed5d02069 fix(ui): handling for invalid edges when loading workflows
Previously, reactflow appears to have handled an edge case when using its `applyChanges` utility. If a change was provided without an item, it would skip that change. For example, an "add edge" change that somehow passed `null` as the edge, instead of a valid edge.

In our workflow loading and validation logic, invalid edges were removed from the array using `delete edges[i]`. This left "holes" in the array of edges. We then asked `reactflow` to add these edges to state. When it encountered one of the "holes", it skipped over it.

In a recent release (unsure which, somewhere between the latest v11 and ~v12.4) this seems to have changed. It no longer skips over the "holes" and instead trusts the data. This can cause a couple issues:
- Error when loading the workflow if `reactflow` attempt to do anything with the nonexistent edge.
- If somehow the workflow makes it into state with "holes" in the array of edges, all sorts of other stuff breaks when our code does anything with the nonexistent edge.

Two-part fix:
- Update the invalid edge handling to not use `delete edges[i]`. Instead, as we check each edge, we add invalid ones to a set. Then, after all the checks are finished, filter out the invalid edges. The resultant edges array has no holes.
- Simplify the logic around setting nodes and edges in redux. Previously we were using `reactflow`'s `applyChanges` utils, but this does literally nothing except take extra CPU cycles. We can simply set the loaded nodes and edges directly in redux. Perhaps we were using `applyChanges` because it addressed the "holes" issue? Not sure. But we don't need it now.

Closes #7868
2025-04-03 07:37:49 +11:00
psychedelicious
3650d91045 chore(ui): bump @xyflow/react to latest 2025-04-03 07:37:49 +11:00
Eugene Brodsky
6c7d08cacb Change timm and controlnet-aux pins to fix LLaVA model support (#7846)
## Summary

`timm` below 1.0.0 prevents llava models from working (broken in
transformers). but `controlnet-aux` pins `timm` to an earlier version
because otherwise it was breaking the ZoeDepth controlnet.

we don't use ZoeDepth (replaced by depthAnything), and downgrading
controlnet-aux seems to be acceptable.

more context here:

https://github.com/huggingface/controlnet_aux/issues/106
https://github.com/huggingface/controlnet_aux/pull/101


Note that this results in some warnings on startup, stemming from
controlnet-aux:

![image](https://github.com/user-attachments/assets/fa908837-6154-42a2-a93b-eb5e363f5783)

we can probably silence the warnings as a separate enhancement

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-04-01 21:16:40 -04:00
Eugene Brodsky
bb1c40f222 Merge branch 'main' into pin-timm-for-llava 2025-04-01 21:10:30 -04:00
jazzhaiku
bfb117d0e0 Port LoRA to new classification API (#7849)
## Summary

- Port LoRA to new classification API
- Add 2 additional tests cases (ControlLora and Flux Diffusers LoRA)
- Moved `ModelOnDisk` to its own module

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-04-01 08:05:48 +11:00
jazzhaiku
b31c1022c3 Merge branch 'main' into lora-classification 2025-04-01 07:58:36 +11:00
Mary Hipp
a5851ca31c fix from leftover testing 2025-03-31 12:45:53 -04:00
Mary Hipp
77bf5c15bb GET presigned URLs directly instead of trying to use redirects 2025-03-31 12:45:53 -04:00
Eugene Brodsky
d26b7a1a12 Merge branch 'main' into pin-timm-for-llava 2025-03-31 11:37:29 -04:00
psychedelicious
595133463e feat(nodes): add methods to invalidate invocation typeadapters 2025-03-31 19:15:59 +11:00
psychedelicious
6155f9ff9e feat(nodes): move invocation/output registration to separate class 2025-03-31 19:15:59 +11:00
psychedelicious
7be87c8048 refactor(nodes): simpler logic for baseinvocation typeadapter handling 2025-03-31 19:15:59 +11:00
jazzhaiku
9868c3bfe3 Merge branch 'main' into lora-classification 2025-03-31 16:43:26 +11:00
psychedelicious
8b299d0bac chore: prep for v5.9.1 2025-03-31 13:40:07 +11:00
psychedelicious
a44bfb4658 fix(mm): handle FLUX models w/ diff in_channels keys
Before FLUX Fill was merged, we didn't do any checks for the model variant. We always returned "normal".

To determine if a model is a FLUX Fill model, we need to check the state dict for a specific key. Initially, this logic was too strict and rejected quantized FLUX models. This issue was resolved, but it turns out there is another failure mode - some fine-tunes use a different key.

This change further reduces the strictness, handling the alternate key and also falling back to "normal" if we don't see either key. This effectively restores the previous probing behaviour for all FLUX models.

Closes #7856
Closes #7859
2025-03-31 12:32:55 +11:00
psychedelicious
96fb5f6881 feat(ui): disable denoising strength when selected models flux fill 2025-03-31 11:31:02 +11:00
psychedelicious
4109ea5324 fix(nodes): expanded masks not 100% transparent outside the fade out region
The polynomial fit isn't perfect and we end up with alpha values of 1 instead of 0 when applying the mask. This in turn causes issues on canvas where outputs aren't 100% transparent and individual layer bbox calculations are incorrect.
2025-03-31 11:17:00 +11:00
jazzhaiku
f6c2ee5040 Merge branch 'main' into lora-classification 2025-03-31 09:01:16 +11:00
Billy
965753bf8b Ruff formatting 2025-03-31 08:18:00 +11:00
Billy
40c53ab95c Guard 2025-03-29 09:58:02 +11:00
psychedelicious
aaa6211625 chore(backend): ruff C420 2025-03-28 18:28:32 -04:00
psychedelicious
f6d770eac9 ci: add python 3.12 to test matrix 2025-03-28 18:28:32 -04:00
psychedelicious
47cb61cd62 ci: remove python 3.10 from test matrix 2025-03-28 18:28:32 -04:00
psychedelicious
b0fdc8ae1c ci: bump linux-cpu test runner to ubuntu 24.04 2025-03-28 18:28:32 -04:00
psychedelicious
ed9b30efda ci: bump uv to 0.6.10 2025-03-28 18:28:32 -04:00
psychedelicious
168e5eeff0 ci: use uv in typegen-checks
ci: use uv in typegen-checks to generate types

experiment: simulate typegen-checks failure

Revert "experiment: simulate typegen-checks failure"

This reverts commit f53c6876fe8311de236d974194abce93ed84930c.
2025-03-28 18:28:32 -04:00
psychedelicious
7acaa86bdf ci: get ci working with uv instead of pip
Lots of squashed experimentation heh:

ci: manually specify python version in tests

ci: whoops typo in ruff cmds

ci: specify python versions for uv python install

ci: install python verbosely

ci: try forcing python preference?

ci: try forcing python preference a different way?

ci: try in a venv?

ci: it works, but try without venv

ci: oh maybe we need --preview?

ci: poking it with a stick

ci: it works, add summary to pytest output

ci: fix pytest output

experiment: simulate test failure

Revert "experiment: simulate test failure"

This reverts commit b99ca512f6e61a2a04a1c0636d44018c11019954.

ci: just use default pytest output

cI: attempt again to use uv to install python

cI: attempt again again to use uv to install python

Revert "cI: attempt again again to use uv to install python"

This reverts commit 3cba861c90738081caeeb3eca97b60656ab63929.

Revert "cI: attempt again to use uv to install python"

This reverts commit b30f2277041dc999ed514f6c594c6d6a78f5c810.
2025-03-28 18:28:32 -04:00
psychedelicious
96c0393fe7 ci: bump ruff to 0.11.2
Need to bump both CI and pyproject.toml at the same time
2025-03-28 18:28:32 -04:00
psychedelicious
403f795c5e ci: remove linux-cuda-11_7 & linux-rocm-5_2 from test matrix
We only have CPU runners, so these tests are not doing anything useful.
2025-03-28 18:28:32 -04:00
psychedelicious
c0f88a083e ci: use uv for python-tests 2025-03-28 18:28:32 -04:00
psychedelicious
542b182899 ci: use uv for python-checks 2025-03-28 18:28:32 -04:00
Mary Hipp
3f58c68c09 fix tag invalidation 2025-03-28 10:52:27 -04:00
Mary Hipp
e50c7e5947 restore multiple key 2025-03-28 10:52:27 -04:00
Mary Hipp
4a83700fe4 if clientSideUploading is enabled, handle bulk uploads using that flow 2025-03-28 10:52:27 -04:00
Eugene Brodsky
c9992914d6 Merge branch 'main' into pin-timm-for-llava 2025-03-28 09:20:30 -04:00
jazzhaiku
c25f6d1f84 Merge branch 'main' into lora-classification 2025-03-28 12:32:22 +11:00
jazzhaiku
a53e1ccf08 Small improvements (#7842)
## Summary

- Extend `ModelOnDisk` with caching, type hints, default args
- Fail early if there is an error classifying a config

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-28 12:21:41 +11:00
jazzhaiku
1af9930951 Merge branch 'main' into small-improvements 2025-03-28 12:11:09 +11:00
Billy
c276c1cbee Comment 2025-03-28 10:57:46 +11:00
Billy
c619348f29 Extract ModelOnDisk to its own module 2025-03-28 10:35:13 +11:00
psychedelicious
c6f96613fc chore(ui): typegen 2025-03-28 08:14:06 +11:00
psychedelicious
258bf736da fix(nodes): handle zero fade size (e.g. mask blur 0)
Closes #7850
2025-03-28 08:14:06 +11:00
Billy
0d75c99476 Caching 2025-03-27 17:55:09 +11:00
Billy
323d409fb6 Make ruff happy 2025-03-27 17:47:57 +11:00
Billy
f251722f56 LoRA classification API 2025-03-27 17:47:01 +11:00
jazzhaiku
c9dc27afbb Merge branch 'main' into small-improvements 2025-03-27 08:14:48 +11:00
Billy
efd14ec0e4 Make ruff happy 2025-03-27 08:11:39 +11:00
Billy
21ee2b6251 Merge branch 'small-improvements' of github.com:invoke-ai/InvokeAI into small-improvements 2025-03-27 08:10:38 +11:00
Billy
82dd2d508f Deprecate checkpoint as file, diffusers as directory terminology 2025-03-27 08:10:12 +11:00
Eugene Brodsky
3f12a43e75 remove pin for controlnet-aux and pin timm to a version that works with llava
timm < 1.0.0 prevents llava models from working (broken in transformers). but controlnet-aux pinned it to an earlier version because otherwise it was breaking the ZoeDepth controlnet.

we don't use ZoeDepth (replaced by depthAnything), and downgrading controlnet-aux seems to be acceptable.

more context here:

https://github.com/huggingface/controlnet_aux/issues/106
https://github.com/huggingface/controlnet_aux/pull/101
2025-03-26 16:58:18 -04:00
jazzhaiku
5a59f6e3b8 Merge branch 'main' into small-improvements 2025-03-27 07:38:13 +11:00
Billy
60b5aef16a Log error -> warning 2025-03-27 06:56:22 +11:00
Billy
0e8b5484d5 Error handling 2025-03-26 19:31:57 +11:00
Billy
454506c83e Type hints 2025-03-26 19:12:49 +11:00
Billy
8f6ab67376 Logs 2025-03-26 16:34:32 +11:00
Billy
5afcc7778f Redundant 2025-03-26 16:32:19 +11:00
Billy
325e07d330 Error handling 2025-03-26 16:30:45 +11:00
Billy
a016bdc159 Add todo 2025-03-26 16:17:26 +11:00
Billy
a14f0b2864 Fail early on invalid config 2025-03-26 16:10:32 +11:00
Billy
721483318a Extend ModelOnDisk 2025-03-26 16:10:00 +11:00
58 changed files with 5844 additions and 975 deletions

View File

@@ -1,9 +1,11 @@
*
!invokeai
!pyproject.toml
!uv.lock
!docker/docker-entrypoint.sh
!LICENSE
**/dist
**/node_modules
**/__pycache__
**/*.egg-info
**/*.egg-info

View File

@@ -97,6 +97,8 @@ jobs:
context: .
file: docker/Dockerfile
platforms: ${{ env.PLATFORMS }}
build-args: |
GPU_DRIVER=${{ matrix.gpu-driver }}
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.12'
cache: pip
cache-dependency-path: pyproject.toml

View File

@@ -34,6 +34,9 @@ on:
jobs:
python-checks:
env:
# uv requires a venv by default - but for this, we can simply use the system python
UV_SYSTEM_PYTHON: 1
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
@@ -57,25 +60,19 @@ jobs:
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
- name: setup uv
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
uses: astral-sh/setup-uv@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff==0.9.9
shell: bash
version: '0.6.10'
enable-cache: true
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff check --output-format=github .
run: uv tool run ruff@0.11.2 check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff format --check .
run: uv tool run ruff@0.11.2 format --check .
shell: bash

View File

@@ -39,24 +39,15 @@ jobs:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
- '3.12'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
os: ubuntu-24.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
@@ -70,6 +61,8 @@ jobs:
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
UV_SYSTEM_PYTHON: 1
steps:
- name: checkout
# https://github.com/nschloe/action-cached-lfs-checkout
@@ -92,20 +85,25 @@ jobs:
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup uv
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: astral-sh/setup-uv@v5
with:
version: '0.6.10'
enable-cache: true
python-version: ${{ matrix.python-version }}
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
UV_INDEX: ${{ matrix.extra-index-url }}
run: uv pip install --editable ".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}

View File

@@ -54,17 +54,25 @@ jobs:
- 'pyproject.toml'
- 'invokeai/**'
- name: setup uv
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: astral-sh/setup-uv@v5
with:
version: '0.6.10'
enable-cache: true
python-version: '3.11'
- name: setup python
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
python-version: '3.11'
- name: install python dependencies
- name: install dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: pip3 install --use-pep517 --editable="."
env:
UV_INDEX: ${{ matrix.extra-index-url }}
run: uv pip install --editable .
- name: install frontend dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
@@ -77,7 +85,7 @@ jobs:
- name: generate schema
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: make frontend-typegen
run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen
shell: bash
- name: compare files

2
.nvmrc
View File

@@ -1 +1 @@
v22.12.0
v22.14.0

View File

@@ -1,77 +1,6 @@
# syntax=docker/dockerfile:1.4
## Builder stage
FROM library/ubuntu:24.04 AS builder
ARG DEBIAN_FRONTEND=noninteractive
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt-get install -y \
build-essential \
git
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
ENV VIRTUAL_ENV=/opt/venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ENV INVOKEAI_SRC=/opt/invokeai
ENV PYTHON_VERSION=3.11
ENV UV_PYTHON=3.11
ENV UV_COMPILE_BYTECODE=1
ENV UV_LINK_MODE=copy
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
ARG GPU_DRIVER=cuda
# unused but available
ARG BUILDPLATFORM
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
RUN mkdir -p ${VIRTUAL_ENV} && \
mkdir -p ${INVOKEAI_SRC} && \
chmod -R a+w /opt && \
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
USER ubuntu
# Install python
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
uv python install ${PYTHON_VERSION}
WORKDIR ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv sync --no-install-project
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
COPY invokeai invokeai
COPY pyproject.toml .
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv sync
#### Build the Web UI ------------------------------------
#### Web UI ------------------------------------
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
@@ -85,69 +14,89 @@ RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN npx vite build
#### Runtime stage ---------------------------------------
## Backend ---------------------------------------
FROM library/ubuntu:24.04 AS runtime
FROM library/ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
apt update && apt install -y --no-install-recommends \
ca-certificates \
git \
gosu \
libglib2.0-0 \
libgl1 \
libglx-mesa0 \
build-essential \
libopencv-dev \
libstdc++-10-dev
RUN apt update && apt install -y --no-install-recommends \
git \
curl \
vim \
tmux \
ncdu \
iotop \
bzip2 \
gosu \
magic-wormhole \
libglib2.0-0 \
libgl1 \
libglx-mesa0 \
build-essential \
libopencv-dev \
libstdc++-10-dev &&\
apt-get clean && apt-get autoclean
ENV \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
VIRTUAL_ENV=/opt/venv \
INVOKEAI_SRC=/opt/invokeai \
PYTHON_VERSION=3.12 \
UV_PYTHON=3.12 \
UV_COMPILE_BYTECODE=1 \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
UV_INDEX="https://download.pytorch.org/whl/cu124" \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
PATH="/opt/venv/bin:$PATH" \
CONTAINER_UID=${CONTAINER_UID:-1000} \
CONTAINER_GID=${CONTAINER_GID:-1000}
ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV PYTHON_VERSION=3.11
ENV INVOKEAI_ROOT=/invokeai
ENV INVOKEAI_HOST=0.0.0.0
ENV INVOKEAI_PORT=9090
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
ARG GPU_DRIVER=cuda
# Install `uv` for package management
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
USER ubuntu
RUN uv python install ${PYTHON_VERSION}
USER root
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
# --link requires buldkit w/ dockerfile syntax 1.4
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
RUN --mount=type=cache,target=/root/.cache/uv \
uv python install ${PYTHON_VERSION} && \
# chmod --recursive a+rX /root/.local/share/uv/python
chmod 711 /root
WORKDIR ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv sync --frozen
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
CMD ["invokeai-web"]
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# add sources last to minimize image changes on code changes
COPY invokeai ${INVOKEAI_SRC}/invokeai

View File

@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
With the modifications made, the install command should look something like this:
```sh
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
```
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.

View File

@@ -43,10 +43,10 @@ The following commands vary depending on the version of Invoke being installed a
3. Create a virtual environment in that directory:
```sh
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
uv venv --relocatable --prompt invoke --python 3.12 --python-preference only-managed .venv
```
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
This command creates a portable virtual environment at `.venv` complete with a portable python 3.12. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
4. Activate the virtual environment:
@@ -88,13 +88,13 @@ The following commands vary depending on the version of Invoke being installed a
8. Install the `invokeai` package. Substitute the package specifier and version.
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
```
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
```
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:

View File

@@ -41,7 +41,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
@@ -49,19 +49,19 @@ Check that your system has an up-to-date Python installed by running `python3 --
=== "Windows"
- Install python 3.11 with [an official installer].
- Install python with [an official installer].
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
- You may need to install [Microsoft Visual C++ Redistributable].
=== "macOS"
- Install python 3.11 with [an official installer].
- Install python with [an official installer].
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
=== "Linux"
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
- Installing python varies depending on your system. We recommend [using `uv` to manage your python installation](https://docs.astral.sh/uv/concepts/python-versions/#installing-a-python-version).
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
## Drivers

View File

@@ -37,7 +37,13 @@ from invokeai.app.services.style_preset_records.style_preset_records_sqlite impo
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
FLUXConditioningInfo,
SD3ConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@@ -101,10 +107,25 @@ class ApiDependencies:
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = ObjectSerializerForwardCache(
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
ObjectSerializerDisk[torch.Tensor](
output_folder / "tensors",
safe_globals=[torch.Tensor],
ephemeral=True,
),
max_cache_size=0,
)
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
ObjectSerializerDisk[ConditioningFieldData](
output_folder / "conditioning",
safe_globals=[
ConditioningFieldData,
BasicConditioningInfo,
SDXLConditioningInfo,
FLUXConditioningInfo,
SD3ConditioningInfo,
],
ephemeral=True,
),
)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")

View File

@@ -96,6 +96,22 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
class ImageUploadEntry(BaseModel):
image_dto: ImageDTO = Body(description="The image DTO")
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
@images_router.post("/", operation_id="create_image_upload_entry")
async def create_image_upload_entry(
width: int = Body(description="The width of the image"),
height: int = Body(description="The height of the image"),
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
) -> ImageUploadEntry:
"""Uploads an image from a URL, not implemented"""
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),

View File

@@ -8,6 +8,7 @@ import sys
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from functools import lru_cache
from inspect import signature
from typing import (
TYPE_CHECKING,
@@ -27,7 +28,6 @@ import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import (
FieldKind,
@@ -100,37 +100,6 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -173,76 +142,16 @@ class BaseInvocation(ABC, BaseModel):
All invocations must use the `@invocation` decorator to provide their unique type.
"""
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
return cls.model_fields["type"].default
@classmethod
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
def invalidate_typeadapter(cls) -> None:
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
the updated allowlist and denylist."""
cls._typeadapter_needs_update = True
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in BaseInvocation.get_invocations())
@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
@@ -340,6 +249,105 @@ class BaseInvocation(ABC, BaseModel):
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
class InvocationRegistry:
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()
@classmethod
@lru_cache(maxsize=1)
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation types.
This is used to parse serialized invocations into the correct invocation class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
@classmethod
def invalidate_invocation_typeadapter(cls) -> None:
"""Invalidates the cached invocation type adapter."""
cls.get_invocation_typeadapter.cache_clear()
@classmethod
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[type[BaseInvocation]] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in cls.get_invocation_classes()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in cls.get_invocation_classes())
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()
@classmethod
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
This is used to parse serialized invocation outputs into the correct invocation output class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
@classmethod
def invalidate_output_typeadapter(cls) -> None:
"""Invalidates the cached invocation output type adapter."""
cls.get_output_typeadapter.cache_clear()
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
"is_intermediate",
@@ -453,8 +461,8 @@ def invocation(
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in BaseInvocation.get_invocation_types():
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
@@ -539,8 +547,7 @@ def invocation(
)
cls.__doc__ = docstring
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
BaseInvocation.register_invocation(cls) # type: ignore
InvocationRegistry.register_invocation(cls)
return cls
@@ -565,7 +572,7 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in BaseInvocationOutput.get_output_types():
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
@@ -586,7 +593,7 @@ def invocation_output(
)
cls.__doc__ = docstring
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
InvocationRegistry.register_output(cls)
return cls

View File

@@ -1089,12 +1089,13 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
If the fade size is 0, the mask is returned as-is.
"""
mask: ImageField = InputField(description="The mask to expand")
@@ -1104,6 +1105,11 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
if self.fade_size_px == 0:
# If the fade size is 0, just return the mask as-is.
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
np_mask = numpy.array(pil_mask)
# Threshold the mask to create a binary mask - 0 for black, 255 for white
@@ -1141,8 +1147,21 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
coeffs = numpy.polyfit(x_control, y_control, 3)
poly = numpy.poly1d(coeffs)
# Evaluate and clip the smooth mapping
feather = numpy.clip(poly(d_norm), 0, 1)
# Evaluate the polynomial
feather = poly(d_norm)
# The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
# even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
# polynomial fit, which is a best approximation of the control points but not an exact match.
# When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
# have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.
# Force pixels at or beyond the fade distance to exactly 1.0
feather = numpy.where(d_norm >= 1.0, 1.0, feather)
# Clip any other values to ensure they're in the valid range [0,1]
feather = numpy.clip(feather, 0, 1)
# Build final image.
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))

View File

@@ -21,10 +21,16 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
:param output_dir: The folder where the serialized objects will be stored
:param safe_globals: A list of types to be added to the safe globals for torch serialization
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
"""
def __init__(self, output_dir: Path, ephemeral: bool = False):
def __init__(
self,
output_dir: Path,
safe_globals: list[type],
ephemeral: bool = False,
) -> None:
super().__init__()
self._ephemeral = ephemeral
self._base_output_dir = output_dir
@@ -42,6 +48,8 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
self.__obj_class_name: Optional[str] = None
torch.serialization.add_safe_globals(safe_globals) if safe_globals else None
def load(self, name: str) -> T:
file_path = self._get_path(name)
try:

View File

@@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationRegistry,
invocation,
invocation_output,
)
@@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation):
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
@@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return InvocationRegistry.get_output_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}

View File

@@ -4,7 +4,10 @@ from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.baseinvocation import (
InvocationRegistry,
UIConfigBase,
)
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
@@ -56,14 +59,14 @@ def get_openapi_func(
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
for invocation in InvocationRegistry.get_invocation_classes():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from invokeai.backend.model_manager.legacy_probe import CkptType
def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
"""Gets the in channels from the state dict."""
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use
# "model.diffusion_model.img_in.weight". Known models that use the latter key:
# - https://civitai.com/models/885098?modelVersionId=990775
# - https://civitai.com/models/1018060?modelVersionId=1596255
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
keys = {"img_in.weight", "model.diffusion_model.img_in.weight"}
for key in keys:
val = state_dict.get(key)
if val is not None:
return val.shape[1]
return None

View File

@@ -30,19 +30,18 @@ from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ClipVariantType,
FluxLoRAFormat,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
@@ -51,9 +50,8 @@ from invokeai.backend.model_manager.taxonomy import (
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings
logger = logging.getLogger(__name__)
@@ -97,68 +95,6 @@ class ControlAdapterDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
def hash(self):
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self):
if self.format_type == ModelFormat.Checkpoint:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
return state_dict
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
@@ -233,16 +169,18 @@ class ModelConfigBase(ABC, BaseModel):
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
if not config_cls.matches(mod):
continue
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
@@ -282,12 +220,12 @@ class ModelConfigBase(ABC, BaseModel):
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])
if "variant" in overrides:
overrides["variant"] = ModelVariantType(overrides["variant"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
@@ -344,6 +282,38 @@ class LoRAConfigBase(ABC, BaseModel):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd)
mod.cache[key] = value
return value
@classmethod
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
if cls.flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException("Unknown LoRA type")
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
@@ -359,11 +329,40 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return False
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if type(key) is int:
continue
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return True
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return True
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
@@ -387,11 +386,26 @@ class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, Mod
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_file():
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
return any(wf.exists() for wf in weight_files)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models."""
@@ -563,7 +577,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
if mod.path.is_file():
return False
config_path = mod.path / "config.json"

View File

@@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
@@ -564,7 +565,14 @@ class CheckpointProbeBase(ProbeBase):
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if base_type == BaseModelType.Flux:
in_channels = state_dict["img_in.weight"].shape[1]
in_channels = get_flux_in_channels_from_state_dict(state_dict)
if in_channels is None:
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
logger.warning(
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
)
return ModelVariantType.Normal
# FLUX Model variant types are distinguished by input channels:
# - Unquantized Dev and Schnell have in_channels=64

View File

@@ -0,0 +1,96 @@
from pathlib import Path
from typing import Any, Optional, TypeAlias
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.silence_warnings import SilenceWarnings
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
# This prevents redundant computations during matching and parsing
self.cache = {"_CACHED_STATE_DICTS": {}}
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self) -> int:
if self.path.is_file():
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self) -> set[Path]:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
return None
weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
sd_cache = self.cache["_CACHED_STATE_DICTS"]
if path in sd_cache:
return sd_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
return state_dict

View File

@@ -126,4 +126,13 @@ class ModelSourceType(str, Enum):
HFRepoID = "hf_repo_id"
class FluxLoRAFormat(str, Enum):
"""Flux LoRA formats."""
Diffusers = "flux.diffusers"
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]

View File

@@ -0,0 +1,24 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
def flux_format_from_state_dict(state_dict):
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
return FluxLoRAFormat.OneTrainer
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):
return FluxLoRAFormat.Diffusers
elif is_state_dict_likely_flux_control(state_dict):
return FluxLoRAFormat.Control
else:
return None

View File

@@ -69,6 +69,9 @@ class SD3ConditioningInfo:
@dataclass
class ConditioningFieldData:
# If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
# invokeai/app/api/dependencies.py, adding the types to the list of safe globals. If you do not, torch will be
# unable to deserialize the object and will raise an error.
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]

View File

@@ -62,7 +62,7 @@
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.6.1",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.4.2",
"@xyflow/react": "^12.5.3",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.0.0",
@@ -162,5 +162,6 @@
},
"engines": {
"pnpm": "8"
}
},
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
}

View File

@@ -36,8 +36,8 @@ dependencies:
specifier: ^1.3.0
version: 1.3.0
'@xyflow/react':
specifier: ^12.4.2
version: 12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
specifier: ^12.5.3
version: 12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
async-mutex:
specifier: ^0.5.0
version: 0.5.0
@@ -3323,7 +3323,7 @@ packages:
/@types/d3-drag@3.0.7:
resolution: {integrity: sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==}
dependencies:
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/d3-interpolate@3.0.4:
@@ -3332,21 +3332,21 @@ packages:
'@types/d3-color': 3.1.3
dev: false
/@types/d3-selection@3.0.10:
resolution: {integrity: sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg==}
/@types/d3-selection@3.0.11:
resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==}
dev: false
/@types/d3-transition@3.0.8:
resolution: {integrity: sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==}
/@types/d3-transition@3.0.9:
resolution: {integrity: sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==}
dependencies:
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/d3-zoom@3.0.8:
resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==}
dependencies:
'@types/d3-interpolate': 3.0.4
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/diff-match-patch@1.0.36:
@@ -3951,28 +3951,28 @@ packages:
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
dev: false
/@xyflow/react@12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-AFJKVc/fCPtgSOnRst3xdYJwiEcUN9lDY7EO/YiRvFHYCJGgfzg+jpvZjkTOnBLGyrMJre9378pRxAc3fsR06A==}
/@xyflow/react@12.5.3(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-saovy/aQRoW8qQoIqMFUtmC3F6oEV7n6+J1pVbhSG45NI/hOFvK0qozsIPKqX5Va6lGQnkl/o53NHLja3NiweQ==}
peerDependencies:
react: '>=17'
react-dom: '>=17'
dependencies:
'@xyflow/system': 0.0.50
'@xyflow/system': 0.0.53
classcat: 5.0.5
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
zustand: 4.5.5(@types/react@18.3.11)(react@18.3.1)
zustand: 4.5.6(@types/react@18.3.11)(react@18.3.1)
transitivePeerDependencies:
- '@types/react'
- immer
dev: false
/@xyflow/system@0.0.50:
resolution: {integrity: sha512-HVUZd4LlY88XAaldFh2nwVxDOcdIBxGpQ5txzwfJPf+CAjj2BfYug1fHs2p4yS7YO8H6A3EFJQovBE8YuHkAdg==}
/@xyflow/system@0.0.53:
resolution: {integrity: sha512-QTWieiTtvNYyQAz1fxpzgtUGXNpnhfh6vvZa7dFWpWS2KOz6bEHODo/DTK3s07lDu0Bq0Db5lx/5M5mNjb9VDQ==}
dependencies:
'@types/d3-drag': 3.0.7
'@types/d3-selection': 3.0.10
'@types/d3-transition': 3.0.8
'@types/d3-selection': 3.0.11
'@types/d3-transition': 3.0.9
'@types/d3-zoom': 3.0.8
d3-drag: 3.0.0
d3-selection: 3.0.0
@@ -9123,6 +9123,14 @@ packages:
react: 18.3.1
dev: false
/use-sync-external-store@1.5.0(react@18.3.1):
resolution: {integrity: sha512-Rb46I4cGGVBmjamjphe8L/UnvJD+uPPtTkNvX5mZgqdbavhI4EbgIWJiIHXJ8bc/i9EQGPRh4DwEURJ552Do0A==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
dependencies:
react: 18.3.1
dev: false
/util-deprecate@1.0.2:
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
dev: true
@@ -9567,8 +9575,8 @@ packages:
/zod@3.23.8:
resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
/zustand@4.5.5(@types/react@18.3.11)(react@18.3.1):
resolution: {integrity: sha512-+0PALYNJNgK6hldkgDq2vLrw5f6g/jCInz52n9RTpropGgeAf/ioFUCdtsjCqu4gNhW9D01rUQBROoRjdzyn2Q==}
/zustand@4.5.6(@types/react@18.3.11)(react@18.3.1):
resolution: {integrity: sha512-ibr/n1hBzLLj5Y+yUcU7dYw8p6WnIVzdJbnX+1YpaScvZVF2ziugqHs+LAmHw4lWO9c/zRj+K1ncgWDQuthEdQ==}
engines: {node: '>=12.7.0'}
peerDependencies:
'@types/react': '>=16.8'
@@ -9584,5 +9592,5 @@ packages:
dependencies:
'@types/react': 18.3.11
react: 18.3.1
use-sync-external-store: 1.2.2(react@18.3.1)
use-sync-external-store: 1.5.0(react@18.3.1)
dev: false

View File

@@ -116,7 +116,10 @@
"combinatorial": "Kombinatorisch",
"saveChanges": "Änderungen speichern",
"error_withCount_one": "{{count}} Fehler",
"error_withCount_other": "{{count}} Fehler"
"error_withCount_other": "{{count}} Fehler",
"value": "Wert",
"label": "Label",
"systemInformation": "Systeminformationen"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -695,7 +698,10 @@
"guidance": "Führung",
"coherenceMode": "Modus",
"recallMetadata": "Metadaten abrufen",
"gaussianBlur": "Gaußsche Unschärfe"
"gaussianBlur": "Gaußsche Unschärfe",
"sendToUpscale": "An Hochskalieren senden",
"useCpuNoise": "CPU-Rauschen verwenden",
"sendToCanvas": "An Leinwand senden"
},
"settings": {
"displayInProgress": "Zwischenbilder anzeigen",
@@ -1328,7 +1334,8 @@
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
"loadingTemplates": "Lade {{name}}",
"missingSourceOrTargetHandle": "Fehlender Quell- oder Zielgriff",
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten"
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten",
"showEdgeLabelsHelp": "Beschriftungen an Kanten anzeigen, um die verknüpften Knoten zu kennzeichnen"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -115,7 +115,8 @@
"error_withCount_many": "{{count}} errori",
"error_withCount_other": "{{count}} errori",
"value": "Valore",
"label": "Etichetta"
"label": "Etichetta",
"systemInformation": "Informazioni di sistema"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -715,7 +716,8 @@
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
"collectionEmpty": "raccolta vuota"
"collectionEmpty": "raccolta vuota",
"batchNodeCollectionSizeMismatchNoGroupId": "Dimensione della raccolta di gruppo nel Lotto non corrisponde"
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -2365,8 +2367,9 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Flussi di lavoro: nuova e migliorata libreria dei flussi di lavoro.",
"FLUX: supporto per FLUX Redux e FLUX Fill in Flussi di lavoro e Tela."
"Flussi di lavoro: supporto per menu a discesa di stringhe personalizzate nel Generatore di Flussi di lavoro.",
"FLUX: supporto per FLUX Fill in Flussi di lavoro e Tela.",
"LLaVA OneVision VLLM: supporto beta nei flussi di lavoro."
]
},
"system": {

View File

@@ -237,7 +237,10 @@
"row": "Hàng",
"board": "Bảng",
"saveChanges": "Lưu Thay Đổi",
"error_withCount_other": "{{count}} lỗi"
"error_withCount_other": "{{count}} lỗi",
"value": "Giá Trị",
"label": "Nhãn Tên",
"systemInformation": "Thông Tin Hệ Thống"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -2300,7 +2303,10 @@
"minimum": "Tối Thiểu",
"maximum": "Tối Đa",
"containerRowLayout": "Hộp Chứa (bố cục hàng)",
"containerColumnLayout": "Hộp Chứa (bố cục cột)"
"containerColumnLayout": "Hộp Chứa (bố cục cột)",
"resetOptions": "Tải Lại Lựa Chọn",
"addOption": "Thêm Lựa Chọn",
"dropdown": "Danh Sách Thả Xuống"
},
"yourWorkflows": "Workflow Của Bạn",
"browseWorkflows": "Khám Phá Workflow",
@@ -2316,7 +2322,8 @@
"view": "Xem",
"deselectAll": "Huỷ Chọn Tất Cả",
"noRecentWorkflows": "Không Có Workflows Gần Đây",
"recommended": "Có Thể Bạn Sẽ Cần"
"recommended": "Có Thể Bạn Sẽ Cần",
"emptyStringPlaceholder": "<xâu ký tự trống>"
},
"upscaling": {
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
@@ -2352,8 +2359,9 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Workflow: Thư Viện Workflow mới và đã được cải tiến.",
"FLUX: Hỗ trợ FLUX Redux & FLUX Fill trong Workflow và Canvas."
"Workflow: Hỗ trợ xâu ký tự thả xuống tùy chỉnh trong Trình Tạo Vùng Nhập.",
"FLUX: Hỗ trợ FLUX Fill trong Workflow và Canvas.",
"LLaVA OneVision VLLM: Hỗ trợ phiên bản Beta trong Workflow."
]
},
"upsell": {

View File

@@ -1,6 +1,8 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { toast } from 'features/toast/toast';
@@ -8,7 +10,8 @@ import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
const log = logger('gallery');
/**
@@ -34,19 +37,56 @@ let lastUploadedToastTimeout: number | null = null;
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
matcher: isAnyOf(imagesApi.endpoints.uploadImage.matchFulfilled, imageUploadedClientSide),
effect: (action, { dispatch, getState }) => {
const imageDTO = action.payload;
let imageDTO: ImageDTO;
let silent;
let isFirstUploadOfBatch = true;
if (imageUploadedClientSide.match(action)) {
imageDTO = action.payload.imageDTO;
silent = action.payload.silent;
isFirstUploadOfBatch = action.payload.isFirstUploadOfBatch;
} else if (imagesApi.endpoints.uploadImage.matchFulfilled(action)) {
imageDTO = action.payload;
silent = action.meta.arg.originalArgs.silent;
isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
} else {
return;
}
if (silent || imageDTO.is_intermediate) {
// If the image is silent or intermediate, we don't want to show a toast
return;
}
if (imageUploadedClientSide.match(action)) {
const categories = getCategories(imageDTO);
const boardId = imageDTO.board_id ?? 'none';
dispatch(
imagesApi.util.invalidateTags([
{
type: 'ImageList',
id: getListImagesUrl({
board_id: boardId,
categories,
}),
},
{
type: 'Board',
id: boardId,
},
{
type: 'BoardImagesTotal',
id: boardId,
},
])
);
}
const state = getState();
log.debug({ imageDTO }, 'Image uploaded');
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
// like toasts and switching the gallery view
return;
}
const boardId = imageDTO.board_id ?? 'none';
const DEFAULT_UPLOADED_TOAST = {
@@ -80,7 +120,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));

View File

@@ -73,6 +73,7 @@ export type AppConfig = {
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
@@ -81,7 +82,6 @@ export type AppConfig = {
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
maxImageUploadCount?: number;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];

View File

@@ -0,0 +1,121 @@
import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { useCallback } from 'react';
import { useCreateImageUploadEntryMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
type PresignedUrlResponse = {
fullUrl: string;
thumbnailUrl: string;
};
const isPresignedUrlResponse = (response: unknown): response is PresignedUrlResponse => {
return typeof response === 'object' && response !== null && 'fullUrl' in response && 'thumbnailUrl' in response;
};
export const useClientSideUpload = () => {
const dispatch = useAppDispatch();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const authToken = useStore($authToken);
const [createImageUploadEntry] = useCreateImageUploadEntryMutation();
const clientSideUpload = useCallback(
async (file: File, i: number): Promise<ImageDTO> => {
const image = new Image();
const objectURL = URL.createObjectURL(file);
image.src = objectURL;
let width = 0;
let height = 0;
let thumbnail: Blob | undefined;
await new Promise<void>((resolve) => {
image.onload = () => {
width = image.naturalWidth;
height = image.naturalHeight;
// Calculate thumbnail dimensions maintaining aspect ratio
let thumbWidth = width;
let thumbHeight = height;
if (width > height && width > 256) {
thumbWidth = 256;
thumbHeight = Math.round((height * 256) / width);
} else if (height > 256) {
thumbHeight = 256;
thumbWidth = Math.round((width * 256) / height);
}
const canvas = document.createElement('canvas');
canvas.width = thumbWidth;
canvas.height = thumbHeight;
const ctx = canvas.getContext('2d');
ctx?.drawImage(image, 0, 0, thumbWidth, thumbHeight);
canvas.toBlob(
(blob) => {
if (blob) {
thumbnail = blob;
// Clean up resources
URL.revokeObjectURL(objectURL);
image.src = ''; // Clear image source
image.remove(); // Remove the image element
canvas.width = 0; // Clear canvas
canvas.height = 0;
resolve();
}
},
'image/webp',
0.8
);
};
// Handle load errors
image.onerror = () => {
URL.revokeObjectURL(objectURL);
image.remove();
resolve();
};
});
const { presigned_url, image_dto } = await createImageUploadEntry({
width,
height,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
const response = await fetch(presigned_url, {
method: 'GET',
...(authToken && {
headers: {
Authorization: `Bearer ${authToken}`,
},
}),
}).then((res) => res.json());
if (!isPresignedUrlResponse(response)) {
throw new Error('Invalid response');
}
const fullUrl = response.fullUrl;
const thumbnailUrl = response.thumbnailUrl;
await fetch(fullUrl, {
method: 'PUT',
body: file,
});
await fetch(thumbnailUrl, {
method: 'PUT',
body: thumbnail,
});
dispatch(imageUploadedClientSide({ imageDTO: image_dto, silent: false, isFirstUploadOfBatch: i === 0 }));
return image_dto;
},
[autoAddBoardId, authToken, createImageUploadEntry, dispatch]
);
return clientSideUpload;
};

View File

@@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
@@ -15,6 +15,7 @@ import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
import { useClientSideUpload } from './useClientSideUpload';
type UseImageUploadButtonArgs =
| {
isDisabled?: boolean;
@@ -50,51 +51,65 @@ const log = logger('gallery');
*/
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const clientSideUpload = useClientSideUpload();
const { t } = useTranslation();
const onDropAccepted = useCallback(
async (files: File[]) => {
if (!allowMultiple) {
if (files.length > 1) {
log.warn('Multiple files dropped but only one allowed');
return;
}
if (files.length === 0) {
// Should never happen
log.warn('No files dropped');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: true,
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
const imageDTOs = await uploadImages(
files.map((file, i) => ({
try {
if (!allowMultiple) {
if (files.length > 1) {
log.warn('Multiple files dropped but only one allowed');
return;
}
if (files.length === 0) {
// Should never happen
log.warn('No files dropped');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
if (onUpload) {
onUpload(imageDTOs);
silent: true,
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
} else {
imageDTOs = await uploadImages(
files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
}
if (onUpload) {
onUpload(imageDTOs);
}
}
} catch (error) {
toast({
id: 'UPLOAD_FAILED',
title: t('toast.imageUploadFailed'),
status: 'error',
});
}
},
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
);
const onDropRejected = useCallback(
@@ -105,10 +120,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
file: rejection.file.path,
}));
log.error({ errors }, 'Invalid upload');
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
const description = t('toast.uploadFailedInvalidUploadDesc');
toast({
id: 'UPLOAD_FAILED',
@@ -120,7 +132,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
return;
}
},
[maxImageUploadCount, t]
[t]
);
const {
@@ -137,8 +149,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
onDropRejected,
disabled: isDisabled,
noDrag: true,
multiple: allowMultiple && (maxImageUploadCount === undefined || maxImageUploadCount > 1),
maxFiles: maxImageUploadCount,
multiple: allowMultiple,
});
return { getUploadButtonProps, getUploadInputProps, openUploader, request };

View File

@@ -14,8 +14,9 @@ import WavyLine from 'common/components/WavyLine';
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
const selectHasRasterLayersWithContent = createSelector(
selectActiveRasterLayerEntities,
@@ -26,6 +27,7 @@ export const ParamDenoisingStrength = memo(() => {
const img2imgStrength = useAppSelector(selectImg2imgStrength);
const dispatch = useAppDispatch();
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
const selectedModelConfig = useSelectedModelConfig();
const onChange = useCallback(
(v: number) => {
@@ -39,8 +41,24 @@ export const ParamDenoisingStrength = memo(() => {
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
const isDisabled = useMemo(() => {
if (!hasRasterLayersWithContent) {
// Denoising strength does nothing if there are no raster layers w/ content
return true;
}
if (
selectedModelConfig?.type === 'main' &&
selectedModelConfig?.base === 'flux' &&
selectedModelConfig.variant === 'inpaint'
) {
// Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint'
return true;
}
return false;
}, [hasRasterLayersWithContent, selectedModelConfig]);
return (
<FormControl isDisabled={!hasRasterLayersWithContent} p={1} justifyContent="space-between" h={8}>
<FormControl isDisabled={isDisabled} p={1} justifyContent="space-between" h={8}>
<Flex gap={3} alignItems="center">
<InformationalPopover feature="paramDenoisingStrength">
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
@@ -49,7 +67,7 @@ export const ParamDenoisingStrength = memo(() => {
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
)}
</Flex>
{hasRasterLayersWithContent ? (
{!isDisabled ? (
<>
<CompositeSlider
step={config.coarseStep}

View File

@@ -8,12 +8,13 @@ import { useStore } from '@nanostores/react';
import { getStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { $focusedRegion } from 'common/hooks/focus';
import { useClientSideUpload } from 'common/hooks/useClientSideUpload';
import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
@@ -53,13 +54,6 @@ const zUploadFile = z
(file) => ({ message: `File extension .${file.name.split('.').at(-1)} is not supported` })
);
const getFilesSchema = (max?: number) => {
if (max === undefined) {
return z.array(zUploadFile);
}
return z.array(zUploadFile).max(max);
};
const sx = {
position: 'absolute',
top: 2,
@@ -74,22 +68,19 @@ const sx = {
export const FullscreenDropzone = memo(() => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const activeTab = useAppSelector(selectActiveTab);
const isImageViewerOpen = useStore($imageViewer);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const clientSideUpload = useClientSideUpload();
const validateAndUploadFiles = useCallback(
(files: File[]) => {
async (files: File[]) => {
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
const parseResult = uploadFilesSchema.safeParse(files);
const parseResult = z.array(zUploadFile).safeParse(files);
if (!parseResult.success) {
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
const description = t('toast.uploadFailedInvalidUploadDesc');
toast({
id: 'UPLOAD_FAILED',
@@ -118,17 +109,23 @@ export const FullscreenDropzone = memo(() => {
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
}));
if (isClientSideUploadEnabled) {
for (const [i, file] of files.entries()) {
await clientSideUpload(file, i);
}
} else {
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
}));
uploadImages(uploadArgs);
uploadImages(uploadArgs);
}
},
[activeTab, isImageViewerOpen, maxImageUploadCount, t]
[activeTab, isImageViewerOpen, t, isClientSideUploadEnabled, clientSideUpload]
);
const onPaste = useCallback(

View File

@@ -1,31 +1,18 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { t } from 'i18next';
import { useMemo } from 'react';
import { PiUploadBold } from 'react-icons/pi';
export const GalleryUploadButton = () => {
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const uploadOptions = useMemo(() => ({ allowMultiple: maxImageUploadCount !== 1 }), [maxImageUploadCount]);
const uploadApi = useImageUploadButton(uploadOptions);
const uploadApi = useImageUploadButton({ allowMultiple: true });
return (
<>
<IconButton
size="sm"
alignSelf="stretch"
variant="link"
aria-label={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
tooltip={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
aria-label={t('accessibility.uploadImages')}
tooltip={t('accessibility.uploadImages')}
icon={<PiUploadBold />}
{...uploadApi.getUploadButtonProps()}
/>

View File

@@ -49,7 +49,11 @@ export const useGalleryHotkeys = () => {
useRegisteredHotkeys({
id: 'galleryNavLeft',
category: 'gallery',
callback: () => {
callback: (e) => {
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
return;
}
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
goPrev('arrow');
return;
@@ -71,7 +75,11 @@ export const useGalleryHotkeys = () => {
useRegisteredHotkeys({
id: 'galleryNavRight',
category: 'gallery',
callback: () => {
callback: (e) => {
// Skip the hotkey if the user is focused on a tab element - the arrow keys are used to navigate between tabs.
if (e.target instanceof HTMLElement && e.target.getAttribute('role') === 'tab') {
return;
}
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
goNext('arrow');
return;

View File

@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { ImageDTO } from 'services/api/types';
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
@@ -7,3 +8,9 @@ export const imageDownloaded = createAction('gallery/imageDownloaded');
export const imageCopiedToClipboard = createAction('gallery/imageCopiedToClipboard');
export const imageOpenedInNewTab = createAction('gallery/imageOpenedInNewTab');
export const imageUploadedClientSide = createAction<{
imageDTO: ImageDTO;
silent: boolean;
isFirstUploadOfBatch: boolean;
}>('gallery/imageUploadedClientSide');

View File

@@ -470,31 +470,8 @@ export const nodesSlice = createSlice({
builder.addCase(workflowLoaded, (state, action) => {
const { nodes, edges } = action.payload;
const changes: NodeChange<AnyNode>[] = [];
for (const node of nodes) {
if (node.type === 'notes') {
changes.push({
type: 'add',
item: {
...SHARED_NODE_PROPERTIES,
...node,
},
});
} else if (node.type === 'invocation') {
changes.push({
type: 'add',
item: {
...SHARED_NODE_PROPERTIES,
...node,
},
});
}
}
state.nodes = applyNodeChanges<AnyNode>(changes, []);
state.edges = applyEdgeChanges(
edges.map((edge) => ({ type: 'add', item: edge })),
[]
);
state.nodes = nodes.map((node) => ({ ...SHARED_NODE_PROPERTIES, ...node }));
state.edges = edges;
});
},
});

View File

@@ -79,4 +79,4 @@ export const isInvocationOutputSchemaObject = (
export const isInvocationFieldSchema = (
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
): obj is InvocationFieldSchema => !('$ref' in obj);
): obj is InvocationFieldSchema => 'field_kind' in obj;

View File

@@ -148,7 +148,11 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
}
}
}
edges.forEach((edge, i) => {
// Stash invalid edges here to be deleted later
const edgesToDelete = new Set<string>();
for (const edge of edges) {
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
const sourceNode = nodes.find(({ id }) => id === edge.source);
const targetNode = nodes.find(({ id }) => id === edge.target);
@@ -215,8 +219,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
}
if (issues.length) {
// This edge has some issues. Remove it.
delete edges[i];
edgesToDelete.add(edge.id);
const source = edge.type === 'default' ? `${edge.source}.${edge.sourceHandle}` : edge.source;
const target = edge.type === 'default' ? `${edge.source}.${edge.targetHandle}` : edge.target;
warnings.push({
@@ -225,7 +228,10 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
data: edge,
});
}
});
}
// Remove invalid edges
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
// Migrated exposed fields to form elements if they exist and the form does not
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!

View File

@@ -20,6 +20,7 @@ const initialConfigState: AppConfig = {
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
@@ -218,6 +219,5 @@ export const selectWorkflowFetchDebounce = createConfigSelector((config) => conf
export const selectMetadataFetchDebounce = createConfigSelector((config) => config.metadataFetchDebounce ?? 300);
export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models'));
export const selectMaxImageUploadCount = createConfigSelector((config) => config.maxImageUploadCount);
export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload);
export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);

View File

@@ -7,6 +7,8 @@ import type {
DeleteBoardResult,
GraphAndWorkflowResponse,
ImageDTO,
ImageUploadEntryRequest,
ImageUploadEntryResponse,
ListImagesArgs,
ListImagesResponse,
UploadImageArg,
@@ -287,6 +289,7 @@ export const imagesApi = api.injectEndpoints({
},
};
},
invalidatesTags: (result) => {
if (!result || result.is_intermediate) {
// Don't add it to anything
@@ -314,7 +317,13 @@ export const imagesApi = api.injectEndpoints({
];
},
}),
createImageUploadEntry: build.mutation<ImageUploadEntryResponse, ImageUploadEntryRequest>({
query: ({ width, height, board_id }) => ({
url: buildImagesUrl(),
method: 'POST',
body: { width, height, board_id },
}),
}),
deleteBoard: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }),
invalidatesTags: () => [
@@ -549,6 +558,7 @@ export const {
useGetImageWorkflowQuery,
useLazyGetImageWorkflowQuery,
useUploadImageMutation,
useCreateImageUploadEntryMutation,
useClearIntermediatesMutation,
useAddImagesToBoardMutation,
useRemoveImagesFromBoardMutation,

File diff suppressed because it is too large Load Diff

View File

@@ -354,3 +354,6 @@ export type UploadImageArg = {
*/
isFirstUploadOfBatch?: boolean;
};
export type ImageUploadEntryResponse = S['ImageUploadEntry'];
export type ImageUploadEntryRequest = paths['/api/v1/images/']['post']['requestBody']['content']['application/json'];

View File

@@ -1 +1 @@
__version__ = "5.9.0"
__version__ = "5.10.0dev1"

14
pins.json Normal file
View File

@@ -0,0 +1,14 @@
{
"python": "3.11",
"torchIndexUrl": {
"win32": {
"cuda": "https://download.pytorch.org/whl/cu126"
},
"linux": {
"cpu": "https://download.pytorch.org/whl/cpu",
"rocm": "https://download.pytorch.org/whl/rocm6.2.4",
"cuda": "https://download.pytorch.org/whl/cu126"
},
"darwin": {}
}
}

View File

@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "InvokeAI"
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
requires-python = ">=3.10, <3.12"
requires-python = ">=3.10, <3.13"
readme = { content-type = "text/markdown", file = "README.md" }
keywords = ["stable-diffusion", "AI"]
dynamic = ["version"]
@@ -33,12 +33,12 @@ classifiers = [
]
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==1.0.1",
"bitsandbytes==0.45.0; sys_platform!='darwin'",
"accelerate",
"bitsandbytes; sys_platform!='darwin'",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.31.0",
"diffusers[torch]",
"gguf==0.10.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.14", # needed for "mediapipeface" controlnet model
@@ -46,26 +46,27 @@ dependencies = [
"onnx==1.16.1",
"onnxruntime==1.19.2",
"opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3",
"safetensors==0.4.3",
"pytorch-lightning",
"safetensors",
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
"sentencepiece==0.2.0",
"spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch<2.5.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
"timm~=1.0.0",
"torch~=2.6.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
"torchmetrics",
"torchsde",
"torchvision",
"transformers==4.46.3",
"timm~=1.0.0", # controlnet-aux depends on a version of timm that breaks LLaVA. explicitly pinning `timm` here, which results in a downgrade of `controlnet-aux`. https://github.com/huggingface/controlnet_aux/pull/101. If this poses a problem, we need to decide between newer `controlnet_aux` and LLaVA, OR we can fork `controlnet_aux` and update the pin.
"transformers",
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.1",
"fastapi==0.111.0",
"huggingface-hub==0.26.1",
"pydantic-settings==2.2.1",
"pydantic==2.7.2",
"python-socketio==5.11.1",
"uvicorn[standard]==0.28.0",
"fastapi-events",
"fastapi",
"huggingface-hub",
"pydantic-settings",
"pydantic",
"python-socketio",
"uvicorn[standard]",
# Auxiliary dependencies, pinned only if necessary.
"albumentations",
@@ -90,11 +91,8 @@ dependencies = [
"pyreadline3",
"python-multipart",
"requests",
"rich~=13.3",
"scikit-image",
"semver~=3.0.1",
"test-tube",
"windows-curses; sys_platform=='win32'",
"humanize==4.12.1",
]
@@ -117,7 +115,7 @@ dependencies = [
]
"dev" = ["jurigged", "pudb", "snakeviz", "gprof2dot"]
"test" = [
"ruff~=0.9.9",
"ruff~=0.11.2",
"ruff-lsp~=0.0.62",
"mypy",
"pre-commit",

View File

@@ -22,9 +22,8 @@ from pathlib import Path
import humanize
import torch
from invokeai.backend.model_manager.config import ModelOnDisk
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import ModelFormat
def strip(v):
@@ -63,7 +62,7 @@ def load_stripped_model(path: Path, *args, **kwargs):
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
original = ModelOnDisk(original_model_path)
if original.format_type == ModelFormat.Checkpoint:
if original.path.is_file():
shutil.copy2(original.path, stripped_model_path)
else:
shutil.copytree(original.path, stripped_model_path, dirs_exist_ok=True)
@@ -71,7 +70,7 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.component_paths():
original_state_dict = ModelOnDisk.load_state_dict(component_path)
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)

View File

@@ -24,7 +24,7 @@ from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs
],
)
def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool):
sd = {k: None for k in sd_shapes}
sd = dict.fromkeys(sd_shapes)
assert is_state_dict_xlabs_controlnet(sd) == expected
@@ -37,7 +37,7 @@ def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expecte
],
)
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
sd = {k: None for k in sd_keys}
sd = dict.fromkeys(sd_keys)
assert is_state_dict_instantx_controlnet(sd) == expected

View File

@@ -19,7 +19,7 @@ from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_v2_state_dict import xl
@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes])
def test_is_state_dict_xlabs_ip_adapter(sd_shapes: dict[str, list[int]]):
# Construct a dummy state_dict.
sd = {k: None for k in sd_shapes}
sd = dict.fromkeys(sd_shapes)
assert is_state_dict_xlabs_ip_adapter(sd)

View File

@@ -5,7 +5,7 @@ from typing import Any
import pytest
from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.baseinvocation import InvocationRegistry
from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
@@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir):
# We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for
# subsequent graph validations
BaseInvocation.invalidate_typeadapter()
InvocationRegistry.invalidate_invocation_typeadapter()
# confirm graph validation fails when using denied node
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}})
@@ -284,7 +284,7 @@ def test_deny_nodes(patch_rootdir):
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}})
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
all_invocations = InvocationRegistry.get_invocation_classes()
has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
@@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir):
# Reset the config so that it doesn't affect other tests
get_config.cache_clear()
BaseInvocation.invalidate_typeadapter()
InvocationRegistry.invalidate_invocation_typeadapter()

View File

@@ -17,7 +17,6 @@ from invokeai.backend.model_manager.config import (
MainDiffusersConfig,
ModelConfigBase,
ModelConfigFactory,
ModelOnDisk,
get_model_discriminator_value,
)
from invokeai.backend.model_manager.legacy_probe import (
@@ -27,6 +26,7 @@ from invokeai.backend.model_manager.legacy_probe import (
get_default_settings_control_adapters,
get_default_settings_main,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util.logging import InvokeAILogger

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5acefb3658338a4126736e2da02cfef5a9ce6e2469564a6c7994ae34e8ef2e8a
size 192447

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:55aafd0f9b4ac2863361573b070320e13b800b2359a81a73878008bdffc3edfa
size 201040

View File

@@ -21,16 +21,18 @@ def count_files(path: Path):
@pytest.fixture
def obj_serializer(tmp_path: Path):
return ObjectSerializerDisk[MockDataclass](tmp_path)
return ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
@pytest.fixture
def fwd_cache(tmp_path: Path):
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
return ObjectSerializerForwardCache(
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass]), max_cache_size=2
)
def test_obj_serializer_disk_initializes(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
assert obj_serializer._output_dir == tmp_path
@@ -70,7 +72,7 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
assert obj_serializer._base_output_dir == tmp_path
assert obj_serializer._output_dir != tmp_path
@@ -78,21 +80,21 @@ def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
tempdir_path = obj_serializer._output_dir
del obj_serializer
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
tempdir_path = obj_serializer._output_dir
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
assert Path(obj_serializer._output_dir, obj_1_name).exists()
@@ -102,19 +104,19 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path):
tempdir = tmp_path / "tmpdir"
tempdir.mkdir()
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=True)
assert not tempdir.exists()
def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path):
tempdir = tmp_path / "tmpdir"
tempdir.mkdir()
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False)
ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass], ephemeral=False)
assert tempdir.exists()
def test_obj_serializer_disk_different_types(tmp_path: Path):
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path, safe_globals=[MockDataclass])
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer_1.save(obj_1)
obj_1_loaded = obj_serializer_1.load(obj_1_name)
@@ -123,19 +125,19 @@ def test_obj_serializer_disk_different_types(tmp_path: Path):
assert obj_1_loaded.foo == "bar"
assert obj_1_name.startswith("MockDataclass_")
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path, safe_globals=[int])
obj_2_name = obj_serializer_2.save(9001)
assert obj_serializer_2._obj_class_name == "int"
assert obj_serializer_2.load(obj_2_name) == 9001
assert obj_2_name.startswith("int_")
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path, safe_globals=[str])
obj_3_name = obj_serializer_3.save("foo")
assert obj_serializer_3._obj_class_name == "str"
assert obj_serializer_3.load(obj_3_name) == "foo"
assert obj_3_name.startswith("str_")
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path, safe_globals=[torch.Tensor])
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
obj_4_loaded = obj_serializer_4.load(obj_4_name)
assert obj_serializer_4._obj_class_name == "Tensor"

4704
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff