Compare commits

...

142 Commits

Author SHA1 Message Date
psychedelicious
8a8f4c593f wip 2025-04-02 06:42:01 +10:00
psychedelicious
29c78f0e5e wip 2025-04-01 15:45:59 +10:00
psychedelicious
501534e2e1 chore(ui): typegen 2025-04-01 08:49:28 +10:00
psychedelicious
50c7318004 feat(app): add is_published to workflow models 2025-04-01 08:48:06 +10:00
psychedelicious
7f14597012 refactor(app): clean up compose_mode_from_fields util 2025-04-01 08:46:27 +10:00
psychedelicious
dbe68b364f feat(ui): publish toast links to project dashboard 2025-04-01 08:22:48 +10:00
psychedelicious
0c7aa85a5c feat(ui): add badge to queue indicating if run is validation run 2025-04-01 08:22:48 +10:00
psychedelicious
703e1c8001 feat(ui): publish toasts do not auto-close 2025-04-01 08:22:48 +10:00
psychedelicious
b056c93ea3 feat(ui): disable invoke button during publish operation 2025-04-01 08:22:48 +10:00
psychedelicious
4289241943 feat(ui): "isInDeployFlow" -> "isInPublishFlow" 2025-04-01 08:22:48 +10:00
psychedelicious
51f5abf5f9 feat(ui): wip publish flow 2025-04-01 08:22:48 +10:00
psychedelicious
e59fa59ad7 feat(ui): wip publish flow 2025-04-01 08:22:48 +10:00
psychedelicious
2407cb64b3 feat(app): truncate invalid model config warning to 64 chars
Previously it logged the whole config and flooded the terminal output.
2025-04-01 08:22:48 +10:00
psychedelicious
70f704ab44 feat(ui): publish button works 2025-04-01 08:22:48 +10:00
psychedelicious
b786032b89 feat(ui): make validation run logic conditional 2025-04-01 08:22:48 +10:00
psychedelicious
e8cc06cc92 feat(ui): disable all workflow editor interaction while in deploy flow 2025-04-01 08:22:48 +10:00
psychedelicious
8e6c56c93d wip 2025-04-01 08:22:48 +10:00
psychedelicious
69d4ee7f93 chore(ui): bump @xyflow/react to latest 2025-04-01 08:22:48 +10:00
psychedelicious
567fd3e0da refactor(ui): standardize more workflow editor hooks to use Safe and OrThrow suffixes for clarity 2025-04-01 08:22:47 +10:00
psychedelicious
0b8f88e554 wip 2025-04-01 08:22:47 +10:00
psychedelicious
60f0c4bf99 refactor(ui): standardize more workflow editor hooks to use Safe and OrThrow suffixes for clarity 2025-04-01 08:22:47 +10:00
psychedelicious
900ec92ef1 tidy(ui): remove extraneous scrollable container 2025-04-01 08:22:47 +10:00
psychedelicious
2594768479 revert(ui): remove api_fields from zod workflow schema 2025-04-01 08:22:47 +10:00
psychedelicious
91ab81eca9 chore(ui): typegen 2025-04-01 08:22:47 +10:00
psychedelicious
b20c745c6e revert(app): remove api_fields from workflow pydantic model 2025-04-01 08:22:47 +10:00
psychedelicious
e41a37bca0 refactor(ui): generalize node field dnd to drag node fields vs node field form elements 2025-04-01 08:22:47 +10:00
psychedelicious
9ca44f27a5 feat(ui): rough out state mgmt for workflow api fields 2025-04-01 08:22:47 +10:00
psychedelicious
b9ddf67853 refactor(ui): rejiggle enqueue actions to support api validation runs 2025-04-01 08:22:47 +10:00
psychedelicious
afe088045f chore(ui): rename type BatchConfig -> EnqueueBatchArg 2025-04-01 08:22:47 +10:00
psychedelicious
09ca61a962 chore(ui): typegen 2025-04-01 08:22:47 +10:00
psychedelicious
dd69a96c03 feat(queue): move session count calculation in to Batch class, cache it, add pydantic validator for validation runs 2025-04-01 08:22:46 +10:00
psychedelicious
4a54e594d0 tests(ui): update test for workflow types 2025-04-01 08:22:46 +10:00
psychedelicious
936ed1960a feat(ui): add api_fields to zod schemas 2025-04-01 08:22:46 +10:00
psychedelicious
9fac7986c7 chore(ui): typegen 2025-04-01 08:22:46 +10:00
psychedelicious
e4b603f44e feat(app): add api_fields to workflow pydantic schema 2025-04-01 08:22:46 +10:00
psychedelicious
7edfe6edcf chore(ui): bump tsafe dep 2025-04-01 08:22:46 +10:00
jazzhaiku
bfb117d0e0 Port LoRA to new classification API (#7849)
## Summary

- Port LoRA to new classification API
- Add 2 additional tests cases (ControlLora and Flux Diffusers LoRA)
- Moved `ModelOnDisk` to its own module

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-04-01 08:05:48 +11:00
jazzhaiku
b31c1022c3 Merge branch 'main' into lora-classification 2025-04-01 07:58:36 +11:00
Mary Hipp
a5851ca31c fix from leftover testing 2025-03-31 12:45:53 -04:00
Mary Hipp
77bf5c15bb GET presigned URLs directly instead of trying to use redirects 2025-03-31 12:45:53 -04:00
psychedelicious
595133463e feat(nodes): add methods to invalidate invocation typeadapters 2025-03-31 19:15:59 +11:00
psychedelicious
6155f9ff9e feat(nodes): move invocation/output registration to separate class 2025-03-31 19:15:59 +11:00
psychedelicious
7be87c8048 refactor(nodes): simpler logic for baseinvocation typeadapter handling 2025-03-31 19:15:59 +11:00
jazzhaiku
9868c3bfe3 Merge branch 'main' into lora-classification 2025-03-31 16:43:26 +11:00
psychedelicious
8b299d0bac chore: prep for v5.9.1 2025-03-31 13:40:07 +11:00
psychedelicious
a44bfb4658 fix(mm): handle FLUX models w/ diff in_channels keys
Before FLUX Fill was merged, we didn't do any checks for the model variant. We always returned "normal".

To determine if a model is a FLUX Fill model, we need to check the state dict for a specific key. Initially, this logic was too strict and rejected quantized FLUX models. This issue was resolved, but it turns out there is another failure mode - some fine-tunes use a different key.

This change further reduces the strictness, handling the alternate key and also falling back to "normal" if we don't see either key. This effectively restores the previous probing behaviour for all FLUX models.

Closes #7856
Closes #7859
2025-03-31 12:32:55 +11:00
psychedelicious
96fb5f6881 feat(ui): disable denoising strength when selected models flux fill 2025-03-31 11:31:02 +11:00
psychedelicious
4109ea5324 fix(nodes): expanded masks not 100% transparent outside the fade out region
The polynomial fit isn't perfect and we end up with alpha values of 1 instead of 0 when applying the mask. This in turn causes issues on canvas where outputs aren't 100% transparent and individual layer bbox calculations are incorrect.
2025-03-31 11:17:00 +11:00
jazzhaiku
f6c2ee5040 Merge branch 'main' into lora-classification 2025-03-31 09:01:16 +11:00
Billy
965753bf8b Ruff formatting 2025-03-31 08:18:00 +11:00
Billy
40c53ab95c Guard 2025-03-29 09:58:02 +11:00
psychedelicious
aaa6211625 chore(backend): ruff C420 2025-03-28 18:28:32 -04:00
psychedelicious
f6d770eac9 ci: add python 3.12 to test matrix 2025-03-28 18:28:32 -04:00
psychedelicious
47cb61cd62 ci: remove python 3.10 from test matrix 2025-03-28 18:28:32 -04:00
psychedelicious
b0fdc8ae1c ci: bump linux-cpu test runner to ubuntu 24.04 2025-03-28 18:28:32 -04:00
psychedelicious
ed9b30efda ci: bump uv to 0.6.10 2025-03-28 18:28:32 -04:00
psychedelicious
168e5eeff0 ci: use uv in typegen-checks
ci: use uv in typegen-checks to generate types

experiment: simulate typegen-checks failure

Revert "experiment: simulate typegen-checks failure"

This reverts commit f53c6876fe8311de236d974194abce93ed84930c.
2025-03-28 18:28:32 -04:00
psychedelicious
7acaa86bdf ci: get ci working with uv instead of pip
Lots of squashed experimentation heh:

ci: manually specify python version in tests

ci: whoops typo in ruff cmds

ci: specify python versions for uv python install

ci: install python verbosely

ci: try forcing python preference?

ci: try forcing python preference a different way?

ci: try in a venv?

ci: it works, but try without venv

ci: oh maybe we need --preview?

ci: poking it with a stick

ci: it works, add summary to pytest output

ci: fix pytest output

experiment: simulate test failure

Revert "experiment: simulate test failure"

This reverts commit b99ca512f6e61a2a04a1c0636d44018c11019954.

ci: just use default pytest output

cI: attempt again to use uv to install python

cI: attempt again again to use uv to install python

Revert "cI: attempt again again to use uv to install python"

This reverts commit 3cba861c90738081caeeb3eca97b60656ab63929.

Revert "cI: attempt again to use uv to install python"

This reverts commit b30f2277041dc999ed514f6c594c6d6a78f5c810.
2025-03-28 18:28:32 -04:00
psychedelicious
96c0393fe7 ci: bump ruff to 0.11.2
Need to bump both CI and pyproject.toml at the same time
2025-03-28 18:28:32 -04:00
psychedelicious
403f795c5e ci: remove linux-cuda-11_7 & linux-rocm-5_2 from test matrix
We only have CPU runners, so these tests are not doing anything useful.
2025-03-28 18:28:32 -04:00
psychedelicious
c0f88a083e ci: use uv for python-tests 2025-03-28 18:28:32 -04:00
psychedelicious
542b182899 ci: use uv for python-checks 2025-03-28 18:28:32 -04:00
Mary Hipp
3f58c68c09 fix tag invalidation 2025-03-28 10:52:27 -04:00
Mary Hipp
e50c7e5947 restore multiple key 2025-03-28 10:52:27 -04:00
Mary Hipp
4a83700fe4 if clientSideUploading is enabled, handle bulk uploads using that flow 2025-03-28 10:52:27 -04:00
jazzhaiku
c25f6d1f84 Merge branch 'main' into lora-classification 2025-03-28 12:32:22 +11:00
jazzhaiku
a53e1ccf08 Small improvements (#7842)
## Summary

- Extend `ModelOnDisk` with caching, type hints, default args
- Fail early if there is an error classifying a config

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-28 12:21:41 +11:00
jazzhaiku
1af9930951 Merge branch 'main' into small-improvements 2025-03-28 12:11:09 +11:00
Billy
c276c1cbee Comment 2025-03-28 10:57:46 +11:00
Billy
c619348f29 Extract ModelOnDisk to its own module 2025-03-28 10:35:13 +11:00
psychedelicious
c6f96613fc chore(ui): typegen 2025-03-28 08:14:06 +11:00
psychedelicious
258bf736da fix(nodes): handle zero fade size (e.g. mask blur 0)
Closes #7850
2025-03-28 08:14:06 +11:00
Billy
0d75c99476 Caching 2025-03-27 17:55:09 +11:00
Billy
323d409fb6 Make ruff happy 2025-03-27 17:47:57 +11:00
Billy
f251722f56 LoRA classification API 2025-03-27 17:47:01 +11:00
psychedelicious
7004fde41b fix(mm): vllm model calculates its own size 2025-03-27 09:36:14 +11:00
jazzhaiku
c9dc27afbb Merge branch 'main' into small-improvements 2025-03-27 08:14:48 +11:00
Billy
efd14ec0e4 Make ruff happy 2025-03-27 08:11:39 +11:00
Billy
21ee2b6251 Merge branch 'small-improvements' of github.com:invoke-ai/InvokeAI into small-improvements 2025-03-27 08:10:38 +11:00
Billy
82dd2d508f Deprecate checkpoint as file, diffusers as directory terminology 2025-03-27 08:10:12 +11:00
psychedelicious
ffb5f6c6a6 chore: bump version to v5.9.0 2025-03-27 08:08:44 +11:00
psychedelicious
5c5fff9ecb chore(ui): update whatsnew 2025-03-27 08:08:44 +11:00
psychedelicious
9ca071819b chore(nodes): remove beta/prototype flag from a lot of stable nodes 2025-03-27 08:08:44 +11:00
psychedelicious
b14d8e8192 chore(nodes): mark llava_onevision_vllm as beta 2025-03-27 08:08:44 +11:00
jazzhaiku
5a59f6e3b8 Merge branch 'main' into small-improvements 2025-03-27 07:38:13 +11:00
Billy
60b5aef16a Log error -> warning 2025-03-27 06:56:22 +11:00
jazzhaiku
35222a8835 Taxonomy (#7833)
## Summary

This PR moves type definitions out of `config.py` into a new
`taxonomy.py` module.
The goal is to reduce clutter in `config.py`, and to resolve circular
import issues by isolating these types in a dedicated module with
(almost) no internal dependencies.
Because so many places import these definitions, these changes touch 73
files.

Additional changes:
- Removed star imports using "removestar" tool
- Added the commit to `.git-blame-ignore-revs` to avoid noise in git
blame history


## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-26 22:44:41 +11:00
Billy
0e8b5484d5 Error handling 2025-03-26 19:31:57 +11:00
Billy
454506c83e Type hints 2025-03-26 19:12:49 +11:00
Billy
8f6ab67376 Logs 2025-03-26 16:34:32 +11:00
Billy
5afcc7778f Redundant 2025-03-26 16:32:19 +11:00
Billy
325e07d330 Error handling 2025-03-26 16:30:45 +11:00
Billy
a016bdc159 Add todo 2025-03-26 16:17:26 +11:00
Billy
a14f0b2864 Fail early on invalid config 2025-03-26 16:10:32 +11:00
Billy
721483318a Extend ModelOnDisk 2025-03-26 16:10:00 +11:00
jazzhaiku
be04743649 Merge branch 'main' into taxonomy 2025-03-26 15:09:26 +11:00
psychedelicious
92f0c28d6c fix(ui): correctly render whitespace in strings in string generator previews
This is a visual issue - the underlying strings are not trimmed.

Closes #7830
2025-03-26 13:52:31 +11:00
Billy
a6b94e8ca4 Revert some files 2025-03-26 13:18:50 +11:00
Billy
00b11ef795 Git blame ignore revs 2025-03-26 12:56:04 +11:00
Billy
182580ff69 Imports 2025-03-26 12:55:10 +11:00
Billy
8e9d5c1187 Ruff formatting 2025-03-26 12:30:31 +11:00
Billy
99aac5870e Remove star imports 2025-03-26 12:27:00 +11:00
psychedelicious
c1b475c585 feat(ui): add getRuntimeConfig query and show it all in the about modal 2025-03-26 11:39:21 +11:00
psychedelicious
ec44e68cbf chore(ui): typegen 2025-03-26 11:39:21 +11:00
psychedelicious
73dbebbcc3 feat(api): add route to get app config and set config fields 2025-03-26 11:39:21 +11:00
psychedelicious
09f971467d feat(app): do not set port unless necessary 2025-03-26 11:39:21 +11:00
psychedelicious
2c71b0e873 fix(ui): long node titles overflow 2025-03-26 10:24:46 +11:00
Kevin Turner
92f69ac463 fix: make source location discovery more robust
The top-level `invokeai` package may have an obscured origin due to the way editible installs work, but it's much more likely that this module is from a specific file.
2025-03-26 10:12:36 +11:00
jazzhaiku
3b154df71a Import Smoke Test (#7835)
## Summary

This test imports all modules in the invokeai package and fails if there
are any exceptions.
Existing issues are excluded to avoid blocking main.

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-26 08:40:07 +11:00
Billy
64aa965160 Set ordering 2025-03-25 19:21:14 +11:00
Billy
d715c27d07 Add more known failures 2025-03-25 17:59:28 +11:00
Billy
515084577c Test all imports work 2025-03-25 17:45:22 +11:00
psychedelicious
7596c07a64 chore: prep for v5.9.0rc2 2025-03-25 10:21:23 +11:00
Kevin Turner
98fd1d949b fix: make dev_reload work for files in nodes/ 2025-03-25 10:04:17 +11:00
Linos
6312e6aa8f translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1832 of 1832 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-03-25 08:00:45 +11:00
Riccardo Giovanetti
6435f11bae translationBot(ui): update translation (Italian)
Currently translated at 98.7% (1815 of 1838 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1809 of 1832 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-03-25 08:00:45 +11:00
psychedelicious
1c69b9b1fa fix(ui): restore display: flex to image viewer and node editor
This was inadventently removed in #7786 and caused some minor layout overflow.
2025-03-25 07:44:07 +11:00
psychedelicious
731970ff88 fix(ui): use expanded mask for paste-back when inpainting 2025-03-25 00:03:13 +11:00
psychedelicious
038bac1614 feat(ui): make it clearer that we are doing scale before processing in graph builders 2025-03-25 00:03:13 +11:00
jazzhaiku
ed9efe7740 Port LLaVA to new API (#7817)
## Summary

- Port LLaVA model config to new classification API
- Add 2 test cases (stripped LLaVA models variants to git-lfs)

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-24 22:50:54 +11:00
jazzhaiku
ffa0beba7a Merge branch 'main' into llava 2025-03-24 15:17:33 +11:00
psychedelicious
75d793f1c4 fix(ui): siglip model translation key 2025-03-24 13:26:38 +11:00
psychedelicious
2b086917e0 chore(ui): lint 2025-03-24 13:24:13 +11:00
psychedelicious
a9f2738086 feat(ui): layout improvements for string field collection input 2025-03-24 13:24:13 +11:00
psychedelicious
3a56799ea5 tidy(ui): remove unused code 2025-03-24 13:24:13 +11:00
psychedelicious
3162ce94dc tidy(ui): use settings for node field settings instead of config
Non-functional naming change to clarify the logic
2025-03-24 13:24:13 +11:00
psychedelicious
c0dc6ac4e1 fix(ui): issue where string drop-down options are not removed when changing component to a different type 2025-03-24 13:24:13 +11:00
psychedelicious
fed1995525 chore(ui): lint 2025-03-24 13:24:13 +11:00
psychedelicious
5006e23456 feat(ui): added reset options button 2025-03-24 13:24:13 +11:00
psychedelicious
2f063bddda fix(ui): restore field-node overlay
Accidentally removed it
2025-03-24 13:24:13 +11:00
psychedelicious
23a26422fd feat(ui): support for custom string field dropdowns in builder 2025-03-24 13:24:13 +11:00
psychedelicious
434f195a96 feat(ui): add empty string placeholder to string fields 2025-03-24 13:24:13 +11:00
psychedelicious
6a4c2d692c chore(ui): typegen 2025-03-24 12:45:46 +11:00
psychedelicious
5127a07cf9 feat(nodes): clean up lora node names
I had named them wonkily and caused some user confusion.
2025-03-24 12:45:46 +11:00
psychedelicious
0b4c6f0ab4 fix(mm): flux model variant probing
In #7780 we added FLUX Fill support, and needed the probe to be able to distinguish between "normal" FLUX models and FLUX Fill models.

Logic was added to the probe to check a particular state dict key (input channels), which should be 384 for FLUX Fill and 64 for other FLUX models.

The new logic was stricter and instead of falling back on the "normal" variant, it raised when an unexpected value for input channels was detected.

This caused failures to probe for BNB-NF4 quantized FLUX Dev/Schnell, which apparently only have 1 input channel.

After checking a variety of FLUX models, I loosened the strictness of the variant probing logic to only special-case the new FLUX Fill model, and otherwise fall back to returning the "normal" variant. This better matches the old behaviour and fixes the import errors.

Closes #7822
2025-03-24 12:36:18 +11:00
Billy
d8450033ea Fix 2025-03-21 17:46:18 +11:00
Billy
3938736bd8 Ruff formatting 2025-03-21 17:35:12 +11:00
Billy
fb2c7b9566 Defaults 2025-03-21 17:35:04 +11:00
Billy
29449ec27d Implement new api for LLaVA 2025-03-21 17:17:56 +11:00
Billy
e38f778d28 Extend ModelOnDisk 2025-03-21 17:17:15 +11:00
Billy
f5e78436a8 Update regression test 2025-03-21 17:14:02 +11:00
Billy
6a15b5d9be Add stripped models for testing llava 2025-03-21 15:34:20 +11:00
279 changed files with 4428 additions and 1403 deletions

View File

@@ -1,2 +1,5 @@
b3dccfaeb636599c02effc377cdd8a87d658256c
218b6d0546b990fc449c876fb99f44b50c4daa35
182580ff6970caed400be178c5b888514b75d7f2
8e9d5c1187b0d36da80571ce4c8ba9b3a37b6c46
99aac5870e1092b182e6c5f21abcaab6936a4ad1

View File

@@ -34,6 +34,9 @@ on:
jobs:
python-checks:
env:
# uv requires a venv by default - but for this, we can simply use the system python
UV_SYSTEM_PYTHON: 1
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
@@ -57,25 +60,19 @@ jobs:
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
- name: setup uv
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
uses: astral-sh/setup-uv@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff==0.9.9
shell: bash
version: '0.6.10'
enable-cache: true
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff check --output-format=github .
run: uv tool run ruff@0.11.2 check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff format --check .
run: uv tool run ruff@0.11.2 format --check .
shell: bash

View File

@@ -39,24 +39,15 @@ jobs:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
- '3.12'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
os: ubuntu-24.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
@@ -70,6 +61,8 @@ jobs:
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
UV_SYSTEM_PYTHON: 1
steps:
- name: checkout
# https://github.com/nschloe/action-cached-lfs-checkout
@@ -92,20 +85,25 @@ jobs:
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup uv
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: astral-sh/setup-uv@v5
with:
version: '0.6.10'
enable-cache: true
python-version: ${{ matrix.python-version }}
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
UV_INDEX: ${{ matrix.extra-index-url }}
run: uv pip install --editable ".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}

View File

@@ -54,17 +54,25 @@ jobs:
- 'pyproject.toml'
- 'invokeai/**'
- name: setup uv
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: astral-sh/setup-uv@v5
with:
version: '0.6.10'
enable-cache: true
python-version: '3.11'
- name: setup python
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
python-version: '3.11'
- name: install python dependencies
- name: install dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: pip3 install --use-pep517 --editable="."
env:
UV_INDEX: ${{ matrix.extra-index-url }}
run: uv pip install --editable .
- name: install frontend dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
@@ -77,7 +85,7 @@ jobs:
- name: generate schema
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: make frontend-typegen
run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen
shell: bash
- name: compare files

View File

@@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.util.logging import logging
@@ -99,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
async def get_config() -> AppConfig:
async def get_config_() -> AppConfig:
infill_methods = ["lama", "tile", "cv2", "color"] # TODO: add mosaic back
if PatchMatch.patchmatch_available():
infill_methods.append("patchmatch")
@@ -121,6 +122,21 @@ async def get_config() -> AppConfig:
)
class InvokeAIAppConfigWithSetFields(BaseModel):
"""InvokeAI App Config with model fields set"""
set_fields: set[str] = Field(description="The set fields")
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
@app_router.get(
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
)
async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
config = get_config()
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
@app_router.get(
"/logging",
operation_id="get_log_level",

View File

@@ -96,6 +96,22 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
class ImageUploadEntry(BaseModel):
image_dto: ImageDTO = Body(description="The image DTO")
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
@images_router.post("/", operation_id="create_image_upload_entry")
async def create_image_upload_entry(
width: int = Body(description="The width of the image"),
height: int = Body(description="The height of the image"),
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
) -> ImageUploadEntry:
"""Uploads an image from a URL, not implemented"""
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),

View File

@@ -28,12 +28,10 @@ from invokeai.app.services.model_records import (
UnknownModelException,
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
MainCheckpointConfig,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch

View File

@@ -1,10 +1,13 @@
from typing import Optional
import json
from typing import Any, Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.fields import BoardField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@@ -15,6 +18,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
ClearResult,
EnqueueBatchResult,
FieldIdentifier,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
@@ -22,6 +26,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.compose_pydantic_model import compose_model_from_fields
from invokeai.app.services.shared.pagination import CursorPaginatedResults
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
@@ -34,6 +39,17 @@ class SessionQueueAndProcessorStatus(BaseModel):
processor: SessionProcessorStatus
class SimpleModelIdentifer(BaseModel):
id: str = Field(description="The model id")
model_field_overrides = {ModelIdentifierField: (SimpleModelIdentifer, Field(description="The model identifier"))}
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField, Optional[BoardField]}
@session_queue_router.post(
"/{queue_id}/enqueue_batch",
operation_id="enqueue_batch",
@@ -45,9 +61,52 @@ async def enqueue_batch(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch: Batch = Body(description="Batch to process"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
is_api_validation_run: bool = Body(
default=False,
description="Whether or not this is a validation run.",
),
api_input_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as input to the API"
),
api_output_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as output from the API"
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
if is_api_validation_run:
session_count = batch.get_session_count()
assert session_count == 1, "API validation run only supports single session batches"
if api_input_fields:
composed_model = compose_model_from_fields(
g=batch.graph,
field_identifiers=api_input_fields,
composed_model_class_name="APIInputModel",
model_field_overrides=model_field_overrides,
model_field_filter=model_field_filter,
)
json_schema = composed_model.model_json_schema(mode="validation")
print("API Input Model")
print(json.dumps(json_schema))
if api_output_fields:
composed_model = compose_model_from_fields(
g=batch.graph,
field_identifiers=api_output_fields,
composed_model_class_name="APIOutputModel",
)
json_schema = composed_model.model_json_schema(mode="validation")
print("API Output Model")
print(json.dumps(json_schema))
print("graph")
print(batch.graph.model_dump_json())
if batch.workflow is not None:
print("workflow")
print(batch.workflow.model_dump_json())
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)

View File

@@ -1,4 +1,5 @@
import io
import random
import traceback
from typing import Optional
@@ -24,6 +25,37 @@ from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import
IMAGE_MAX_AGE = 31536000
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
ids = {
"6614752a-0420-4d81-98fc-e110069d4f38": random.choice([True, False]),
"default_5e8b008d-c697-45d0-8883-085a954c6ace": random.choice([True, False]),
"4b2b297a-0d47-4f43-8113-ebbf3f403089": random.choice([True, False]),
"d0ce602a-049e-4368-97ae-977b49eed042": random.choice([True, False]),
"f170a187-fd74-40b8-ba9c-00de173ea4b9": random.choice([True, False]),
"default_f96e794f-eb3e-4d01-a960-9b4e43402bcf": random.choice([True, False]),
"default_cbf0e034-7b54-4b2c-b670-3b1e2e4b4a88": random.choice([True, False]),
"default_dec5a2e9-f59c-40d9-8869-a056751d79b8": random.choice([True, False]),
"default_dbe46d95-22aa-43fb-9c16-94400d0ce2fd": random.choice([True, False]),
"default_d7a1c60f-ca2f-4f90-9e33-75a826ca6d8f": random.choice([True, False]),
"default_e71d153c-2089-43c7-bd2c-f61f37d4c1c1": random.choice([True, False]),
"default_7dde3e36-d78f-4152-9eea-00ef9c8124ed": random.choice([True, False]),
"default_444fe292-896b-44fd-bfc6-c0b5d220fffc": random.choice([True, False]),
"default_2d05e719-a6b9-4e64-9310-b875d3b2f9d2": random.choice([True, False]),
"acae7e87-070b-4999-9074-c5b593c86618": random.choice([True, False]),
"3008fc77-1521-49c7-ba95-94c5a4508d1d": random.choice([True, False]),
"default_686bb1d0-d086-4c70-9fa3-2f600b922023": random.choice([True, False]),
"36905c46-e768-4dc3-8ecd-e55fe69bf03c": random.choice([True, False]),
"7c3e4951-183b-40ef-a890-28eef4d50097": random.choice([True, False]),
"7a053b2f-64e4-4152-80e9-296006e77131": random.choice([True, False]),
"27d4f1be-4156-46e9-8d22-d0508cd72d4f": random.choice([True, False]),
"e881dc06-70d2-438f-b007-6f3e0c3c0e78": random.choice([True, False]),
"265d2244-a1d7-495c-a2eb-88217f5eae37": random.choice([True, False]),
"caebcbc7-2bf0-41c4-b553-106b585fddda": random.choice([True, False]),
"a7998705-474e-417d-bd37-a2a9480beedf": random.choice([True, False]),
"554d94b5-94b3-4d8e-8aed-51ebfc9deea5": random.choice([True, False]),
"e6898540-c1bc-408b-b944-c1e242cddbcd": random.choice([True, False]),
"363b0960-ab2c-4902-8df3-f592d6194bb3": random.choice([True, False]),
}
@workflows_router.get(
"/i/{workflow_id}",
@@ -39,6 +71,8 @@ async def get_workflow(
try:
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
workflow.is_published = ids.get(workflow_id, False)
workflow.workflow.is_published = ids.get(workflow_id, False)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
@@ -106,10 +140,11 @@ async def list_workflows(
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
workflow_record_list_items = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
direction=direction,
page=page,
@@ -118,20 +153,23 @@ async def list_workflows(
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
is_published=is_published,
)
for workflow in workflows.items:
for item in workflow_record_list_items.items:
data = item.model_dump()
data["is_published"] = ids.get(item.workflow_id, False)
workflows_with_thumbnails.append(
WorkflowRecordListItemWithThumbnailDTO(
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
**workflow.model_dump(),
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(item.workflow_id),
**data,
)
)
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
items=workflows_with_thumbnails,
total=workflows.total,
page=workflows.page,
pages=workflows.pages,
per_page=workflows.per_page,
total=workflow_record_list_items.total,
page=workflow_record_list_items.page,
pages=workflow_record_list_items.pages,
per_page=workflow_record_list_items.per_page,
)

View File

@@ -8,6 +8,7 @@ import sys
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from functools import lru_cache
from inspect import signature
from typing import (
TYPE_CHECKING,
@@ -27,7 +28,6 @@ import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import (
FieldKind,
@@ -100,37 +100,6 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -173,76 +142,16 @@ class BaseInvocation(ABC, BaseModel):
All invocations must use the `@invocation` decorator to provide their unique type.
"""
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
return cls.model_fields["type"].default
@classmethod
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
def invalidate_typeadapter(cls) -> None:
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
the updated allowlist and denylist."""
cls._typeadapter_needs_update = True
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in BaseInvocation.get_invocations())
@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
@@ -340,6 +249,105 @@ class BaseInvocation(ABC, BaseModel):
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
class InvocationRegistry:
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()
@classmethod
@lru_cache(maxsize=1)
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation types.
This is used to parse serialized invocations into the correct invocation class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
@classmethod
def invalidate_invocation_typeadapter(cls) -> None:
"""Invalidates the cached invocation type adapter."""
cls.get_invocation_typeadapter.cache_clear()
@classmethod
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[type[BaseInvocation]] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in cls.get_invocation_classes()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in cls.get_invocation_classes())
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()
@classmethod
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
This is used to parse serialized invocation outputs into the correct invocation output class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
@classmethod
def invalidate_output_typeadapter(cls) -> None:
"""Invalidates the cached invocation output type adapter."""
cls.get_output_typeadapter.cache_clear()
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
"is_intermediate",
@@ -453,8 +461,8 @@ def invocation(
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in BaseInvocation.get_invocation_types():
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
@@ -539,8 +547,7 @@ def invocation(
)
cls.__doc__ = docstring
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
BaseInvocation.register_invocation(cls) # type: ignore
InvocationRegistry.register_invocation(cls)
return cls
@@ -565,7 +572,7 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in BaseInvocationOutput.get_output_types():
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
@@ -586,7 +593,7 @@ def invocation_output(
)
cls.__doc__ = docstring
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
InvocationRegistry.register_output(cls)
return cls

View File

@@ -19,7 +19,8 @@ from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.model_manager.config import MainConfigBase
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor

View File

@@ -39,8 +39,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

View File

@@ -1,7 +1,6 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -25,7 +24,6 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
tags=["lora", "model", "flux"],
category="model",
version="1.1.1",
classification=Classification.Prototype,
)
class FluxControlLoRALoaderInvocation(BaseInvocation):
"""LoRA model and Image to use with FLUX transformer generation."""

View File

@@ -3,7 +3,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -52,7 +51,6 @@ class FluxControlNetOutput(BaseInvocationOutput):
tags=["controlnet", "flux"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxControlNetInvocation(BaseInvocation):
"""Collect FLUX ControlNet info to pass to other nodes."""

View File

@@ -10,7 +10,7 @@ from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -49,7 +49,7 @@ from invokeai.backend.flux.sampling_utils import (
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat, ModelVariantType
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -64,7 +64,6 @@ from invokeai.backend.util.devices import TorchDevice
tags=["image", "flux"],
category="image",
version="3.3.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""

View File

@@ -31,7 +31,7 @@ class FluxFillOutput(BaseInvocationOutput):
tags=["inpaint"],
category="inpaint",
version="1.0.0",
classification=Classification.Prototype,
classification=Classification.Beta,
)
class FluxFillInvocation(BaseInvocation):
"""Prepare the FLUX Fill conditioning data."""

View File

@@ -4,7 +4,7 @@ from typing import List, Literal, Union
from pydantic import field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, UIType
from invokeai.app.invocations.ip_adapter import (
CLIP_VISION_MODEL_MAP,
@@ -28,7 +28,6 @@ from invokeai.backend.model_manager.config import (
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxIPAdapterInvocation(BaseInvocation):
"""Collects FLUX IP-Adapter info to pass to other nodes."""

View File

@@ -3,14 +3,13 @@ from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType
@invocation_output("flux_lora_loader_output")
@@ -28,11 +27,10 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
title="Apply LoRA - FLUX",
tags=["lora", "model", "flux"],
category="model",
version="1.2.0",
classification=Classification.Prototype,
version="1.2.1",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
@@ -107,11 +105,10 @@ class FluxLoRALoaderInvocation(BaseInvocation):
@invocation(
"flux_lora_collection_loader",
title="FLUX LoRA Collection Loader",
title="Apply LoRA Collection - FLUX",
tags=["lora", "model", "flux"],
category="model",
version="1.3.0",
classification=Classification.Prototype,
version="1.3.1",
)
class FLUXLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a FLUX transformer."""

View File

@@ -3,7 +3,6 @@ from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -17,8 +16,8 @@ from invokeai.app.util.t5_model_identifier import (
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
SubModelType,
)
from invokeai.backend.model_manager.taxonomy import SubModelType
@invocation_output("flux_model_loader_output")
@@ -41,7 +40,6 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
tags=["model", "flux"],
category="model",
version="1.0.6",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""

View File

@@ -23,7 +23,8 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.starter_models import siglip
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice
@@ -44,7 +45,7 @@ class FluxReduxOutput(BaseInvocationOutput):
tags=["ip_adapter", "control"],
category="ip_adapter",
version="2.0.0",
classification=Classification.Prototype,
classification=Classification.Beta,
)
class FluxReduxInvocation(BaseInvocation):
"""Runs a FLUX Redux model to generate a conditioning tensor."""

View File

@@ -4,7 +4,7 @@ from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
@@ -17,7 +17,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.model_manager import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -30,7 +30,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.2",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a flux image."""

View File

@@ -6,7 +6,7 @@ from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import UNetField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType
@invocation_output("ideal_size_output")

View File

@@ -355,7 +355,6 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
tags=["image", "unsharp_mask"],
category="image",
version="1.2.2",
classification=Classification.Beta,
)
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies an unsharp mask filter to an image"""
@@ -1090,12 +1089,13 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
If the fade size is 0, the mask is returned as-is.
"""
mask: ImageField = InputField(description="The mask to expand")
@@ -1105,6 +1105,11 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
if self.fade_size_px == 0:
# If the fade size is 0, just return the mask as-is.
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
np_mask = numpy.array(pil_mask)
# Threshold the mask to create a binary mask - 0 for black, 255 for white
@@ -1142,8 +1147,21 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
coeffs = numpy.polyfit(x_control, y_control, 3)
poly = numpy.poly1d(coeffs)
# Evaluate and clip the smooth mapping
feather = numpy.clip(poly(d_norm), 0, 1)
# Evaluate the polynomial
feather = poly(d_norm)
# The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
# even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
# polynomial fit, which is a best approximation of the control points but not an exact match.
# When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
# have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.
# Force pixels at or beyond the fade distance to exactly 1.0
feather = numpy.where(d_norm >= 1.0, 1.0, feather)
# Clip any other values to ensure they're in the valid range [0,1]
feather = numpy.clip(feather, 0, 1)
# Build final image.
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))
@@ -1265,7 +1283,6 @@ class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
@@ -1292,7 +1309,6 @@ class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Paste the source image into the target image at the given bounding box.

View File

@@ -13,10 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
ModelType,
)
from invokeai.backend.model_manager.starter_models import (
StarterModel,
@@ -24,6 +22,7 @@ from invokeai.backend.model_manager.starter_models import (
ip_adapter_sd_image_encoder,
ip_adapter_sdxl_image_encoder,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class IPAdapterField(BaseModel):

View File

@@ -4,7 +4,7 @@ import torch
from PIL.Image import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
@@ -13,7 +13,14 @@ from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.util.devices import TorchDevice
@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
@invocation(
"llava_onevision_vllm",
title="LLaVA OneVision VLLM",
tags=["vllm"],
category="vllm",
version="1.0.0",
classification=Classification.Beta,
)
class LlavaOnevisionVllmInvocation(BaseInvocation):
"""Run a LLaVA OneVision VLLM model."""

View File

@@ -4,7 +4,6 @@ from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
)
@@ -58,7 +57,6 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
tags=["conditioning"],
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)
class AlphaMaskToTensorInvocation(BaseInvocation):
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
@@ -87,7 +85,6 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
tags=["conditioning"],
category="conditioning",
version="1.1.0",
classification=Classification.Beta,
)
class InvertTensorMaskInvocation(BaseInvocation):
"""Inverts a tensor mask."""
@@ -234,7 +231,6 @@ WHITE = ColorField(r=255, g=255, b=255, a=255)
tags=["mask"],
category="mask",
version="1.0.0",
classification=Classification.Beta,
)
class GetMaskBoundingBoxInvocation(BaseInvocation):
"""Gets the bounding box of the given mask image."""

View File

@@ -43,7 +43,7 @@ from invokeai.app.invocations.primitives import BooleanOutput, FloatOutput, Inte
from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import ModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import ModelType, SubModelType
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.version import __version__

View File

@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -15,10 +14,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
class ModelIdentifierField(BaseModel):
@@ -126,7 +123,6 @@ class ModelIdentifierOutput(BaseInvocationOutput):
tags=["model"],
category="model",
version="1.0.1",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
@@ -181,7 +177,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
@invocation("lora_loader", title="Apply LoRA - SD1.5", tags=["model"], category="model", version="1.0.4")
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@@ -244,7 +240,7 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Model - SD1.5", tags=["model"], category="model", version="1.0.2")
@invocation("lora_selector", title="Select LoRA", tags=["model"], category="model", version="1.0.3")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
@@ -258,7 +254,7 @@ class LoRASelectorInvocation(BaseInvocation):
@invocation(
"lora_collection_loader", title="LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.1"
"lora_collection_loader", title="Apply LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.2"
)
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
@@ -322,10 +318,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
@invocation(
"sdxl_lora_loader",
title="LoRA Model - SDXL",
title="Apply LoRA - SDXL",
tags=["lora", "model"],
category="model",
version="1.0.4",
version="1.0.5",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@@ -402,10 +398,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
@invocation(
"sdxl_lora_collection_loader",
title="LoRA Collection - SDXL",
title="Apply LoRA Collection - SDXL",
tags=["model"],
category="model",
version="1.1.1",
version="1.1.2",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""

View File

@@ -6,7 +6,7 @@ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
@@ -23,7 +23,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
@@ -36,7 +36,6 @@ from invokeai.backend.util.devices import TorchDevice
tags=["image", "sd3"],
category="image",
version="1.1.1",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""

View File

@@ -2,7 +2,7 @@ import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@@ -25,7 +25,6 @@ from invokeai.backend.util.devices import TorchDevice
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
version="1.0.1",
classification=Classification.Prototype,
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image."""

View File

@@ -3,7 +3,6 @@ from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -14,7 +13,7 @@ from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.taxonomy import SubModelType
@invocation_output("sd3_model_loader_output")
@@ -34,7 +33,6 @@ class Sd3ModelLoaderOutput(BaseInvocationOutput):
tags=["model", "sd3"],
category="model",
version="1.0.1",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""

View File

@@ -11,12 +11,12 @@ from transformers import (
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.model_manager.taxonomy import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -33,7 +33,6 @@ SD3_T5_MAX_SEQ_LEN = 256
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.1",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""

View File

@@ -2,7 +2,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.taxonomy import SubModelType
@invocation_output("sdxl_model_loader_output")

View File

@@ -7,7 +7,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
@@ -56,7 +56,6 @@ def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> C
title="Tiled Multi-Diffusion Denoise - SD1.5, SDXL",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
version="1.0.1",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -40,7 +39,6 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class CalculateImageTilesInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@@ -74,7 +72,6 @@ class CalculateImageTilesInvocation(BaseInvocation):
tags=["tiles"],
category="tiles",
version="1.1.1",
classification=Classification.Beta,
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@@ -117,7 +114,6 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@@ -168,7 +164,6 @@ class TileToPropertiesOutput(BaseInvocationOutput):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class TileToPropertiesInvocation(BaseInvocation):
"""Split a Tile into its individual properties."""
@@ -201,7 +196,6 @@ class PairTileImageOutput(BaseInvocationOutput):
tags=["tiles"],
category="tiles",
version="1.0.1",
classification=Classification.Beta,
)
class PairTileImageInvocation(BaseInvocation):
"""Pair an image with its tile properties."""
@@ -230,7 +224,6 @@ BLEND_MODES = Literal["Linear", "Seam"]
tags=["tiles"],
category="tiles",
version="1.1.1",
classification=Classification.Beta,
)
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Merge multiple tile images into a single image."""

View File

@@ -41,16 +41,15 @@ def run_app() -> None:
)
# Find an open port, and modify the config accordingly.
orig_config_port = app_config.port
app_config.port = find_open_port(app_config.port)
if orig_config_port != app_config.port:
first_open_port = find_open_port(app_config.port)
if app_config.port != first_open_port:
orig_config_port = app_config.port
app_config.port = first_open_port
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
# Miscellaneous startup tasks.
apply_monkeypatches()
register_mime_types()
if app_config.dev_reload:
enable_dev_reload()
check_cudnn(logger)
# Initialize the app and event loop.
@@ -61,6 +60,11 @@ def run_app() -> None:
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
if app_config.dev_reload:
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
# imported.
enable_dev_reload(custom_nodes_path=app_config.custom_nodes_path)
# Start the server.
config = uvicorn.Config(
app=app,

View File

@@ -44,7 +44,8 @@ if TYPE_CHECKING:
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig
class EventServiceBase:

View File

@@ -16,7 +16,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob

View File

@@ -10,9 +10,9 @@ from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
class InstallStatus(str, Enum):

View File

@@ -39,8 +39,6 @@ from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
ModelRepoVariant,
ModelSourceType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
@@ -52,6 +50,7 @@ from invokeai.backend.model_manager.metadata import (
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice

View File

@@ -5,9 +5,10 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
class ModelLoadServiceBase(ABC):

View File

@@ -11,7 +11,7 @@ from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.load import (
)
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

View File

@@ -1,16 +1,12 @@
"""Initialization file for model manager service."""
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.load import LoadedModel
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelType",
"SubModelType",
"LoadedModel",
]

View File

@@ -14,10 +14,12 @@ from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ClipVariantType,
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
)
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
ModelFormat,
ModelSourceType,
ModelType,

View File

@@ -60,11 +60,9 @@ from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
class ModelRecordServiceSQL(ModelRecordServiceBase):
@@ -304,7 +302,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
self._logger.warning(
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
)
else:
results.append(model_config)

View File

@@ -33,7 +33,12 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
def enqueue_batch(
self,
queue_id: str,
batch: Batch,
prepend: bool,
) -> Coroutine[Any, Any, EnqueueBatchResult]:
"""Enqueues all permutations of a batch for execution."""
pass

View File

@@ -157,6 +157,28 @@ class Batch(BaseModel):
v.validate_self()
return v
def get_session_count(self) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
If the session count has already been calculated, return the cached value.
"""
if not self.data:
return self.runs
data = []
for batch_datum_list in self.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * self.runs
model_config = ConfigDict(
json_schema_extra={
"required": [
@@ -201,6 +223,12 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
return None
class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
@@ -237,6 +265,16 @@ class SessionQueueItemWithoutGraph(BaseModel):
retried_from_item_id: Optional[int] = Field(
default=None, description="The item_id of the queue item that this item was retried from"
)
is_api_validation_run: bool = Field(
default=False,
description="Whether this queue item is an API validation run.",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
@@ -536,28 +574,6 @@ def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
return batch.runs
data = []
for batch_datum_list in batch.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)

View File

@@ -28,7 +28,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -118,7 +117,8 @@ class SqliteSessionQueue(SessionQueueBase):
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
requested_count = batch.get_session_count()
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,

View File

@@ -0,0 +1,204 @@
from copy import deepcopy
from typing import Any, Callable, TypeAlias, get_args
from pydantic import BaseModel, ConfigDict, create_model
from pydantic.fields import FieldInfo
from invokeai.app.services.session_queue.session_queue_common import FieldIdentifier
from invokeai.app.services.shared.graph import Graph
DictOfFieldsMetadata: TypeAlias = dict[str, tuple[type[Any], FieldInfo]]
class ComposedFieldMetadata(BaseModel):
node_id: str
field_name: str
field_type_class_name: str
def dedupe_field_name(field_metadata: DictOfFieldsMetadata, field_name: str) -> str:
"""Given a field name, return a name that is not already in the field metadata.
If the field name is not in the field metadata, return the field name.
If the field name is in the field metadata, generate a new name by appending an underscore and integer to the field name, starting with 2.
"""
if field_name not in field_metadata:
return field_name
i = 2
while True:
new_field_name = f"{field_name}_{i}"
if new_field_name not in field_metadata:
return new_field_name
i += 1
def compose_model_from_fields(
g: Graph,
field_identifiers: list[FieldIdentifier],
composed_model_class_name: str = "ComposedModel",
model_field_overrides: dict[type[Any], tuple[type[Any], FieldInfo]] | None = None,
model_field_filter: Callable[[type[Any]], bool] | None = None,
) -> type[BaseModel]:
"""Given a graph and a list of field identifiers, create a new pydantic model composed of the fields of the nodes in the graph.
The resultant model can be used to validate a JSON payload that contains the fields of the nodes in the graph, or generate an
OpenAPI schema for the model.
Args:
g: The graph containing the nodes whose fields will be composed into the new model.
field_identifiers: A list of FieldIdentifier instances, each representing a field on a node in the graph.
model_name: The name of the composed model.
kind: The kind of model to create. Must be "input" or "output". Defaults to "input".
model_field_overrides: A dictionary mapping type annotations to tuples of (new_type_annotation, new_field_info).
This can be used to override the type annotation and field info of a field in the composed model. For example,
if `ModelIdentifierField` should be replaced by a string, the dictionary would look like this:
```python
{ModelIdentifierField: (str, Field(description="The model id."))}
```
model_field_filter: A function that takes a type annotation and returns True if the field should be included in the composed model.
If None, all fields will be included. For example, to omit `BoardField` fields, the filter would look like this:
```python
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField}
```
Optional fields - or any other complex field types like unions - must be explicitly included in the filter. For example,
to omit `BoardField` _and_ `Optional[BoardField]`:
```python
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField, Optional[BoardField]}
```
Note that the filter is applied to the type annotation of the field, not the field itself.
Example usage:
```python
# Create some nodes.
add_node = AddInvocation()
sub_node = SubtractInvocation()
color_node = ColorInvocation()
# Create a graph with the nodes.
g = Graph(
nodes={
add_node.id: add_node,
sub_node.id: sub_node,
color_node.id: color_node,
}
)
# Select the fields to compose.
fields_to_compose = [
FieldIdentifier(node_id=add_node.id, field_name="a"),
FieldIdentifier(node_id=sub_node.id, field_name="a"), # this will be deduped to "a_2"
FieldIdentifier(node_id=add_node.id, field_name="b"),
FieldIdentifier(node_id=color_node.id, field_name="color"),
]
# Compose the model from the fields.
composed_model = compose_model_from_fields(g, fields_to_compose, model_name="ComposedModel")
# Generate the OpenAPI schema for the model.
json_schema = composed_model.model_json_schema(mode="validation")
```
"""
# Temp storage for the composed fields. Pydantic needs a type annotation and instance of FieldInfo to create a model.
field_metadata: DictOfFieldsMetadata = {}
model_field_overrides = model_field_overrides or {}
# The list of required fields. This is used to ensure the composed model's fields retain their required state.
required: list[str] = []
for field_identifier in field_identifiers:
node_id = field_identifier.node_id
field_name = field_identifier.field_name
# Pull the node instance from the graph so we can introspect it.
node_instance = g.nodes[node_id]
if field_identifier.kind == "input":
# Get the class of the node. This will be a BaseInvocation subclass, e.g. AddInvocation, DenoiseLatentsInvocation, etc.
pydantic_model = type(node_instance)
else:
# Otherwise the the type of the node's output class. This will be a BaseInvocationOutput subclass, e.g. IntegerOutput, ImageOutput, etc.
pydantic_model = type(node_instance).get_output_annotation()
# Get the FieldInfo instance for the field. For example:
# a: int = Field(..., description="The first number to add.")
# ^^^^^ The return value of this Field call is the FieldInfo instance (Field is a function).
og_field_info = pydantic_model.model_fields[field_name]
# Get the type annotation of the field. For example:
# a: int = Field(..., description="The first number to add.")
# ^^^ this is the type annotation
og_field_type = og_field_info.annotation
# Apparently pydantic allows fields without type annotations. We don't support that.
assert og_field_type is not None, (
f"{field_identifier.kind.capitalize()} field {field_name} on node {node_id} has no type annotation."
)
# Now that we have the type annotation, we can apply the filter to see if we should include the field in the composed model.
if model_field_filter and not model_field_filter(og_field_type):
continue
# Ok, we want this type of field. Retrieve any overrides for the field type. This is a dictionary mapping
# type annotations to tuples of (override_type_annotation, override_field_info).
(override_field_type, override_field_info) = model_field_overrides.get(og_field_type, (None, None))
# The override tuple's first element is the new type annotation, if it exists.
composed_field_type = override_field_type if override_field_type is not None else og_field_type
# Create a deep copy of the FieldInfo instance (or override it if it exists) so we can modify it without
# affecting the original. This is important because we are going to modify the FieldInfo instance and
# don't want to affect the original model's schema.
composed_field_info = deepcopy(override_field_info if override_field_info is not None else og_field_info)
json_schema_extra = og_field_info.json_schema_extra if isinstance(og_field_info.json_schema_extra, dict) else {}
# The field's original required state is stored in the json_schema_extra dict. For more information about why,
# see the definition of `InputField` in invokeai/app/invocations/fields.py.
#
# Add the field to the required list if it is required, which we will use when creating the composed model.
if json_schema_extra.get("orig_required", False):
required.append(field_name)
# Invocation fields have some extra metadata, used by the UI to render the field in the frontend. This data is
# included in the OpenAPI schema for each field. For example, we add a "ui_order" field, which the UI uses to
# sort fields when rendering them.
#
# The composed model's OpenAPI schema should not have this information. It should only have a standard OpenAPI
# schema for the field. We need to strip out the UI-specific metadata from the FieldInfo instance before adding
# it to the composed model.
#
# We will replace this metadata with some custom metadata:
# - node_id: The id of the node that this field belongs to.
# - field_name: The name of the field on the node.
# - original_data_type: The original data type of the field.
field_type_class = get_args(og_field_type)[0] if hasattr(og_field_type, "__args__") else og_field_type
field_type_class_name = field_type_class.__name__
composed_field_metadata = ComposedFieldMetadata(
node_id=node_id,
field_name=field_name,
field_type_class_name=field_type_class_name,
)
composed_field_info.json_schema_extra = {
"composed_field_extra": composed_field_metadata.model_dump(),
}
# Override the name, title and description if overrides are provided. Dedupe the field name if necessary.
final_field_name = dedupe_field_name(field_metadata, field_name)
# Store the field metadata.
field_metadata.update({final_field_name: (composed_field_type, composed_field_info)})
# Splat in the composed fields to create the new model. There are type errors here because create_model's kwargs are not typed,
# and for some reason pydantic's ConfigDict doesn't like lists in `json_schema_extra`. Anyways, the inputs here are correct.
return create_model(
composed_model_class_name,
**field_metadata,
__config__=ConfigDict(json_schema_extra={"required": required}),
)

View File

@@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationRegistry,
invocation,
invocation_output,
)
@@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation):
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
@@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return InvocationRegistry.get_output_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}

View File

@@ -20,14 +20,10 @@ from invokeai.app.services.session_processor.session_processor_common import Pro
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData

View File

@@ -47,6 +47,7 @@ class WorkflowRecordsStorageBase(ABC):
query: Optional[str],
tags: Optional[list[str]],
has_been_opened: Optional[bool],
is_published: Optional[bool],
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass
@@ -56,6 +57,7 @@ class WorkflowRecordsStorageBase(ABC):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided categories."""
pass
@@ -66,6 +68,7 @@ class WorkflowRecordsStorageBase(ABC):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided tags."""
pass

View File

@@ -67,6 +67,7 @@ class WorkflowWithoutID(BaseModel):
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
# it is None.
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
model_config = ConfigDict(extra="ignore")
@@ -101,6 +102,7 @@ class WorkflowRecordDTOBase(BaseModel):
opened_at: Optional[Union[datetime.datetime, str]] = Field(
default=None, description="The opened timestamp of the workflow."
)
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
class WorkflowRecordDTO(WorkflowRecordDTOBase):

View File

@@ -119,6 +119,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
query: Optional[str] = None,
tags: Optional[list[str]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
@@ -241,6 +242,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
if not tags:
return {}
@@ -292,6 +294,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}

View File

@@ -4,7 +4,10 @@ from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.baseinvocation import (
InvocationRegistry,
UIConfigBase,
)
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
@@ -56,14 +59,14 @@ def get_openapi_func(
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
for invocation in InvocationRegistry.get_invocation_classes():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)

View File

@@ -1,6 +1,7 @@
import logging
import mimetypes
import socket
from pathlib import Path
import torch
@@ -33,7 +34,16 @@ def check_cudnn(logger: logging.Logger) -> None:
)
def enable_dev_reload() -> None:
def invokeai_source_dir() -> Path:
# `invokeai.__file__` doesn't always work for editable installs
this_module_path = Path(__file__).resolve()
# https://youtrack.jetbrains.com/issue/PY-38382/Unresolved-reference-spec-but-this-is-standard-builtin
# noinspection PyUnresolvedReferences
depth = len(__spec__.parent.split("."))
return this_module_path.parents[depth - 1]
def enable_dev_reload(custom_nodes_path=None) -> None:
"""Enable hot reloading on python file changes during development."""
from invokeai.backend.util.logging import InvokeAILogger
@@ -44,7 +54,10 @@ def enable_dev_reload() -> None:
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
) from e
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
paths = [str(invokeai_source_dir() / "*.py")]
if custom_nodes_path:
paths.append(str(custom_nodes_path / "*.py"))
jurigged.watch(pattern=paths, logger=InvokeAILogger.get_logger(name="jurigged").info)
def apply_monkeypatches() -> None:

View File

@@ -5,7 +5,7 @@ import torch
from PIL import Image
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
# fast latents preview matrix for sdxl

View File

@@ -1,5 +1,5 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, SubModelType
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:

View File

@@ -4,7 +4,7 @@ from typing import List, Tuple
import invokeai.backend.util.logging as logger
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.textual_inversion import TextualInversionModelRaw

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from invokeai.backend.model_manager.legacy_probe import CkptType
def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
"""Gets the in channels from the state dict."""
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use
# "model.diffusion_model.img_in.weight". Known models that use the latter key:
# - https://civitai.com/models/885098?modelVersionId=990775
# - https://civitai.com/models/1018060?modelVersionId=1596255
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
keys = {"img_in.weight", "model.diffusion_model.img_in.weight"}
for key in keys:
val = state_dict.get(key)
if val is not None:
return val.shape[1]
return None

View File

@@ -6,8 +6,8 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.model_manager.taxonomy import AnyModel
def norm_img(np_img):

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import *
from .config import is_exportable, is_scriptable
# From PyTorch internals

View File

@@ -5,8 +5,8 @@ Copyright 2020 Ross Wightman
import re
from copy import deepcopy
from .conv2d_layers import *
from geffnet.activations import *
from .conv2d_layers import CondConv2d, get_condconv_initializer, math, partial, select_conv2d
from geffnet.activations import F, get_act_layer, nn, sigmoid, torch
__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',

View File

@@ -32,7 +32,9 @@ import torch.nn.functional as F
from .config import layer_config_kwargs, is_scriptable
from .conv2d_layers import select_conv2d
from .helpers import load_pretrained
from .efficientnet_builder import *
from .efficientnet_builder import (BN_EPS_TF_DEFAULT, EfficientNetBuilder, decode_arch_def,
initialize_weight_default, initialize_weight_goog,
resolve_act_layer, resolve_bn_args, round_channels)
__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140',
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small',

View File

@@ -13,7 +13,9 @@ from .activations import get_act_fn, get_act_layer, HardSwish
from .config import layer_config_kwargs
from .conv2d_layers import select_conv2d
from .helpers import load_pretrained
from .efficientnet_builder import *
from .efficientnet_builder import (BN_EPS_TF_DEFAULT, EfficientNetBuilder, decode_arch_def,
initialize_weight_default, initialize_weight_goog,
resolve_act_layer, resolve_bn_args, round_channels)
__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',

View File

@@ -10,7 +10,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.taxonomy import AnyModel
from invokeai.backend.util.devices import TorchDevice
"""

View File

@@ -47,3 +47,10 @@ class LlavaOnevisionModel(RawModel):
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._vllm_model.to(device=device, dtype=dtype)
def calc_size(self) -> int:
"""Get size of the model in memory in bytes."""
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._vllm_model)

View File

@@ -1,37 +1,45 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
AnyVariant,
BaseModelType,
ClipVariantType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.search import ModelSearch
__all__ = [
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelRepoVariant",
"InvalidModelConfigException",
"LoadedModel",
"ModelConfigFactory",
"ModelFormat",
"ModelProbe",
"ModelSearch",
"ModelConfigBase",
"AnyModel",
"AnyVariant",
"BaseModelType",
"ClipVariantType",
"ModelFormat",
"ModelRepoVariant",
"ModelSourceType",
"ModelType",
"ModelVariantType",
"SchedulerPredictionType",
"SubModelType",
"ModelConfigBase",
]

View File

@@ -21,6 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import time
from abc import ABC, abstractmethod
@@ -29,153 +30,41 @@ from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
import diffusers
import onnxruntime as ort
import safetensors.torch
import torch
from diffusers.models.modeling_utils import ModelMixin
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.raw_model import RawModel
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ClipVariantType,
FluxLoRAFormat,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings
logger = logging.getLogger(__name__)
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
]
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognize this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
VAE = "vae"
LoRA = "lora"
ControlLoRa = "control_lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
CLIPEmbed = "clip_embed"
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
LyCORIS = "lycoris"
ONNX = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OpenVINO = "openvino"
Flax = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
pass
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@@ -206,51 +95,6 @@ class ControlAdapterDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
def hash(self):
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self):
if self.format_type == ModelFormat.Checkpoint:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
return state_dict
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
@@ -325,16 +169,18 @@ class ModelConfigBase(ABC, BaseModel):
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
if not config_cls.matches(mod):
continue
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
@@ -359,21 +205,43 @@ class ModelConfigBase(ABC, BaseModel):
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass
@staticmethod
def cast_overrides(overrides: dict[str, Any]):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])
if "format" in overrides:
overrides["format"] = ModelFormat(overrides["format"])
if "base" in overrides:
overrides["base"] = BaseModelType(overrides["base"])
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])
if "variant" in overrides:
overrides["variant"] = ModelVariantType(overrides["variant"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
type = fields.get("type") or cls.model_fields["type"].default
base = fields.get("base") or cls.model_fields["base"].default
fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = mod.name
fields["name"] = name = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
fields.update(overrides)
return cls(**fields)
@@ -414,6 +282,38 @@ class LoRAConfigBase(ABC, BaseModel):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd)
mod.cache[key] = value
return value
@classmethod
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
if cls.flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException("Unknown LoRA type")
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
@@ -429,11 +329,40 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return False
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if type(key) is int:
continue
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return True
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return True
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
@@ -457,11 +386,26 @@ class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, Mod
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_file():
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
return any(wf.exists() for wf in weight_files)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models."""
@@ -625,12 +569,34 @@ class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class LlavaOnevisionConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_file():
return False
config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return False
architectures = config.get("architectures")
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": BaseModelType.Any,
"variant": ModelVariantType.Normal,
}
def get_model_discriminator_value(v: Any) -> str:
"""

View File

@@ -14,27 +14,30 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
AnyVariant,
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
MainModelDefaultSettings,
ModelConfigFactory,
SubmodelDefinition,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubmodelDefinition,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import (
get_clip_variant_type,
lora_token_vector_length,
@@ -562,15 +565,28 @@ class CheckpointProbeBase(ProbeBase):
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if base_type == BaseModelType.Flux:
in_channels = state_dict["img_in.weight"].shape[1]
if in_channels == 64:
in_channels = get_flux_in_channels_from_state_dict(state_dict)
if in_channels is None:
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
logger.warning(
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
)
return ModelVariantType.Normal
elif in_channels == 384:
# FLUX Model variant types are distinguished by input channels:
# - Unquantized Dev and Schnell have in_channels=64
# - BNB-NF4 Dev and Schnell have in_channels=1
# - FLUX Fill has in_channels=384
# - Unsure of quantized FLUX Fill models
# - Unsure of GGUF-quantized models
if in_channels == 384:
# This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
# type is used to determine whether to use the fill model or the base model.
return ModelVariantType.Inpaint
else:
raise InvalidModelConfigException(
f"Unexpected in_channels (in_channels={in_channels}) for FLUX model at {self.model_path}."
)
# Fall back on "normal" variant type for all other FLUX models.
return ModelVariantType.Normal
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:

View File

@@ -13,12 +13,11 @@ import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
SubModelType,
)
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
class LoadedModelWithoutConfig:

View File

@@ -6,18 +6,16 @@ from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
InvalidModelConfigException,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
SubModelType,
)
from invokeai.backend.util.devices import TorchDevice

View File

@@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, Optional
import psutil
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
@@ -23,6 +22,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
apply_custom_layers_to_model,
)
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter

View File

@@ -20,13 +20,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import ModelLoaderBase
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
class ModelLoaderRegistryBase(ABC):

View File

@@ -4,16 +4,12 @@ from typing import Optional
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
DiffusersConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)

View File

@@ -5,19 +5,19 @@ from typing import Optional
from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
)
from invokeai.backend.model_manager.config import (
BaseModelType,
AnyModelConfig,
ControlNetCheckpointConfig,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(

View File

@@ -27,15 +27,8 @@ from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
@@ -51,6 +44,13 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
)

View File

@@ -8,18 +8,16 @@ from typing import Any, Optional
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)

View File

@@ -7,8 +7,9 @@ from typing import Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.raw_model import RawModel

View File

@@ -3,15 +3,11 @@ from typing import Optional
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers)

View File

@@ -9,17 +9,17 @@ import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,

View File

@@ -5,16 +5,16 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)

View File

@@ -2,15 +2,11 @@ from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline

View File

@@ -4,15 +4,11 @@ from typing import Optional
import torch
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel

View File

@@ -11,16 +11,8 @@ from diffusers import (
StableDiffusionXLPipeline,
)
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
@@ -28,6 +20,14 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.util.silence_warnings import SilenceWarnings
VARIANT_TO_IN_CHANNEL_MAP = {

View File

@@ -4,16 +4,16 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.textual_inversion import TextualInversionModelRaw

View File

@@ -5,15 +5,16 @@ from typing import Optional
from diffusers import AutoencoderKL
from invokeai.backend.model_manager import (
AnyModelConfig,
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)

View File

@@ -15,7 +15,8 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.taxonomy import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
@@ -50,6 +51,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
SegmentAnythingPipeline,
DepthAnythingPipeline,
SigLipPipeline,
LlavaOnevisionModel,
),
):
return model.calc_size()

View File

@@ -17,12 +17,12 @@ from typing import Optional
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.metadata.metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
BaseMetadata,
)
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
class ModelMetadataFetchBase(ABC):

View File

@@ -24,7 +24,6 @@ from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundErro
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata.fetch.fetch_base import ModelMetadataFetchBase
from invokeai.backend.model_manager.metadata.metadata_base import (
AnyModelRepoMetadata,
@@ -32,6 +31,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import (
RemoteModelFile,
UnknownMetadataException,
)
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"

View File

@@ -23,7 +23,7 @@ from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.model_manager.util.select_hf_files import filter_files

View File

@@ -0,0 +1,96 @@
from pathlib import Path
from typing import Any, Optional, TypeAlias
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.silence_warnings import SilenceWarnings
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
# This prevents redundant computations during matching and parsing
self.cache = {"_CACHED_STATE_DICTS": {}}
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self) -> int:
if self.path.is_file():
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self) -> set[Path]:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
return None
weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
sd_cache = self.cache["_CACHED_STATE_DICTS"]
if path in sd_cache:
return sd_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
return state_dict

View File

@@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
class StarterModelWithoutDependencies(BaseModel):

View File

@@ -0,0 +1,138 @@
from enum import Enum
from typing import Dict, TypeAlias, Union
import diffusers
import onnxruntime as ort
import torch
from diffusers import ModelMixin
from invokeai.backend.raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
]
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
VAE = "vae"
LoRA = "lora"
ControlLoRa = "control_lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
CLIPEmbed = "clip_embed"
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
LyCORIS = "lycoris"
ONNX = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OpenVINO = "openvino"
Flax = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
class FluxLoRAFormat(str, Enum):
"""Flux LoRA formats."""
Diffusers = "flux.diffusers"
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]

View File

@@ -8,7 +8,7 @@ import picklescan.scanner as pscan
import safetensors
import torch
from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.model_manager.taxonomy import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
def filter_files(

View File

@@ -0,0 +1,24 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
def flux_format_from_state_dict(state_dict):
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
return FluxLoRAFormat.OneTrainer
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):
return FluxLoRAFormat.Diffusers
elif is_state_dict_likely_flux_control(state_dict):
return FluxLoRAFormat.Control
else:
return None

View File

@@ -8,7 +8,7 @@ from diffusers import T2IAdapter
from PIL.Image import Image
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback

View File

@@ -62,7 +62,7 @@
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.6.1",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.4.2",
"@xyflow/react": "^12.5.1",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.0.0",
@@ -150,7 +150,7 @@
"prettier": "^3.3.3",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.3.4",
"tsafe": "^1.7.5",
"tsafe": "^1.8.5",
"type-fest": "^4.26.1",
"typescript": "^5.6.2",
"vite": "^6.1.0",

View File

@@ -36,8 +36,8 @@ dependencies:
specifier: ^1.3.0
version: 1.3.0
'@xyflow/react':
specifier: ^12.4.2
version: 12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
specifier: ^12.5.1
version: 12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
async-mutex:
specifier: ^0.5.0
version: 0.5.0
@@ -284,8 +284,8 @@ devDependencies:
specifier: ^8.3.4
version: 8.3.4
tsafe:
specifier: ^1.7.5
version: 1.7.5
specifier: ^1.8.5
version: 1.8.5
type-fest:
specifier: ^4.26.1
version: 4.26.1
@@ -3323,7 +3323,7 @@ packages:
/@types/d3-drag@3.0.7:
resolution: {integrity: sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==}
dependencies:
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/d3-interpolate@3.0.4:
@@ -3332,21 +3332,21 @@ packages:
'@types/d3-color': 3.1.3
dev: false
/@types/d3-selection@3.0.10:
resolution: {integrity: sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg==}
/@types/d3-selection@3.0.11:
resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==}
dev: false
/@types/d3-transition@3.0.8:
resolution: {integrity: sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==}
/@types/d3-transition@3.0.9:
resolution: {integrity: sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==}
dependencies:
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/d3-zoom@3.0.8:
resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==}
dependencies:
'@types/d3-interpolate': 3.0.4
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/diff-match-patch@1.0.36:
@@ -3951,28 +3951,28 @@ packages:
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
dev: false
/@xyflow/react@12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-AFJKVc/fCPtgSOnRst3xdYJwiEcUN9lDY7EO/YiRvFHYCJGgfzg+jpvZjkTOnBLGyrMJre9378pRxAc3fsR06A==}
/@xyflow/react@12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-jMKQVqGwCz0x6pUyvxTIuCMbyehfua7CfEEWDj29zQSHigQpCy0/5d8aOmZrqK4cwur/pVHLQomT6Rm10gXfHg==}
peerDependencies:
react: '>=17'
react-dom: '>=17'
dependencies:
'@xyflow/system': 0.0.50
'@xyflow/system': 0.0.53
classcat: 5.0.5
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
zustand: 4.5.5(@types/react@18.3.11)(react@18.3.1)
zustand: 4.5.6(@types/react@18.3.11)(react@18.3.1)
transitivePeerDependencies:
- '@types/react'
- immer
dev: false
/@xyflow/system@0.0.50:
resolution: {integrity: sha512-HVUZd4LlY88XAaldFh2nwVxDOcdIBxGpQ5txzwfJPf+CAjj2BfYug1fHs2p4yS7YO8H6A3EFJQovBE8YuHkAdg==}
/@xyflow/system@0.0.53:
resolution: {integrity: sha512-QTWieiTtvNYyQAz1fxpzgtUGXNpnhfh6vvZa7dFWpWS2KOz6bEHODo/DTK3s07lDu0Bq0Db5lx/5M5mNjb9VDQ==}
dependencies:
'@types/d3-drag': 3.0.7
'@types/d3-selection': 3.0.10
'@types/d3-transition': 3.0.8
'@types/d3-selection': 3.0.11
'@types/d3-transition': 3.0.9
'@types/d3-zoom': 3.0.8
d3-drag: 3.0.0
d3-selection: 3.0.0
@@ -8791,8 +8791,8 @@ packages:
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
dev: false
/tsafe@1.7.5:
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
/tsafe@1.8.5:
resolution: {integrity: sha512-LFWTWQrW6rwSY+IBNFl2ridGfUzVsPwrZ26T4KUJww/py8rzaQ/SY+MIz6YROozpUCaRcuISqagmlwub9YT9kw==}
dev: true
/tsconfck@3.1.5(typescript@5.6.2):
@@ -9123,6 +9123,14 @@ packages:
react: 18.3.1
dev: false
/use-sync-external-store@1.4.0(react@18.3.1):
resolution: {integrity: sha512-9WXSPC5fMv61vaupRkCKCxsPxBocVnwakBEkMIHHpkTTg6icbJtg6jzgtLDm4bl3cSHAca52rYWih0k4K3PfHw==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
dependencies:
react: 18.3.1
dev: false
/util-deprecate@1.0.2:
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
dev: true
@@ -9567,8 +9575,8 @@ packages:
/zod@3.23.8:
resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
/zustand@4.5.5(@types/react@18.3.11)(react@18.3.1):
resolution: {integrity: sha512-+0PALYNJNgK6hldkgDq2vLrw5f6g/jCInz52n9RTpropGgeAf/ioFUCdtsjCqu4gNhW9D01rUQBROoRjdzyn2Q==}
/zustand@4.5.6(@types/react@18.3.11)(react@18.3.1):
resolution: {integrity: sha512-ibr/n1hBzLLj5Y+yUcU7dYw8p6WnIVzdJbnX+1YpaScvZVF2ziugqHs+LAmHw4lWO9c/zRj+K1ncgWDQuthEdQ==}
engines: {node: '>=12.7.0'}
peerDependencies:
'@types/react': '>=16.8'
@@ -9584,5 +9592,5 @@ packages:
dependencies:
'@types/react': 18.3.11
react: 18.3.1
use-sync-external-store: 1.2.2(react@18.3.1)
use-sync-external-store: 1.4.0(react@18.3.1)
dev: false

Some files were not shown because too many files have changed in this diff Show More