Compare commits

...

100 Commits

Author SHA1 Message Date
Millun Atluri
b2baead450 Corrected version 2023-10-11 21:11:03 +11:00
Millun Atluri
3be9b3ded8 Bump version to 3.3.0rc 2023-10-11 21:08:53 +11:00
Millun Atluri
47c7cff365 Updated frontend build 2023-10-11 21:07:12 +11:00
Ryan Dick
462c1d4c9b Improve model load times from disk: skip unnecessary weight init (#4840)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [x] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission
      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description

This PR optimizes the time to load models from disk.
In my local testing, SDXL text_encoder_2 models saw the greatest
improvement:
- Before change, load time (disk to cpu): 14 secs
- After change, load time (disk to cpu): 4 secs

See the in-code documentation for an explanation of how this speedup is
achieved.

## Related Tickets & Documents

This change was previously proposed on the HF transformers repo, but did
not get any traction:
https://github.com/huggingface/transformers/issues/18505#issue-1330728188

## QA Instructions, Screenshots, Recordings

I don't expect any adverse effects, but the new context manager is
applied while loading **all** models, so it would make sense to exercise
everything.

## Added/updated tests?

- [x] Yes
- [ ] No
2023-10-10 13:40:20 -04:00
Ryan Dick
0ed36158c8 Merge branch 'main' into ryan/optimize-model-load 2023-10-10 13:31:08 -04:00
Ryan Dick
f3c138a208 (minor) Fix Flake8. 2023-10-10 10:06:53 -04:00
Ryan Dick
61242bf86a Fix bug in skip_torch_weight_init() where the original behavior of torch.nn.Conv*d modules wasn't being restored correctly. 2023-10-10 10:05:50 -04:00
psychedelicious
d118d02df4 feat(ui): add mapping for sketch and scribble control adapter processors 2023-10-09 23:24:56 -04:00
Ryan Dick
58b56e9b1e Add a skip_torch_weight_init() context manager to improve model load times (from disk). 2023-10-09 14:12:56 -04:00
psychedelicious
1f751f8c21 fix(ui): remove extraneous cache update 2023-10-09 20:11:21 +11:00
psychedelicious
ca95a3bd0d fix(ui): fix canvas soft-lock if canceled before first generation
The canvas needs to be set to staging mode as soon as a canvas-destined batch is enqueued. If the batch is is fully canceled before an image is generated, we need to remove that batch from the canvas `batchIds` watchlist, else canvas gets stuck in staging mode with no way to exit.

The changes here allow the batch status to be tracked, and if a batch has all its items completed, we can remove it from the `batchIds` watchlist. The `batchIds` watchlist now accurately represents *incomplete* canvas batches, fixing this cause of soft lock.
2023-10-09 20:11:21 +11:00
psychedelicious
55b40a9425 feat(events): add batch status and queue status to queue item status changed events
The UI will always re-fetch queue and batch status on receiving this event, so we may as well jsut include that data in the event and save the extra network roundtrips.
2023-10-09 20:11:21 +11:00
psychedelicious
90083cc88d fix(ui): fix use all hotkey 2023-10-09 20:03:14 +11:00
Lincoln Stein
ead754432a add a lists of t2i adapters to startup set (#4828)
## What type of PR is this? (check all applicable)

- [X] Feature

## Have you discussed this change with the InvokeAI team?
- [X] No, because: Non-controversial

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] N/A


## Description

This adds a list of T2I adapters to the “starter models” offered by the
TUI installer. None of the models is selected by default; this can be
done easily if requested. The models offered to the user are:

```
TencentARC/t2iadapter_canny_sd15v2
TencentARC/t2iadapter_sketch_sd15v2
TencentARC/t2iadapter_depth_sd15v2
TencentARC/t2iadapter_zoedepth_sd15v1
TencentARC/t2i-adapter-canny-sdxl-1.0
TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
TencentARC/t2i-adapter-lineart-sdxl-1.0
TencentARC/t2i-adapter-sketch-sdxl-1.0
```

## Related Tickets & Documents

PR #4612 

## QA Instructions, Screenshots, Recordings

The revised installer has a new IP-ADAPTERS tab that looks like this:


![IMG_0255](https://github.com/invoke-ai/InvokeAI/assets/111189/0e01b1f6-7191-49a1-ac63-2c913826d299)

## Added/updated tests?

- [ ] Yes
- [X] No : It would be good to have a suite of model download tests, but
not set up yet.
2023-10-08 19:49:43 -04:00
Lincoln Stein
fa9ea93477 add a lists of t2i adapters to startup set 2023-10-08 18:53:21 -04:00
Lincoln Stein
fe0cf2c160 remove hardcoded subfolder name from model downloader 2023-10-08 17:45:39 -04:00
psychedelicious
a681fa4b03 fix(ui): invalidate query cache for all models on sync models
Also realised the tags were set up incorrectly, fixed that to get type safety with tags.
2023-10-07 22:30:15 +11:00
psychedelicious
1cc686734b feat(ui): on base model change, disable control adapters
Previously it deleted them entirely.
2023-10-07 22:30:15 +11:00
psychedelicious
82e8b92ba0 feat(ui): display toast when enabling t2i/controlnet and disabling the other 2023-10-07 22:30:15 +11:00
psychedelicious
e86658f864 feat(ui): disable invoke button if enabled control adapter model does not match base model 2023-10-07 22:30:15 +11:00
psychedelicious
ad136c2680 fix(ui): do not add control adapters with incompatible models to graph 2023-10-07 22:30:15 +11:00
psychedelicious
35374ec531 feat(ui): update graphs for multi ip adapter 2023-10-07 22:30:15 +11:00
psychedelicious
ed82bf6bb8 feat(ui): disable control adapter buttons if no models available 2023-10-07 22:30:15 +11:00
psychedelicious
078c9b6964 feat(nodes,ui): add t2i to linear UI
- Update backend metadata for t2i adapter
- Fix typo in `T2IAdapterInvocation`: `ip_adapter_model` -> `t2i_adapter_model`
- Update linear graphs to use t2i adapter
- Add client metadata recall for t2i adapter
- Fix bug with controlnet metadata recall - processor should be set to 'none' when recalling a control adapter
2023-10-07 22:30:15 +11:00
psychedelicious
1a9d2f1701 feat(ui): spruce up control adapter ui 2023-10-07 22:30:15 +11:00
psychedelicious
3e93159bce fix(ui): enable duplicated control adapter 2023-10-07 22:30:15 +11:00
psychedelicious
b57ebe52e4 chore(ui): "controlnet" -> "controladapters" 2023-10-07 22:30:15 +11:00
psychedelicious
ba4616ff89 feat(ui): add limits to enabled control adapters
- only 1 ip adapter at a time
- controlnet and t2i cannot both be active at once
2023-10-07 22:30:15 +11:00
psychedelicious
dcfbd49e1b fix(ui): fix control adapters recall 2023-10-07 22:30:15 +11:00
psychedelicious
913fc83cbf fix(ui): fix control adapter autoprocess 2023-10-07 22:30:15 +11:00
psychedelicious
6b8ce34eb3 fix(ui): fix excessive re-renders 2023-10-07 22:30:15 +11:00
psychedelicious
9508e0c9db feat(ui): refactor control adapters
Control adapters logic/state/ui is now generalized to hold controlnet, ip_adapter and t2i_adapter. In the future, other control adapter types can be added.

TODO:
- Limit IP adapter to 1
- Add T2I adapter to linear graphs
- Fix autoprocess
- T2I metadata saving & recall
- Improve on control adapters UI
2023-10-07 22:30:15 +11:00
Ryan Dick
9c720da021 Bump DenoiseLatentsInvocation version. 2023-10-06 20:43:43 -04:00
Ryan Dick
e1b576c72d yarn build 2023-10-06 20:43:43 -04:00
Ryan Dick
971ccfb081 Refactor multi-IP-Adapter to clean up the interface around changing scales. 2023-10-06 20:43:43 -04:00
Ryan Dick
43a3c3c7ea Fix typo in setting IP-Adapter scales. 2023-10-06 20:43:43 -04:00
Ryan Dick
4df1cdb34d Tidy _prepare_attention_processors(...) logic. 2023-10-06 20:43:43 -04:00
Ryan Dick
3f860c3523 Fixup IP-Adapter locale strings. 2023-10-06 20:43:43 -04:00
Ryan Dick
d8d0c9af09 Fix handling of scales with multiple IP-Adapters. 2023-10-06 20:43:43 -04:00
Ryan Dick
9403672ac0 Bugfix for multi-ip-adapter in DenoiseLatentsInvocation. 2023-10-06 20:43:43 -04:00
Ryan Dick
94591840a7 Frontend changes to enable multiple IP-Adapters in the workflow editor. 2023-10-06 20:43:43 -04:00
Ryan Dick
26b91a538a Fixes to get IP-Adapter tests working with new multi-IP-Adapter support. 2023-10-06 20:43:43 -04:00
Ryan Dick
7ca456d674 Update IP-Adapter model to enable running multiple IP-Adapters at once. (Not tested yet.) 2023-10-06 20:43:43 -04:00
Ryan Dick
78828b6b9c WIP - Accept a list of IPAdapterFields in DenoiseLatents. 2023-10-06 20:43:43 -04:00
Ryan Dick
166ff9d301 Proposal: Support slow tests that depend on models (#4813)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description

This PR adds support for slow unit tests that depend on models. It
includes:
- Documentation explaining the handling of fast vs. slow unit tests.
- Utilities to assist with writing tests that depend on models.
- A sample test that loads and runs an IP-Adapter model. This is far
from complete test coverage of IP-Adapter - it's just intended as a
first example of how to write tests with models.

**Suggestion for reviewers**: Start with docs/contributing/TESTS.md

## QA Instructions, Screenshots, Recordings

I've tested it all, but it would make sense for others to try running
both the fast tests and the slow tests.

## Added/updated tests?

- [x] Yes
- [ ] No
2023-10-06 19:55:38 -04:00
Ryan Dick
4f97bd4418 Merge branch 'main' into ryan/model-tests 2023-10-06 19:47:28 -04:00
Ryan Dick
e0e001758a Remove @slow decorator in favor of @pytest.mark.slow. 2023-10-06 18:26:06 -04:00
Ryan Dick
c1887135b3 Improve model cache debug logging (#4784)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [x] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description

This PR adds detailed debug logging to the model cache in order to give
more visibility into the model cache's memory utilization. **This PR
does not make any functional changes to the model cache.**

Every time a model is moved from disk to CPU, or between CPU/CUDA, a log
like this is emitted:
```bash
[2023-10-03 15:17:20,599]::[InvokeAI]::DEBUG --> Moved model '/home/ryan/invokeai/models/.cache/63742ed45b499e55620c402d6df26a20:sdxl:main:unet' from cpu to cuda in 1.23s.
Estimated model size: 4.782 GB.
Process RAM                    (-4.722): 6.987GB -> 2.265GB
libc mmap allocated            (-4.722): 6.030GB -> 1.308GB
libc arena used                (-0.061): 0.402GB -> 0.341GB
libc arena free                (+0.061): 0.006GB -> 0.067GB
libc total allocated           (-4.722): 6.439GB -> 1.717GB
libc total used                (-4.783): 6.433GB -> 1.649GB
VRAM                           (+4.881): 1.538GB -> 6.418GB
```

## Related Tickets & Documents

https://github.com/invoke-ai/InvokeAI/pull/4694 contains related fixes
to some known memory issues.

## QA Instructions, Screenshots, Recordings

Make sure debug logs are enabled and you should see the new logs.

We should test each of the following environments:
- [x] Linux
- [x] Mac OS + MPS
- [x] Windows

## Added/updated tests?

- [x] Yes
- [ ] No

Added unit tests for the new utilities. Test coverage is still low for
the ModelCache, but not worse than before.
2023-10-06 10:21:42 -04:00
Ryan Dick
096d195d6e Merge branch 'main' into ryan/model-cache-logging-only 2023-10-06 09:52:45 -04:00
Ryan Dick
7870b90717 Add TESTS.md documentation. 2023-10-05 15:38:25 -04:00
Ryan Dick
9854b244fd Fix Flake8 errors by using a pytest conftest.py file. 2023-10-05 15:36:15 -04:00
Ryan Dick
7d800e1ce3 Fix broken link in documentation to 'Frontend Documentation'. 2023-10-05 15:36:15 -04:00
Ryan Dick
1c8b1fbc53 POC of a test that depends on models. 2023-10-05 15:35:58 -04:00
Ryan Dick
594a3aef93 Set MALLOC_MMAP_THRESHOLD_=1048576 by default in invoke.sh. And add it to the manual installation docs. 2023-10-05 14:26:45 -04:00
Ryan Dick
78377469db Add support for T2I-Adapter in node workflows (#4612)
* Bump diffusers to 0.21.2.

* Add T2IAdapterInvocation boilerplate.

* Add T2I-Adapter model to model-management.

* (minor) Tidy prepare_control_image(...).

* Add logic to run the T2I-Adapter models at the start of the DenoiseLatentsInvocation.

* Add logic for applying T2I-Adapter weights and accumulating.

* Add T2IAdapter to MODEL_CLASSES map.

* yarn typegen

* Add model probes for T2I-Adapter models.

* Add all of the frontend boilerplate required to use T2I-Adapter in the nodes editor.

* Add T2IAdapterModel.convert_if_required(...).

* Fix errors in T2I-Adapter input image sizing logic.

* Fix bug with handling of multiple T2I-Adapters.

* black / flake8

* Fix typo

* yarn build

* Add num_channels param to prepare_control_image(...).

* Link to upstream diffusers bugfix PR that currently requires a workaround.

* feat: Add Color Map Preprocessor

Needed for the color T2I Adapter

* feat: Add Color Map Preprocessor to Linear UI

* Revert "feat: Add Color Map Preprocessor"

This reverts commit a1119a00bf.

* Revert "feat: Add Color Map Preprocessor to Linear UI"

This reverts commit bd8a9b82d8.

* Fix T2I-Adapter field rendering in workflow editor.

* yarn build, yarn typegen

---------

Co-authored-by: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-10-05 16:29:16 +11:00
Ryan Dick
fbe6452c45 Add support for IPAdapterPlusXL based on 6219530507. 2023-10-04 22:35:17 -04:00
psychedelicious
3f4ea073d1 fix(ui): throw on fetch err when copying image 2023-10-05 10:43:59 +11:00
psychedelicious
8b7f8eaea2 chore: flake8 2023-10-05 09:32:29 +11:00
psychedelicious
88e16ce051 fix(nodes): mark session queue items failed on processor error
When the processor has an error and it has a queue item, mark that item failed.

This addresses processor errors resulting in `in_progress` queue items, which create a soft lock of the processor, requiring the user to cancel the `in_progress` item before anything else processes.
2023-10-05 09:32:29 +11:00
psychedelicious
421440cae0 feat(nodes): exhaustive graph validation
Makes graph validation logic more rigorous, validating graphs when they are created as part of a session or batch.

`validate_self()` method added to `Graph` model. It does all the validation that `is_valid()` did, plus a few extras:
- unique `node.id` values across graph
- node ids match their key in `Graph.nodes`
- recursively validate subgraphs
- validate all edges
- validate graph is acyclical

The new method is required because `is_valid()` just returned a boolean. That behaviour is retained, but `validate_self()` now raises appropriate exceptions for validation errors. This are then surfaced to the client.

The function is named `validate_self()` because pydantic reserves `validate()`.

There are two main places where graphs are created - in batches and in sessions.

Field validators are added to each of these for their `graph` fields, which call the new validation logic.

**Closes #4744**

In this issue, a batch is enqueued with an invalid graph. The output field is typed as optional while the input field is required. The field types themselves are not relevant - this change addresses the case where an invalid graph was created.

The mismatched types problem is not noticed until we attempt to invoke the graph, because the graph was never *fully* validated. An error is raised during the call to `graph_execution_state.next()` in `invoker.py`. This function prepares the edges and validates them, raising an exception due to the mismatched types.

This exception is caught by the session processor, but it doesn't handle this situation well - the graph is not marked as having an error and the queue item status is never changed. The queue item is therefore forever `in_progress`, so no new queue items are popped - the app won't do anything until the queue item is canceled manually.

This commit addresses this by preventing invalid graphs from being created in the first place, addressing a substantial number of fail cases.
2023-10-05 09:32:29 +11:00
Jordan Hewitt
421021cede Add 'make 3d' plugin / community node (#4794)
* Add 'make 3d' plugin.

* Update communityNodes.md

Updated to Repo Link

---------

Co-authored-by: Jordan <srcrr-gitlab@ipriva.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Millun Atluri <Millu@users.noreply.github.com>
2023-10-04 21:41:21 +00:00
psychedelicious
020d4302d1 Change version bump from patch to minor
Because this adds a new field, it's a minor version bump
2023-10-05 08:24:52 +11:00
psychedelicious
8c59d2e5af chore: isort 2023-10-05 08:24:52 +11:00
psychedelicious
17d451eaa7 feat(images): add png_compress_level config
The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize

Closes #4786
2023-10-05 08:24:52 +11:00
psychedelicious
23a06fd06d feat(nodes): clear torch cache after upscaling
This can use many GB of VRAM, so we need to clean up after ourselves.
2023-10-05 08:24:52 +11:00
Darren Ringer
010c8e8038 Roll back change to buildAdHocUpscaleGraph.ts
Undo the change made here which was causing automated tests to fail.
2023-10-05 08:24:52 +11:00
Darren Ringer
dfc635223c Update upscale.py with minor style correction 2023-10-05 08:24:52 +11:00
Darren Ringer
37121a3a24 Add tile_size parameter to ESERGAN node in buildAdHocUpscaleGraph.ts
Adds tile_size parameter to support the changed ESRGAN node in invokeai/app/invocations/upscale.py
2023-10-05 08:24:52 +11:00
Darren Ringer
51b5de799a Update upscale.py to support tile kwarg of RealESRGANer
Adds tile_size field to the ESRGAN Upscaler node, which sends the tile kwarg to RealESRGANer's constructor, enabling tiled upscaling (default=512)
2023-10-05 08:24:52 +11:00
Mary Hipp
eadbe6abf7 handle 0 images/assets 2023-10-05 08:11:52 +11:00
psychedelicious
16f48a816f fix(ui): add dnd validation logic for multi-select board move 2023-10-05 08:11:52 +11:00
psychedelicious
95838e5559 fix(ui): fix remove from board dnd validation
This is fired when the dnd image is moved over the 'none' board. Weren't defaulting to 'none' for the image's board_id, resulting in it being possible to drag a 'none' image onto 'none'.
2023-10-05 08:11:52 +11:00
psychedelicious
3e8d62b1d1 fix(ui): fix duplicate image selection
Selections were not being `uniqBy()`'d, or were `uniqBy()`'d without a proper iteratee. This results in duplicate images in selections in certain situations.

Add correct `uniqBy()` to the reducer to prevent this in the future.
2023-10-05 08:11:52 +11:00
psychedelicious
2acc93eb8e feat(ui): remove all calls to getBoardImagesTotals/getBoardAssetsTotals
This caused a crapload of network requests any time an image was generated.

The counts are necessary to handle the logic for inserting images into existing image list caches; we have to keep track of the counts.

Replace tag invalidation with manual cache updates in all cases, except the initial request (which is necessary to get the initial image counts).

One subtle change is to make the counts an object instead of a number. This is required for `immer` to handle draft states. This should be raised as a bug with RTK Query, as no error is thrown when attempting to update a primitive immer draft.
2023-10-05 08:11:52 +11:00
Millun Atluri
fbb61f2334 Revert "Updated js files"
This reverts commit a0e936f3a7.
2023-10-04 22:32:00 +11:00
Millun Atluri
be85c7972b Updated js files 2023-10-04 22:32:00 +11:00
Millun Atluri
3a586fc9c4 Prevent caching to ensure updated UI is shown 2023-10-04 22:32:00 +11:00
Ryan Dick
2479a59e5e Re-enable garbage collection in model cache MemorySnapshots. 2023-10-03 15:18:47 -04:00
Ryan Dick
7d0ac2c36d (minor) clean up typos. 2023-10-03 15:00:03 -04:00
Ryan Dick
519b892f0c Add unit test for Struct_mallinfo2.__str__() 2023-10-03 14:25:34 -04:00
Ryan Dick
763dcacfd3 Add unit test for get_pretty_snapshot_diff(...). 2023-10-03 14:25:34 -04:00
Ryan Dick
3599d546e6 Add unit test for LibcUtil().mallinfo2(). 2023-10-03 14:25:34 -04:00
Ryan Dick
22a84930f6 Disable garbage collection in ModelCache calls to MemorySnapshot in order minimize snapshot overhead. 2023-10-03 14:25:34 -04:00
Ryan Dick
d64e17e043 Add README with info about glib memory fragmentation caused by the model cache. 2023-10-03 14:25:34 -04:00
Ryan Dick
ba54277011 Catch a more specific exception in environments that do not have a libc shared library. 2023-10-03 14:25:34 -04:00
Ryan Dick
5915a4a51c Minor fixes. 2023-10-03 14:25:34 -04:00
Ryan Dick
4580ba0d87 Remove logic to update model cache size estimates dynamically. 2023-10-03 14:25:34 -04:00
Ryan Dick
b9fd2e9e76 Improve get_pretty_snapshot_diff(...) message formatting. 2023-10-03 14:25:34 -04:00
Ryan Dick
75b65597af Add malloc info to MemorySnapshot. 2023-10-03 14:25:34 -04:00
Ryan Dick
2a3c0ab5d2 Move MemorySnapshot to its own file. 2023-10-03 14:25:34 -04:00
Ryan Dick
7d61373b82 Add LibcUtil class. 2023-10-03 14:25:34 -04:00
Ryan Dick
7d65555a5a Fix type error in torch device comparison. 2023-10-03 14:25:34 -04:00
Ryan Dick
123f2b2dbc Update cache model size estimates based on changes in VRAM when moving models to/from CUDA. 2023-10-03 14:25:34 -04:00
Ryan Dick
1e4e42556e Update model cache device comparison to treat 'cuda' and 'cuda:0' as the same device type. 2023-10-03 14:25:34 -04:00
Ryan Dick
1f6699ac43 Consolidate all model.to(...) calls in the model cache to use a utility function with better logging. 2023-10-03 14:25:34 -04:00
Ryan Dick
ace8665411 Add warning log if moving a model from cuda to cpu causes unexpected change in VRAM usage. 2023-10-03 14:25:34 -04:00
Ryan Dick
7fa5bae8fd Add warning log if moving model from RAM to VRAM causes an unexpected change in VRAM usage. 2023-10-03 14:25:34 -04:00
Ryan Dick
f9faca7c91 Add warning log if model mis-reports its required cache memory before load from disk. 2023-10-03 14:25:34 -04:00
Ryan Dick
594fd3ba6d Add debug logging of changes in RAM and VRAM for all model cache operations. 2023-10-03 14:25:34 -04:00
Ryan Dick
44d68f5ed5 Auto-format model_cache.py. 2023-10-03 14:25:34 -04:00
210 changed files with 6167 additions and 3143 deletions

View File

@@ -47,34 +47,9 @@ pip install ".[dev,test]"
These are optional groups of packages which are defined within the `pyproject.toml`
and will be required for testing the changes you make to the code.
### Running Tests
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
be found under the `./tests` folder and can be run with a single `pytest`
command. Optionally, to review test coverage you can append `--cov`.
```zsh
pytest --cov
```
Test outcomes and coverage will be reported in the terminal. In addition a more
detailed report is created in both XML and HTML format in the `./coverage`
folder. The HTML one in particular can help identify missing statements
requiring tests to ensure coverage. This can be run by opening
`./coverage/html/index.html`.
For example.
```zsh
pytest --cov; open ./coverage/html/index.html
```
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)
### Tests
See the [tests documentation](./TESTS.md) for information about running and writing tests.
### Reloading Changes
Experimenting with changes to the Python source code is a drag if you have to re-start the server —

View File

@@ -0,0 +1,89 @@
# InvokeAI Backend Tests
We use `pytest` to run the backend python tests. (See [pyproject.toml](/pyproject.toml) for the default `pytest` options.)
## Fast vs. Slow
All tests are categorized as either 'fast' (no test annotation) or 'slow' (annotated with the `@pytest.mark.slow` decorator).
'Fast' tests are run to validate every PR, and are fast enough that they can be run routinely during development.
'Slow' tests are currently only run manually on an ad-hoc basis. In the future, they may be automated to run nightly. Most developers are only expected to run the 'slow' tests that directly relate to the feature(s) that they are working on.
As a rule of thumb, tests should be marked as 'slow' if there is a chance that they take >1s (e.g. on a CPU-only machine with slow internet connection). Common examples of slow tests are tests that depend on downloading a model, or running model inference.
## Running Tests
Below are some common test commands:
```bash
# Run the fast tests. (This implicitly uses the configured default option: `-m "not slow"`.)
pytest tests/
# Equivalent command to run the fast tests.
pytest tests/ -m "not slow"
# Run the slow tests.
pytest tests/ -m "slow"
# Run the slow tests from a specific file.
pytest tests/path/to/slow_test.py -m "slow"
# Run all tests (fast and slow).
pytest tests -m ""
```
## Test Organization
All backend tests are in the [`tests/`](/tests/) directory. This directory mirrors the organization of the `invokeai/` directory. For example, tests for `invokeai/model_management/model_manager.py` would be found in `tests/model_management/test_model_manager.py`.
TODO: The above statement is aspirational. A re-organization of legacy tests is required to make it true.
## Tests that depend on models
There are a few things to keep in mind when adding tests that depend on models.
1. If a required model is not already present, it should automatically be downloaded as part of the test setup.
2. If a model is already downloaded, it should not be re-downloaded unnecessarily.
3. Take reasonable care to keep the total number of models required for the tests low. Whenever possible, re-use models that are already required for other tests. If you are adding a new model, consider including a comment to explain why it is required/unique.
There are several utilities to help with model setup for tests. Here is a sample test that depends on a model:
```python
import pytest
import torch
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.util.test_utils import install_and_load_model
@pytest.mark.slow
def test_model(model_installer, torch_device):
model_info = install_and_load_model(
model_installer=model_installer,
model_path_id_or_url="HF/dummy_model_id",
model_name="dummy_model",
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Dummy,
)
dummy_input = build_dummy_input(torch_device)
with torch.no_grad(), model_info as model:
model.to(torch_device, dtype=torch.float32)
output = model(dummy_input)
# Validate output...
```
## Test Coverage
To review test coverage, append `--cov` to your pytest command:
```bash
pytest tests/ --cov
```
Test outcomes and coverage will be reported in the terminal. In addition, a more detailed report is created in both XML and HTML format in the `./coverage` folder. The HTML output is particularly helpful in identifying untested statements where coverage should be improved. The HTML report can be viewed by opening `./coverage/html/index.html`.
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)

View File

@@ -12,7 +12,7 @@ To get started, take a look at our [new contributors checklist](newContributorCh
Once you're setup, for more information, you can review the documentation specific to your area of interest:
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
* #### [Frontend Documentation](development_guides/contributingToFrontend.md)
* #### [Frontend Documentation](./contributingToFrontend.md)
* #### [Node Documentation](../INVOCATIONS.md)
* #### [Local Development](../LOCAL_DEVELOPMENT.md)

View File

@@ -256,6 +256,10 @@ manager, please follow these steps:
*highly recommended** if your virtual environment is located outside of
your runtime directory.
!!! tip
On linux, it is recommended to run invokeai with the following env var: `MALLOC_MMAP_THRESHOLD_=1048576`. For example: `MALLOC_MMAP_THRESHOLD_=1048576 invokeai --web`. This helps to prevent memory fragmentation that can lead to memory accumulation over time. This env var is set automatically when running via `invoke.sh`.
10. Render away!
Browse the [features](../features/index.md) section to learn about all the

View File

@@ -10,6 +10,20 @@ To use a community workflow, download the the `.json` node graph file and load i
--------------------------------
--------------------------------
### Make 3D
**Description:** Create compelling 3D stereo images from 2D originals.
**Node Link:** [https://gitlab.com/srcrr/shift3d/-/raw/main/make3d.py](https://gitlab.com/srcrr/shift3d)
**Example Node Graph:** https://gitlab.com/srcrr/shift3d/-/raw/main/example-workflow.json?ref_type=heads&inline=false
**Output Examples**
![Painting of a cozy delapidated house](https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png){: style="height:512px;width:512px"}
![Photo of cute puppies](https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png){: style="height:512px;width:512px"}
--------------------------------
### Ideal Size

View File

@@ -46,6 +46,9 @@ if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
# Avoid glibc memory fragmentation. See invokeai/backend/model_management/README.md for details.
export MALLOC_MMAP_THRESHOLD_=1048576
# Primary function for the case statement to determine user input
do_choice() {
case $1 in

View File

@@ -68,6 +68,7 @@ class FieldDescriptions:
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
t2i_adapter = "T2I-Adapter(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"

View File

@@ -10,7 +10,7 @@ import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -33,6 +33,7 @@ from invokeai.app.invocations.primitives import (
LatentsOutput,
build_latents_output,
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
@@ -47,6 +48,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
@@ -196,7 +198,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.1.0",
version="1.3.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@@ -223,12 +225,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[IPAdapterField] = InputField(
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
)
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
)
@validator("cfg_scale")
@@ -404,52 +409,150 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None
image_encoder_model_info = context.services.model_manager.get_model(
model_name=ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if not isinstance(ip_adapter, list):
ip_adapter = [ip_adapter]
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model,
if len(ip_adapter) == 0:
return None
ip_adapter_data_list = []
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
image_encoder_model_info = context.services.model_manager.get_model(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
context=context,
)
)
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
)
ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
)
)
return IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=ip_adapter.weight,
begin_step_percent=ip_adapter.begin_step_percent,
end_step_percent=ip_adapter.end_step_percent,
)
return ip_adapter_data_list
def run_t2i_adapters(
self,
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
return None
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
if isinstance(t2i_adapter, T2IAdapterField):
t2i_adapter = [t2i_adapter]
if len(t2i_adapter) == 0:
return None
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_manager.get_model(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
context=context,
)
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
)
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
# this PR: https://github.com/huggingface/diffusers/pull/5134.
total_downscale_factor = total_downscale_factor // 2
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
# result will match the latent image's dimensions after max_unet_downscale is applied.
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
# T2I-Adapter model.
#
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
# of the same requirements (e.g. preserving binary masks during resize).
t2i_image = prepare_control_image(
image=image,
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config.in_channels,
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
adapter_state = t2i_adapter_model(t2i_image)
if do_classifier_free_guidance:
for idx, value in enumerate(adapter_state):
adapter_state[idx] = torch.cat([value] * 2, dim=0)
t2i_adapter_data.append(
T2IAdapterData(
adapter_state=adapter_state,
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
)
)
return t2i_adapter_data
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@@ -522,6 +625,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask, masked_latents = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
t2i_adapter_data = self.run_t2i_adapters(
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
@@ -580,7 +689,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack,
)
@@ -602,8 +710,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # IPAdapterData,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=step_callback,
)

View File

@@ -15,6 +15,7 @@ from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from ...version import __version__
@@ -63,6 +64,7 @@ class CoreMetadata(BaseModelExcludeNull):
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Optional[VAEModelField] = Field(
default=None,
@@ -139,6 +141,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
model: MainModelField = InputField(description="The main model used for inference")
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
strength: Optional[float] = InputField(
default=None,

View File

@@ -0,0 +1,83 @@
from typing import Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType
class T2IAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the T2I-Adapter model")
base_model: BaseModelType = Field(description="Base model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("t2i_adapter_output")
class T2IAdapterOutput(BaseInvocationOutput):
t2i_adapter: T2IAdapterField = OutputField(description=FieldDescriptions.t2i_adapter, title="T2I Adapter")
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(
default="just_resize",
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
)
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
return T2IAdapterOutput(
t2i_adapter=T2IAdapterField(
image=self.image,
t2i_adapter_model=self.t2i_adapter_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
)
)

View File

@@ -4,12 +4,14 @@ from typing import Literal
import cv2 as cv
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@@ -22,13 +24,19 @@ ESRGAN_MODELS = Literal[
"RealESRGAN_x2plus.pth",
]
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
image: ImageField = InputField(description="The input image")
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
tile_size: int = InputField(
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -86,9 +94,11 @@ class ESRGANInvocation(BaseInvocation):
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
@@ -99,6 +109,10 @@ class ESRGANInvocation(BaseInvocation):
# back to PIL
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,

View File

@@ -255,6 +255,7 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )

View File

@@ -4,7 +4,12 @@ from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
@@ -262,21 +267,31 @@ class EventServiceBase:
),
)
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
def emit_queue_item_status_changed(
self,
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
queue_id=session_queue_item.queue_id,
queue_item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
queue_id=queue_status.queue_id,
queue_item=dict(
queue_id=session_queue_item.queue_id,
item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
batch_status=batch_status.dict(),
queue_status=queue_status.dict(),
),
)

View File

@@ -2,7 +2,7 @@
import copy
import itertools
from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, root_validator, validator
@@ -170,6 +170,18 @@ class NodeIdMismatchError(ValueError):
pass
class InvalidSubGraphError(ValueError):
pass
class CyclicalGraphError(ValueError):
pass
class UnknownGraphValidationError(ValueError):
pass
# TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput):
@@ -254,59 +266,6 @@ class Graph(BaseModel):
default_factory=list,
)
@root_validator
def validate_nodes_and_edges(cls, values):
"""Validates that all edges match nodes in the graph"""
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
edges = cast(Optional[list[Edge]], values.get("edges"))
if nodes is not None:
# Validate that all node ids are unique
node_ids = [n.id for n in nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
# Validate that all node ids match the keys in the nodes dict
for k, v in nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
if edges is not None and nodes is not None:
# Validate that all edges match nodes in the graph
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
if missing_node_ids:
raise NodeNotFoundError(
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
)
# Validate that all edge fields match node fields in the graph
for edge in edges:
source_node = nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
destination_node = nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeFieldNotFoundError(
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
)
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
return values
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@@ -377,53 +336,108 @@ class Graph(BaseModel):
except KeyError:
pass
def is_valid(self) -> bool:
"""Validates the graph."""
def validate_self(self) -> None:
"""
Validates the graph.
Raises an exception if the graph is invalid:
- `DuplicateNodeIdError`
- `NodeIdMismatchError`
- `InvalidSubGraphError`
- `NodeNotFoundError`
- `NodeFieldNotFoundError`
- `CyclicalGraphError`
- `InvalidEdgeError`
"""
# Validate that all node ids are unique
node_ids = [n.id for n in self.nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
# Validate that all node ids match the keys in the nodes dict
for k, v in self.nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
if not gn.graph.is_valid():
return False
try:
gn.graph.validate_self()
except Exception as e:
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
# Validate all edges reference nodes in the graph
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
# Validate that all edges match nodes and fields in the graph
for edge in self.edges:
source_node = self.nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
destination_node = self.nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
# Validate there are no cycles
g = self.nx_graph_flat()
if not nx.is_directed_acyclic_graph(g):
return False
raise CyclicalGraphError("Graph contains cycles")
# Validate all edge connections are valid
if not all(
(
are_connections_compatible(
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
for e in self.edges:
if not are_connections_compatible(
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
):
raise InvalidEdgeError(
f"Invalid edge from {e.source.node_id}.{e.source.field} to {e.destination.node_id}.{e.destination.field}"
)
for e in self.edges
)
# Validate all iterators & collectors
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
for n in self.nodes.values():
if isinstance(n, IterateInvocation) and not self._is_iterator_connection_valid(n.id):
raise InvalidEdgeError(f"Invalid iterator node {n.id}")
if isinstance(n, CollectInvocation) and not self._is_collector_connection_valid(n.id):
raise InvalidEdgeError(f"Invalid collector node {n.id}")
return None
def is_valid(self) -> bool:
"""
Checks if the graph is valid.
Raises `UnknownGraphValidationError` if there is a problem validating the graph (not a validation error).
"""
try:
self.validate_self()
return True
except (
DuplicateNodeIdError,
NodeIdMismatchError,
InvalidSubGraphError,
NodeNotFoundError,
NodeFieldNotFoundError,
CyclicalGraphError,
InvalidEdgeError,
):
return False
# Validate all iterators
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
if not all(
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
):
return False
# Validate all collectors
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
if not all(
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
):
return False
return True
except Exception as e:
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
@@ -804,6 +818,12 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
class Config:
schema_extra = {
"required": [

View File

@@ -9,6 +9,7 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@@ -79,6 +80,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[Path, PILImageType]
__max_cache_size: int
__compress_level: int
def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict()
@@ -87,7 +89,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / "thumbnails"
self.__compress_level = InvokeAIAppConfig.get_config().png_compress_level
# Validate required output folders at launch
self.__validate_storage_folders()
@@ -134,7 +136,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
if original_workflow is not None:
pnginfo.add_text("invokeai_workflow", original_workflow)
image.save(image_path, "PNG", pnginfo=pnginfo)
image.save(image_path, "PNG", pnginfo=pnginfo, compress_level=self.__compress_level)
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)

View File

@@ -1,3 +1,4 @@
import traceback
from threading import BoundedSemaphore
from threading import Event as ThreadEvent
from threading import Thread
@@ -123,6 +124,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
continue
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
if queue_item is not None:
self.__invoker.services.session_queue.cancel_queue_item(
queue_item.item_id, error=traceback.format_exc()
)
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception as e:

View File

@@ -80,7 +80,7 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
"""Cancels a session queue item"""
pass

View File

@@ -123,6 +123,11 @@ class Batch(BaseModel):
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values
@validator("graph")
def validate_graph(cls, v: Graph):
v.validate_self()
return v
class Config:
schema_extra = {
"required": [

View File

@@ -427,7 +427,13 @@ class SqliteSessionQueue(SessionQueueBase):
finally:
self.__lock.release()
queue_item = self.get_queue_item(item_id)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
@@ -555,10 +561,11 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error)
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
@@ -608,7 +615,13 @@ class SqliteSessionQueue(SessionQueueBase):
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
except Exception:
self.__conn.rollback()
raise
@@ -654,7 +667,13 @@ class SqliteSessionQueue(SessionQueueBase):
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
except Exception:
self.__conn.rollback()
raise

View File

@@ -265,22 +265,41 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
image: Image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
# batch_size=1, # currently no batching
# num_images_per_prompt=1, # currently only single image
width: int,
height: int,
num_channels: int = 3,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced",
resize_mode="just_resize_simple",
):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
"""Pre-process images for ControlNets or T2I-Adapters.
Args:
image (Image): The PIL image to pre-process.
width (int): The target width in pixels.
height (int): The target height in pixels.
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.
control_mode (str, optional): Defaults to "balanced".
resize_mode (str, optional): Defaults to "just_resize_simple".
Raises:
NotImplementedError: If resize_mode == "crop_resize_simple".
NotImplementedError: If resize_mode == "fill_resize_simple".
ValueError: If `resize_mode` is not recognized.
ValueError: If `num_channels` is out of range.
Returns:
torch.Tensor: The pre-processed input tensor.
"""
if (
resize_mode == "just_resize_simple"
or resize_mode == "crop_resize_simple"
@@ -289,10 +308,10 @@ def prepare_control_image(
image = image.convert("RGB")
if resize_mode == "just_resize_simple":
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif resize_mode == "crop_resize_simple": # not yet implemented
pass
elif resize_mode == "fill_resize_simple": # not yet implemented
pass
elif resize_mode == "crop_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
elif resize_mode == "fill_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
nimage = np.array(image)
nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0)
@@ -313,9 +332,11 @@ def prepare_control_image(
device=device,
)
else:
pass
print("ERROR: invalid resize_mode ==> ", resize_mode)
exit(1)
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")
if timage.shape[1] < num_channels or num_channels <= 0:
raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.")
timage = timage[:, :num_channels, :, :]
timage = timage.to(device=device, dtype=dtype)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"

View File

@@ -335,7 +335,7 @@ class ModelInstall(object):
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
if subfolder:
files = [x for x in files if x.startswith("v2/")]
files = [x for x in files if x.startswith(f"{subfolder}/")]
prefix = f"{subfolder}/" if subfolder else ""
location = None

View File

@@ -8,6 +8,8 @@ import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
@@ -45,18 +47,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
assert len(weights) == len(scales)
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self._weights = weights
self._scales = scales
def __call__(
self,
@@ -67,16 +67,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@@ -128,23 +118,36 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ipa_embed
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)

View File

@@ -1,17 +1,15 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from contextlib import contextmanager
from typing import Optional, Union
import torch
from diffusers.models import UNet2DConditionModel
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler
@@ -61,7 +59,7 @@ class IPAdapter:
def __init__(
self,
state_dict: dict[torch.Tensor],
state_dict: dict[str, torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
@@ -73,12 +71,11 @@ class IPAdapter:
self._clip_image_processor = CLIPImageProcessor()
self._state_dict = state_dict
self._image_proj_model = self._init_image_proj_model(state_dict["image_proj"])
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
# The _attn_processors will be initialized later when we have access to the UNet.
self._attn_processors = None
self.attn_weights = IPAttentionWeights.from_state_dict(state_dict["ip_adapter"]).to(
self.device, dtype=self.dtype
)
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
@@ -86,99 +83,14 @@ class IPAdapter:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)
def calc_size(self):
if self._state_dict is not None:
image_proj_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
)
ip_adapter_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
)
return image_proj_size + ip_adapter_size
else:
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
torch.nn.ModuleList(self._attn_processors.values())
)
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
# which makes implementing begin_step_percent and end_step_percent easier
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
# which should make it easier to implement multiple IPAdapters
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)
# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
@@ -218,6 +130,20 @@ class IPAdapterPlus(IPAdapter):
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL."""
def _init_image_proj_model(self, state_dict):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
dim_head=64,
heads=20,
num_queries=self._num_tokens,
ff_mult=4,
).to(self.device, dtype=self.dtype)
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
@@ -228,6 +154,14 @@ def build_ip_adapter(
is_plus = "proj.weight" not in state_dict["image_proj"]
if is_plus:
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
# SD1 IP-Adapter Plus
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
elif cross_attention_dim == 2048:
# SDXL IP-Adapter Plus
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
else:
return IPAdapter(state_dict, device=device, dtype=dtype)

View File

@@ -0,0 +1,46 @@
import torch
class IPAttentionProcessorWeights(torch.nn.Module):
"""The IP-Adapter weights for a single attention processor.
This class is a torch.nn.Module sub-class to facilitate loading from a state_dict. It does not have a forward(...)
method.
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
class IPAttentionWeights(torch.nn.Module):
"""A collection of all the `IPAttentionProcessorWeights` objects for an IP-Adapter model.
This class is a torch.nn.Module sub-class so that it inherits the `.to(...)` functionality. It does not have a
forward(...) method.
"""
def __init__(self, weights: torch.nn.ModuleDict):
super().__init__()
self._weights = weights
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because
# `torch.nn.ModuleDict` only supports str keys.
return self._weights[str(int(idx))]
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
attn_proc_weights: dict[str, IPAttentionProcessorWeights] = {}
for tensor_name, tensor in state_dict.items():
if "to_k_ip.weight" in tensor_name:
index = str(int(tensor_name.split(".")[0]))
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
attn_proc_weights_module = torch.nn.ModuleDict(attn_proc_weights)
attn_proc_weights_module.load_state_dict(state_dict)
return cls(attn_proc_weights_module)

View File

@@ -0,0 +1,53 @@
from contextlib import contextmanager
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
class UNetPatcher:
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
def __init__(self, ip_adapters: list[IPAdapter]):
self._ip_adapters = ip_adapters
self._scales = [1.0] * len(self._ip_adapters)
def set_scale(self, idx: int, value: float):
self._scales[idx] = value
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"):
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._scales,
)
return attn_procs
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with IP-Adapter attention processors."""
attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@@ -0,0 +1,27 @@
# Model Cache
## `glibc` Memory Allocator Fragmentation
Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM).
This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`.
To better understand how the `glibc` memory allocator works, see these references:
- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html
- Details: https://sourceware.org/glibc/wiki/MallocInternals
Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation.
We can work around this memory fragmentation issue by setting the following env var:
```bash
# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed.
MALLOC_MMAP_THRESHOLD_=1048576
```
See the following references for more information about the `malloc` tunable parameters:
- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html
- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html
- https://man7.org/linux/man-pages/man3/mallopt.3.html
The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected.

View File

@@ -0,0 +1,75 @@
import ctypes
class Struct_mallinfo2(ctypes.Structure):
"""A ctypes Structure that matches the libc mallinfo2 struct.
Docs:
- https://man7.org/linux/man-pages/man3/mallinfo.3.html
- https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html
struct mallinfo2 {
size_t arena; /* Non-mmapped space allocated (bytes) */
size_t ordblks; /* Number of free chunks */
size_t smblks; /* Number of free fastbin blocks */
size_t hblks; /* Number of mmapped regions */
size_t hblkhd; /* Space allocated in mmapped regions (bytes) */
size_t usmblks; /* See below */
size_t fsmblks; /* Space in freed fastbin blocks (bytes) */
size_t uordblks; /* Total allocated space (bytes) */
size_t fordblks; /* Total free space (bytes) */
size_t keepcost; /* Top-most, releasable space (bytes) */
};
"""
_fields_ = [
("arena", ctypes.c_size_t),
("ordblks", ctypes.c_size_t),
("smblks", ctypes.c_size_t),
("hblks", ctypes.c_size_t),
("hblkhd", ctypes.c_size_t),
("usmblks", ctypes.c_size_t),
("fsmblks", ctypes.c_size_t),
("uordblks", ctypes.c_size_t),
("fordblks", ctypes.c_size_t),
("keepcost", ctypes.c_size_t),
]
def __str__(self):
s = ""
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n"
s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n"
s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n"
s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n"
s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n"
s += (
f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)"
" (GB)\n"
)
s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n"
s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n"
return s
class LibcUtil:
"""A utility class for interacting with the C Standard Library (`libc`) via ctypes.
Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where
this shared library is not available.
TODO: Improve cross-OS compatibility of this class.
"""
def __init__(self):
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
def mallinfo2(self) -> Struct_mallinfo2:
"""Calls `libc` `mallinfo2`.
Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html
"""
mallinfo2 = self._libc.mallinfo2
mallinfo2.restype = Struct_mallinfo2
return mallinfo2()

View File

@@ -0,0 +1,94 @@
import gc
from typing import Optional
import psutil
import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB
class MemorySnapshot:
"""A snapshot of RAM and VRAM usage. All values are in bytes."""
def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]):
"""Initialize a MemorySnapshot.
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
Args:
process_ram (int): CPU RAM used by the current process.
vram (Optional[int]): VRAM used by torch.
malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil.
"""
self.process_ram = process_ram
self.vram = vram
self.malloc_info = malloc_info
@classmethod
def capture(cls, run_garbage_collector: bool = True):
"""Capture and return a MemorySnapshot.
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
Args:
run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
usage. Defaults to True.
Returns:
MemorySnapshot
"""
if run_garbage_collector:
gc.collect()
# According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
# supported on all platforms.
process_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
vram = torch.cuda.memory_allocated()
else:
# TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
# time to test it properly.
vram = None
try:
malloc_info = LibcUtil().mallinfo2()
except OSError:
# This is expected in environments that do not have the 'libc.so.6' shared library.
malloc_info = None
return cls(process_ram, vram, malloc_info)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
def get_msg_line(prefix: str, val1: int, val2: int):
diff = val2 - val1
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
msg = ""
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd)
msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks)
msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks)
libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd
libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd
msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2)
libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd
libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)
if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return msg

View File

@@ -18,8 +18,10 @@ context. Use like this:
import gc
import hashlib
import math
import os
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
@@ -28,6 +30,8 @@ from typing import Any, Dict, Optional, Type, Union, types
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
from ..util.devices import choose_torch_device
from .models import BaseModelType, ModelBase, ModelType, SubModelType
@@ -44,6 +48,8 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes.
MB = 2**20
@dataclass
@@ -205,22 +211,44 @@ class ModelCache(object):
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
f"Loading model {model_path}, type"
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1
# this will remove older cached models until
# there is sufficient room to load the requested model
self._make_cache_room(model_info.get_size(submodel))
self_reported_model_size_before_load = model_info.get_size(submodel)
# Remove old models from the cache to make room for the new model.
self._make_cache_room(self_reported_model_size_before_load)
# clean memory to make MemoryUsage() more accurate
gc.collect()
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f"CPU RAM used for load: {(mem_used/GIG):.2f} GB")
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture()
end_load_time = time.time()
cache_entry = _CacheRecord(self, model, mem_used)
self_reported_model_size_after_load = model_info.get_size(submodel)
self.logger.debug(
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
# We only log a warning for over-reported (not under-reported) model sizes before load. There is a known
# issue where models report their fp32 size before load, and are then loaded as fp16. Once this issue is
# addressed, it would make sense to log a warning for both over-reported and under-reported model sizes.
if (self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
self.logger.warning(
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
)
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
self._cached_models[key] = cache_entry
else:
if self.stats:
@@ -240,6 +268,45 @@ class ModelCache(object):
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
def _move_model_to_device(self, key: str, target_device: torch.device):
cache_entry = self._cached_models[key]
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
# multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
cache_entry.model.to(target_device)
snapshot_after = MemorySnapshot.capture()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if snapshot_before.vram is not None and snapshot_after.vram is not None:
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.warning(
f"Moving model '{key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed):
"""
@@ -269,11 +336,7 @@ class ModelCache(object):
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models(self.size_needed)
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
self.cache._move_model_to_device(self.key, self.cache.execution_device)
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
@@ -286,7 +349,7 @@ class ModelCache(object):
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif self.cache_entry.loaded and not self.cache_entry.locked:
self.model.to(self.cache.storage_device)
self.cache._move_model_to_device(self.key, self.cache.storage_device)
return self.model
@@ -339,7 +402,8 @@ class ModelCache(object):
locked_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
f" {cached_models}/{loaded_models}/{locked_models}"
)
def _cache_size(self) -> int:
@@ -354,7 +418,8 @@ class ModelCache(object):
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB"
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
@@ -387,7 +452,8 @@ class ModelCache(object):
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}"
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
)
# 2 refs:
@@ -423,11 +489,9 @@ class ModelCache(object):
if vram_in_use <= reserved:
break
if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f"Offloading {model_key} from {self.execution_device} into {self.storage_device}")
with VRAMUsage() as mem:
cache_entry.model.to(self.storage_device)
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
vram_in_use += mem.vram_used # note vram_used is negative
self._move_model_to_device(model_key, self.storage_device)
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
@@ -454,16 +518,3 @@ class ModelCache(object):
with open(hashpath, "w") as f:
f.write(hash)
return hash
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self
def __exit__(self, *args):
self.vram_used = torch.cuda.memory_allocated() - self.vram

View File

@@ -0,0 +1,30 @@
from contextlib import contextmanager
import torch
def _no_op(*args, **kwargs):
pass
@contextmanager
def skip_torch_weight_init():
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
to skip weight initialization.
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
saved_functions = [m.reset_parameters for m in torch_modules]
try:
for torch_module in torch_modules:
torch_module.reset_parameters = _no_op
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions):
torch_module.reset_parameters = saved_function

View File

@@ -57,6 +57,7 @@ class ModelProbe(object):
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
@classmethod
@@ -408,6 +409,11 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
@@ -595,6 +601,26 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
@@ -603,6 +629,7 @@ ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInvers
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
@@ -611,5 +638,6 @@ ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInver
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@@ -25,6 +25,7 @@ from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
from .t2i_adapter import T2IAdapterModel
from .textual_inversion import TextualInversionModel
from .vae import VaeModel
@@ -38,6 +39,7 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model,
@@ -48,6 +50,7 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
@@ -59,6 +62,7 @@ MODEL_CLASSES = {
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
@@ -70,6 +74,7 @@ MODEL_CLASSES = {
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel,
@@ -81,6 +86,7 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,

View File

@@ -53,6 +53,7 @@ class ModelType(str, Enum):
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):

View File

@@ -0,0 +1,102 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from diffusers import T2IAdapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class T2IAdapterModelFormat(str, Enum):
Diffusers = "diffusers"
class T2IAdapterModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[T2IAdapterModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.T2IAdapter
super().__init__(model_path, base_model, model_type)
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"T2IAdapter"}:
raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.")
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> T2IAdapter:
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
model = None
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception:
pass
if not model:
raise ModelNotFoundException()
# Calculate a more accurate size after loading the model into memory.
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException(f"Model not found at '{path}'.")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return T2IAdapterModelFormat.Diffusers
raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == T2IAdapterModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@@ -24,6 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from ..util import auto_detect_slice_size, normalize_device
@@ -173,6 +174,16 @@ class IPAdapterData:
end_step_percent: float = Field(default=1.0)
@dataclass
class T2IAdapterData:
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
adapter_state: dict[torch.Tensor] = Field()
weight: Union[float, list[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
@@ -326,7 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
@@ -379,6 +391,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback,
)
finally:
@@ -398,7 +411,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
@@ -411,6 +425,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents, attention_map_saver
ip_adapter_unet_patcher = None
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
@@ -421,11 +436,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
weight = ip_adapter_data.weight[0] if isinstance(ip_adapter_data.weight, List) else ip_adapter_data.weight
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model,
scale=weight,
)
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
@@ -454,6 +466,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
)
latents = step_output.prev_sample
@@ -499,7 +513,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@@ -512,26 +528,30 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
first_adapter_step = math.floor(ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(ip_adapter_data.end_step_percent * total_step_count)
weight = (
ip_adapter_data.weight[step_index]
if isinstance(ip_adapter_data.weight, List)
else ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# only apply IP-Adapter if current step is within the IP-Adapter's begin/end step range
# ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
ip_adapter_data.ip_adapter_model.set_scale(weight)
else:
# otherwise, set IP-Adapter scale to 0, so it has no effect
ip_adapter_data.ip_adapter_model.set_scale(0.0)
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
weight = (
single_ip_adapter_data.weight[step_index]
if isinstance(single_ip_adapter_data.weight, List)
else single_ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
ip_adapter_unet_patcher.set_scale(i, weight)
else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0)
# handle ControlNet(s)
# default is no controlnet, so set controlnet processing output to None
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
if control_data is not None:
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
# Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
if control_data is not None and t2i_adapter_data is not None:
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
elif control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data,
sample=latent_model_input,
timestep=timestep,
@@ -539,6 +559,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count,
conditioning_data=conditioning_data,
)
elif t2i_adapter_data is not None:
accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data:
# Determine the T2I-Adapter weights for the current denoising step.
first_t2i_adapter_step = math.floor(single_t2i_adapter_data.begin_step_percent * total_step_count)
last_t2i_adapter_step = math.ceil(single_t2i_adapter_data.end_step_percent * total_step_count)
t2i_adapter_weight = (
single_t2i_adapter_data.weight[step_index]
if isinstance(single_t2i_adapter_data.weight, list)
else single_t2i_adapter_data.weight
)
if step_index < first_t2i_adapter_step or step_index > last_t2i_adapter_step:
# If the current step is outside of the T2I-Adapter's begin/end step range, then set its weight to 0
# so it has no effect.
t2i_adapter_weight = 0.0
# Apply the t2i_adapter_weight, and accumulate.
if accum_adapter_state is None:
# Handle the first T2I-Adapter.
accum_adapter_state = [val * t2i_adapter_weight for val in single_t2i_adapter_data.adapter_state]
else:
# Add to the previous adapter states.
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight
down_block_additional_residuals = accum_adapter_state
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
@@ -547,8 +593,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count,
conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
)
guidance_scale = conditioning_data.guidance_scale

View File

@@ -81,7 +81,7 @@ class ConditioningData:
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
@property
def dtype(self):

View File

@@ -346,12 +346,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": torch.cat(
[
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
]
)
"ip_adapter_image_prompt_embeds": [
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
added_cond_kwargs = None
@@ -418,7 +416,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.uncond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
added_cond_kwargs = None
@@ -444,7 +445,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.cond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
added_cond_kwargs = None

View File

@@ -0,0 +1,67 @@
import contextlib
from pathlib import Path
from typing import Optional, Union
import pytest
import torch
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
@pytest.fixture(scope="session")
def torch_device():
return "cuda" if torch.cuda.is_available() else "cpu"
@pytest.fixture(scope="module")
def model_installer():
"""A global ModelInstall pytest fixture to be used by many tests."""
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
# a singleton.
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
def install_and_load_model(
model_installer: ModelInstall,
model_path_id_or_url: Union[str, Path],
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelInfo:
"""Install a model if it is not already installed, then get the ModelInfo for that model.
This is intended as a utility function for tests.
Args:
model_installer (ModelInstall): The model installer.
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
is not already installed.
model_name (str): The model name, forwarded to ModelManager.get_model(...).
base_model (BaseModelType): The base model, forwarded to ModelManager.get_model(...).
model_type (ModelType): The model type, forwarded to ModelManager.get_model(...).
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
Returns:
ModelInfo
"""
# If the requested model is already installed, return its ModelInfo.
with contextlib.suppress(ModelNotFoundException):
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
# Install the requested model.
model_installer.heuristic_import(model_path_id_or_url)
try:
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
except ModelNotFoundException as e:
raise Exception(
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
)

View File

@@ -96,6 +96,22 @@ sd-1/controlnet/tile:
repo_id: lllyasviel/control_v11f1e_sd15_tile
sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p
sd-1/t2i_adapter/canny-sd15:
repo_id: TencentARC/t2iadapter_canny_sd15v2
sd-1/t2i_adapter/sketch-sd15:
repo_id: TencentARC/t2iadapter_sketch_sd15v2
sd-1/t2i_adapter/depth-sd15:
repo_id: TencentARC/t2iadapter_depth_sd15v2
sd-1/t2i_adapter/zoedepth-sd15:
repo_id: TencentARC/t2iadapter_zoedepth_sd15v1
sdxl/t2i_adapter/canny-sdxl:
repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0
sdxl/t2i_adapter/zoedepth-sdxl:
repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
sdxl/t2i_adapter/lineart-sdxl:
repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0
sdxl/t2i_adapter/sketch-sdxl:
repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0
sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True

View File

@@ -98,15 +98,16 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.tabs = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"STARTER MODELS",
"MAIN MODELS",
"STARTERS",
"MAINS",
"CONTROLNETS",
"T2I-ADAPTERS",
"IP-ADAPTERS",
"LORA/LYCORIS",
"TEXTUAL INVERSION",
"LORAS",
"TI EMBEDDINGS",
],
value=[self.current_tab],
columns=6,
columns=7,
max_height=2,
relx=8,
scroll_exit=True,
@@ -131,6 +132,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.t2i_models = self.add_model_widgets(
model_type=ModelType.T2IAdapter,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.ipadapter_models = self.add_model_widgets(
model_type=ModelType.IPAdapter,
@@ -351,6 +358,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
@@ -541,6 +549,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1 as F,i2 as G,i3 as W,i4 as K,aG as Y,i5 as Z,i6 as H}from"./index-94062f76.js";import{M as U}from"./MantineProvider-a057bfc9.js";var P=String.raw,E=P`
import{w as s,h_ as T,v as l,a2 as I,h$ as R,ae as V,i0 as z,i1 as j,i2 as D,i3 as F,i4 as G,i5 as W,i6 as K,aG as H,i7 as U,i8 as Y}from"./index-90346e33.js";import{M as Z}from"./MantineProvider-486d4834.js";var P=String.raw,E=P`
:root,
:host {
--chakra-vh: 100vh;
@@ -277,4 +277,4 @@ import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1
}
${E}
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=Y(),n=o.dir(),r=l.useMemo(()=>ie({...Z,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(U,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:H,children:e})})}const ve=l.memo(he);export{ve as default};
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(he);export{ve as default};

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -3,6 +3,9 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
<meta http-equiv="Pragma" content="no-cache">
<meta http-equiv="Expires" content="0">
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<style>
@@ -12,7 +15,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-94062f76.js"></script>
<script type="module" crossorigin src="./assets/index-90346e33.js"></script>
</head>
<body dir="ltr">

View File

@@ -49,8 +49,10 @@
"cancel": "Cancel",
"close": "Close",
"communityLabel": "Community",
"controlNet": "Controlnet",
"controlNet": "ControlNet",
"controlAdapter": "Control Adapter",
"ipAdapter": "IP Adapter",
"t2iAdapter": "T2I Adapter",
"darkMode": "Dark Mode",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
@@ -130,6 +132,16 @@
"upload": "Upload"
},
"controlnet": {
"controlAdapter": "Control Adapter",
"controlnet": "$t(controlnet.controlAdapter) #{{number}} ($t(common.controlNet))",
"ip_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.t2iAdapter))",
"addControlNet": "Add $t(common.controlNet)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"addT2IAdapter": "Add $t(common.t2iAdapter)",
"controlNetEnabledT2IDisabled": "$t(common.controlNet) enabled, $t(common.t2iAdapter)s disabled",
"t2iEnabledControlNetDisabled": "$t(common.t2iAdapter) enabled, $t(common.controlNet)s disabled",
"controlNetT2IMutexDesc": "$t(common.controlNet) and $t(common.t2iAdapter) at same time is currently unsupported.",
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
@@ -697,7 +709,7 @@
"noLoRAsAvailable": "No LoRAs available",
"noMatchingLoRAs": "No matching LoRAs",
"noMatchingModels": "No matching Models",
"noModelsAvailable": "No Modelss available",
"noModelsAvailable": "No models available",
"selectLoRA": "Select a LoRA",
"selectModel": "Select a Model"
},
@@ -785,6 +797,14 @@
"integerPolymorphic": "Integer Polymorphic",
"integerPolymorphicDescription": "A collection of integers.",
"invalidOutputSchema": "Invalid output schema",
"ipAdapter": "IP-Adapter",
"ipAdapterCollection": "IP-Adapters Collection",
"ipAdapterCollectionDescription": "A collection of IP-Adapters.",
"ipAdapterDescription": "An Image Prompt Adapter (IP-Adapter).",
"ipAdapterModel": "IP-Adapter Model",
"ipAdapterModelDescription": "IP-Adapter Model Field",
"ipAdapterPolymorphic": "IP-Adapter Polymorphic",
"ipAdapterPolymorphicDescription": "A collection of IP-Adapters.",
"latentsCollection": "Latents Collection",
"latentsCollectionDescription": "Latents may be passed between nodes.",
"latentsField": "Latents",
@@ -944,9 +964,10 @@
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} missing input",
"missingNodeTemplate": "Missing node template",
"noControlImageForControlNet": "ControlNet {{index}} has no control image",
"noControlImageForControlAdapter": "Control Adapter {{number}} has no control image",
"noInitialImageSelected": "No initial image selected",
"noModelForControlNet": "ControlNet {{index}} has no model selected.",
"noModelForControlAdapter": "Control Adapter {{number}} has no model selected.",
"incompatibleBaseModelForControlAdapter": "Control Adapter {{number}} model is invalid with main model.",
"noModelSelected": "No model selected",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
@@ -1085,7 +1106,8 @@
},
"toast": {
"addedToBoard": "Added to board",
"baseModelChangedCleared": "Base model changed, cleared",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{number}} incompatible submodel",
"baseModelChangedCleared_many": "$t(toast.baseModelChangedCleared_one)s",
"canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded",
@@ -1105,7 +1127,6 @@
"imageSavingFailed": "Image Saving Failed",
"imageUploaded": "Image Uploaded",
"imageUploadFailed": "Image Upload Failed",
"incompatibleSubmodel": "incompatible submodel",
"initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image",
"initialImageSet": "Initial Image Set",

View File

@@ -3,6 +3,9 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
<meta http-equiv="Pragma" content="no-cache">
<meta http-equiv="Expires" content="0">
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="favicon.ico" />
<style>

View File

@@ -157,6 +157,7 @@
"prettier": "^3.0.2",
"rollup-plugin-visualizer": "^5.9.2",
"ts-toolbelt": "^9.6.0",
"typescript": "^5.2.2",
"vite": "^4.4.9",
"vite-plugin-css-injected-by-js": "^3.3.0",
"vite-plugin-dts": "^3.5.2",

View File

@@ -49,8 +49,10 @@
"cancel": "Cancel",
"close": "Close",
"communityLabel": "Community",
"controlNet": "Controlnet",
"controlNet": "ControlNet",
"controlAdapter": "Control Adapter",
"ipAdapter": "IP Adapter",
"t2iAdapter": "T2I Adapter",
"darkMode": "Dark Mode",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
@@ -130,6 +132,16 @@
"upload": "Upload"
},
"controlnet": {
"controlAdapter": "Control Adapter",
"controlnet": "$t(controlnet.controlAdapter) #{{number}} ($t(common.controlNet))",
"ip_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.t2iAdapter))",
"addControlNet": "Add $t(common.controlNet)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"addT2IAdapter": "Add $t(common.t2iAdapter)",
"controlNetEnabledT2IDisabled": "$t(common.controlNet) enabled, $t(common.t2iAdapter)s disabled",
"t2iEnabledControlNetDisabled": "$t(common.t2iAdapter) enabled, $t(common.controlNet)s disabled",
"controlNetT2IMutexDesc": "$t(common.controlNet) and $t(common.t2iAdapter) at same time is currently unsupported.",
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
@@ -785,6 +797,14 @@
"integerPolymorphic": "Integer Polymorphic",
"integerPolymorphicDescription": "A collection of integers.",
"invalidOutputSchema": "Invalid output schema",
"ipAdapter": "IP-Adapter",
"ipAdapterCollection": "IP-Adapters Collection",
"ipAdapterCollectionDescription": "A collection of IP-Adapters.",
"ipAdapterDescription": "An Image Prompt Adapter (IP-Adapter).",
"ipAdapterModel": "IP-Adapter Model",
"ipAdapterModelDescription": "IP-Adapter Model Field",
"ipAdapterPolymorphic": "IP-Adapter Polymorphic",
"ipAdapterPolymorphicDescription": "A collection of IP-Adapters.",
"latentsCollection": "Latents Collection",
"latentsCollectionDescription": "Latents may be passed between nodes.",
"latentsField": "Latents",
@@ -944,9 +964,10 @@
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} missing input",
"missingNodeTemplate": "Missing node template",
"noControlImageForControlNet": "ControlNet {{index}} has no control image",
"noControlImageForControlAdapter": "Control Adapter {{number}} has no control image",
"noInitialImageSelected": "No initial image selected",
"noModelForControlNet": "ControlNet {{index}} has no model selected.",
"noModelForControlAdapter": "Control Adapter {{number}} has no model selected.",
"incompatibleBaseModelForControlAdapter": "Control Adapter {{number}} model is invalid with main model.",
"noModelSelected": "No model selected",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
@@ -1085,7 +1106,8 @@
},
"toast": {
"addedToBoard": "Added to board",
"baseModelChangedCleared": "Base model changed, cleared",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{number}} incompatible submodel",
"baseModelChangedCleared_many": "$t(toast.baseModelChangedCleared_one)s",
"canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded",
@@ -1105,7 +1127,6 @@
"imageSavingFailed": "Image Saving Failed",
"imageUploaded": "Image Uploaded",
"imageUploadFailed": "Image Upload Failed",
"incompatibleSubmodel": "incompatible submodel",
"initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image",
"initialImageSet": "Initial Image Set",

View File

@@ -1,5 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
import { controlAdaptersPersistDenylist } from 'features/controlAdapters/store/controlAdaptersPersistDenylist';
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
@@ -20,7 +20,7 @@ const serializationDenylist: {
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
ui: uiPersistDenylist,
controlNet: controlNetDenylist,
controlNet: controlAdaptersPersistDenylist,
dynamicPrompts: dynamicPromptsPersistDenylist,
};

View File

@@ -1,5 +1,5 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
import { initialControlAdapterState } from 'features/controlAdapters/store/controlAdaptersSlice';
import { initialDynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
@@ -25,7 +25,7 @@ const initialStates: {
config: initialConfigState,
ui: initialUIState,
hotkeys: initialHotkeysState,
controlNet: initialControlNetState,
controlAdapters: initialControlAdapterState,
dynamicPrompts: initialDynamicPromptsState,
sdxl: initialSDXLState,
};

View File

@@ -72,6 +72,7 @@ import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
import { addControlAdapterAddedOrEnabledListener } from './listeners/controlAdapterAddedOrEnabled';
export const listenerMiddleware = createListenerMiddleware();
@@ -199,3 +200,7 @@ addTabChangedListener();
// Dynamic prompts
addDynamicPromptsListener();
// Display toast when controlnet or t2i adapter enabled
// TODO: Remove when they can both be enabled at same time
addControlAdapterAddedOrEnabledListener();

View File

@@ -1,8 +1,5 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlNetReset,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
@@ -20,8 +17,7 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
let wasInitialImageReset = false;
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wasControlNetReset = false;
let wasIPAdapterReset = false;
let wereControlAdaptersReset = false;
const state = getState();
deleted_images.forEach((image_name) => {
@@ -42,14 +38,9 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
wasNodeEditorReset = true;
}
if (imageUsage.isControlNetImage && !wasControlNetReset) {
dispatch(controlNetReset());
wasControlNetReset = true;
}
if (imageUsage.isIPAdapterImage && !wasIPAdapterReset) {
dispatch(ipAdapterStateReset());
wasIPAdapterReset = true;
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
dispatch(controlAdaptersReset());
wereControlAdaptersReset = true;
}
});
},

View File

@@ -1,20 +1,21 @@
import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
export const addCanvasImageToControlNetListener = () => {
startAppListening({
actionCreator: canvasImageToControlNet,
actionCreator: canvasImageToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
let blob;
let blob: Blob;
try {
blob = await getBaseLayerBlob(state, true);
} catch (err) {
@@ -50,8 +51,8 @@ export const addCanvasImageToControlNetListener = () => {
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@@ -9,11 +9,11 @@ import { startAppListening } from '..';
export const addCanvasMaskToControlNetListener = () => {
startAppListening({
actionCreator: canvasMaskToControlNet,
actionCreator: canvasMaskToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
@@ -61,8 +61,8 @@ export const addCanvasMaskToControlNetListener = () => {
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);

View File

@@ -0,0 +1,87 @@
import { isAnyOf } from '@reduxjs/toolkit';
import {
controlAdapterAdded,
controlAdapterAddedFromImage,
controlAdapterIsEnabledChanged,
controlAdapterRecalled,
selectControlAdapterAll,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { ControlAdapterType } from 'features/controlAdapters/store/types';
import { addToast } from 'features/system/store/systemSlice';
import i18n from 'i18n';
import { startAppListening } from '..';
const isAnyControlAdapterAddedOrEnabled = isAnyOf(
controlAdapterAdded,
controlAdapterAddedFromImage,
controlAdapterRecalled,
controlAdapterIsEnabledChanged
);
/**
* Until we can have both controlnet and t2i adapter enabled at once, they are mutually exclusive
* This displays a toast when one is enabled and the other is already enabled, or one is added
* with the other enabled
*/
export const addControlAdapterAddedOrEnabledListener = () => {
startAppListening({
matcher: isAnyControlAdapterAddedOrEnabled,
effect: async (action, { dispatch, getOriginalState }) => {
const controlAdapters = getOriginalState().controlAdapters;
const hasEnabledControlNets = selectControlAdapterAll(
controlAdapters
).some((ca) => ca.isEnabled && ca.type === 'controlnet');
const hasEnabledT2IAdapters = selectControlAdapterAll(
controlAdapters
).some((ca) => ca.isEnabled && ca.type === 't2i_adapter');
let caType: ControlAdapterType | null = null;
if (controlAdapterAdded.match(action)) {
caType = action.payload.type;
}
if (controlAdapterAddedFromImage.match(action)) {
caType = action.payload.type;
}
if (controlAdapterRecalled.match(action)) {
caType = action.payload.type;
}
if (controlAdapterIsEnabledChanged.match(action)) {
const _caType = selectControlAdapterById(
controlAdapters,
action.payload.id
)?.type;
if (!_caType) {
return;
}
caType = _caType;
}
if (
(caType === 'controlnet' && hasEnabledT2IAdapters) ||
(caType === 't2i_adapter' && hasEnabledControlNets)
) {
const title =
caType === 'controlnet'
? i18n.t('controlnet.controlNetEnabledT2IDisabled')
: i18n.t('controlnet.t2iEnabledControlNetDisabled');
const description = i18n.t('controlnet.controlNetT2IMutexDesc');
dispatch(
addToast({
title,
description,
status: 'warning',
})
);
}
},
});
};

View File

@@ -1,15 +1,24 @@
import { AnyListenerPredicate } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import {
controlNetAutoConfigToggled,
controlNetImageChanged,
controlNetModelChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterAutoConfigToggled,
controlAdapterImageChanged,
controlAdapterModelChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
type AnyControlAdapterParamChangeAction =
| ReturnType<typeof controlAdapterProcessorParamsChanged>
| ReturnType<typeof controlAdapterModelChanged>
| ReturnType<typeof controlAdapterImageChanged>
| ReturnType<typeof controlAdapterProcessortTypeChanged>
| ReturnType<typeof controlAdapterAutoConfigToggled>;
const predicate: AnyListenerPredicate<RootState> = (
action,
@@ -17,35 +26,37 @@ const predicate: AnyListenerPredicate<RootState> = (
prevState
) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) ||
controlNetImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action) ||
controlNetAutoConfigToggled.match(action);
controlAdapterProcessorParamsChanged.match(action) ||
controlAdapterModelChanged.match(action) ||
controlAdapterImageChanged.match(action) ||
controlAdapterProcessortTypeChanged.match(action) ||
controlAdapterAutoConfigToggled.match(action);
if (!isActionMatched) {
return false;
}
if (controlNetAutoConfigToggled.match(action)) {
const { id } = action.payload;
const prevCA = selectControlAdapterById(prevState.controlAdapters, id);
const ca = selectControlAdapterById(state.controlAdapters, id);
if (
!prevCA ||
!isControlNetOrT2IAdapter(prevCA) ||
!ca ||
!isControlNetOrT2IAdapter(ca)
) {
return false;
}
if (controlAdapterAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
?.shouldAutoConfig === true
) {
if (prevCA.shouldAutoConfig === true) {
return false;
}
}
const cn = state.controlNet.controlNets[action.payload.controlNetId];
if (!cn) {
// something is wrong, the controlNet should exist
return false;
}
const { controlImage, processorType, shouldAutoConfig } = cn;
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
const { controlImage, processorType, shouldAutoConfig } = ca;
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;
}
@@ -67,7 +78,7 @@ export const addControlNetAutoProcessListener = () => {
predicate,
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const log = logger('session');
const { controlNetId } = action.payload;
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
@@ -75,7 +86,7 @@ export const addControlNetAutoProcessListener = () => {
// Delay before starting actual work
await delay(300);
dispatch(controlNetImageProcessed({ controlNetId }));
dispatch(controlAdapterImageProcessed({ id }));
},
});
};

View File

@@ -1,11 +1,11 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import {
clearPendingControlImages,
controlNetImageChanged,
controlNetProcessedImageChanged,
} from 'features/controlNet/store/controlNetSlice';
pendingControlImagesCleared,
controlAdapterImageChanged,
selectControlAdapterById,
controlAdapterProcessedImageChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
@@ -15,28 +15,34 @@ import { isImageOutput } from 'services/api/guards';
import { Graph, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
export const addControlNetImageProcessedListener = () => {
startAppListening({
actionCreator: controlNetImageProcessed,
actionCreator: controlAdapterImageProcessed,
effect: async (action, { dispatch, getState, take }) => {
const log = logger('session');
const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId];
const { id } = action.payload;
const ca = selectControlAdapterById(getState().controlAdapters, id);
if (!controlNet?.controlImage) {
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
log.error('Unable to process ControlNet image');
return;
}
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const graph: Graph = {
nodes: {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
[ca.processorNode.id]: {
...ca.processorNode,
is_intermediate: true,
image: { image_name: controlNet.controlImage },
image: { image_name: ca.controlImage },
},
[SAVE_IMAGE]: {
id: SAVE_IMAGE,
@@ -48,7 +54,7 @@ export const addControlNetImageProcessedListener = () => {
edges: [
{
source: {
node_id: controlNet.processorNode.id,
node_id: ca.processorNode.id,
field: 'image',
},
destination: {
@@ -103,8 +109,8 @@ export const addControlNetImageProcessedListener = () => {
// Update the processed image in the store
dispatch(
controlNetProcessedImageChanged({
controlNetId,
controlAdapterProcessedImageChanged({
id,
processedControlImage: processedControlImage.image_name,
})
);
@@ -126,10 +132,8 @@ export const addControlNetImageProcessedListener = () => {
duration: 15000,
})
);
dispatch(clearPendingControlImages());
dispatch(
controlNetImageChanged({ controlNetId, controlImage: null })
);
dispatch(pendingControlImagesCleared());
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
return;
}
}

View File

@@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetProcessedImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
@@ -17,6 +17,7 @@ import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
export const addRequestedSingleImageDeletionListener = () => {
startAppListening({
@@ -90,35 +91,28 @@ export const addRequestedSingleImageDeletionListener = () => {
dispatch(clearInitialImage());
}
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
// reset control adapters that use the deleted images
forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
if (
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
) {
dispatch(
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterImageChanged({
id: ca.id,
controlImage: null,
})
);
dispatch(
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterProcessedImageChanged({
id: ca.id,
processedControlImage: null,
})
);
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
@@ -215,35 +209,28 @@ export const addRequestedMultipleImageDeletionListener = () => {
dispatch(clearInitialImage());
}
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
// reset control adapters that use the deleted images
forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
if (
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
) {
dispatch(
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterImageChanged({
id: ca.id,
controlImage: null,
})
);
dispatch(
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
controlAdapterProcessedImageChanged({
id: ca.id,
processedControlImage: null,
})
);
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {

View File

@@ -3,11 +3,9 @@ import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
@@ -90,39 +88,26 @@ export const addImageDroppedListener = () => {
* Image dropped on ControlNet
*/
if (
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { controlNetId } = overData.context;
const { id } = overData.context;
dispatch(
controlNetImageChanged({
controlAdapterImageChanged({
id,
controlImage: activeData.payload.imageDTO.image_name,
controlNetId,
})
);
dispatch(
controlNetIsEnabledChanged({
controlNetId,
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
return;
}
/**
* Image dropped on IP Adapter image
*/
if (
overData.actionType === 'SET_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
return;
}
/**
* Image dropped on Canvas
*/

View File

@@ -18,8 +18,7 @@ export const addImageToDeleteSelectedListener = () => {
const isImageInUse =
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlNetImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isControlImage) ||
imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) {

View File

@@ -2,11 +2,9 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
@@ -87,17 +85,17 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
const { controlNetId } = postUploadAction;
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
const { id } = postUploadAction;
dispatch(
controlNetIsEnabledChanged({
controlNetId,
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
dispatch(
controlNetImageChanged({
controlNetId,
controlAdapterImageChanged({
id,
controlImage: imageDTO.image_name,
})
);
@@ -110,18 +108,6 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setIPAdapterImage'),
})
);
return;
}
if (postUploadAction?.type === 'SET_INITIAL_IMAGE') {
dispatch(initialImageChanged(imageDTO));
dispatch(

View File

@@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import {
controlNetRemoved,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterIsEnabledChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
@@ -15,9 +15,9 @@ import {
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { t } from 'i18next';
export const addModelSelectedListener = () => {
startAppListening({
@@ -60,33 +60,27 @@ export const addModelSelectedListener = () => {
}
// handle incompatible controlnets
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base_model !== base_model) {
dispatch(
controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false })
);
modelsCleared += 1;
}
});
// handle incompatible IP-Adapter
const { ipAdapterInfo } = state.controlNet;
if (
ipAdapterInfo.model &&
ipAdapterInfo.model.base_model !== base_model
) {
dispatch(ipAdapterStateReset());
modelsCleared += 1;
}
if (modelsCleared > 0) {
dispatch(
addToast(
makeToast({
title: `${t(
'toast.baseModelChangedCleared'
)} ${modelsCleared} ${t('toast.incompatibleSubmodel')}${
modelsCleared === 1 ? '' : 's'
}`,
title: t(
modelsCleared === 1
? 'toast.baseModelChangedCleared_one'
: 'toast.baseModelChangedCleared_many',
{
number: modelsCleared,
}
),
status: 'warning',
})
)

View File

@@ -1,8 +1,10 @@
import { logger } from 'app/logging/logger';
import {
controlNetRemoved,
ipAdapterModelChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterModelCleared,
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import {
modelChanged,
@@ -19,14 +21,12 @@ import {
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import {
ipAdapterModelsAdapter,
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
import { zIPAdapterModel } from 'features/nodes/types/types';
export const addModelsLoadedListener = () => {
startAppListening({
@@ -221,21 +221,45 @@ export const addModelsLoadedListener = () => {
`ControlNet models loaded (${action.payload.ids.length})`
);
const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
if (isControlNetAvailable) {
if (isModelAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info(
{ models: action.payload.entities },
`ControlNet models loaded (${action.payload.ids.length})`
);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
@@ -249,38 +273,20 @@ export const addModelsLoadedListener = () => {
`IP Adapter models loaded (${action.payload.ids.length})`
);
const { model } = getState().controlNet.ipAdapterInfo;
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === model?.model_name &&
m?.base_model === model?.base_model
);
if (isModelAvailable) {
return;
}
const firstModel = ipAdapterModelsAdapter
.getSelectors()
.selectAll(action.payload)[0];
if (!firstModel) {
dispatch(ipAdapterModelChanged(null));
}
const result = zIPAdapterModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse IP Adapter model'
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
return;
}
dispatch(ipAdapterModelChanged(result.data));
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({

View File

@@ -8,6 +8,7 @@ import {
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { CANVAS_OUTPUT } from 'features/nodes/util/graphBuilders/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { isImageOutput } from 'services/api/guards';
import { imagesAdapter } from 'services/api/util';
@@ -70,11 +71,21 @@ export const addInvocationCompleteEventListener = () => {
)
);
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData(
'getBoardImagesTotal',
imageDTO.board_id ?? 'none',
(draft) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
draft.total += 1;
}
)
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'BoardImagesTotal', id: imageDTO.board_id },
{ type: 'BoardAssetsTotal', id: imageDTO.board_id },
{ type: 'Board', id: imageDTO.board_id },
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
])
);

View File

@@ -11,44 +11,70 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => {
const log = logger('socketio');
const {
queue_item_id: item_id,
queue_batch_id,
status,
} = action.payload.data;
// we've got new status for the queue item, batch and queue
const { queue_item, batch_status, queue_status } = action.payload.data;
log.debug(
action.payload,
`Queue item ${item_id} status updated: ${status}`
`Queue item ${queue_item.item_id} status updated: ${queue_item.status}`
);
dispatch(appSocketQueueItemStatusChanged(action.payload));
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: item_id,
changes: action.payload.data,
id: queue_item.item_id,
changes: queue_item,
});
})
);
// Update the queue status (we do not get the processor status here)
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
if (!draft) {
return;
}
Object.assign(draft.queue, queue_status);
})
);
// Update the batch status
dispatch(
queueApi.util.updateQueryData(
'getBatchStatus',
{ batch_id: batch_status.batch_id },
() => batch_status
)
);
// Update the queue item status (this is the full queue item, including the session)
dispatch(
queueApi.util.updateQueryData(
'getQueueItem',
queue_item.item_id,
(draft) => {
if (!draft) {
return;
}
Object.assign(draft, queue_item);
}
)
);
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItemDTO', id: item_id },
{ type: 'BatchStatus', id: queue_batch_id },
])
);
const req = dispatch(
queueApi.endpoints.getQueueStatus.initiate(undefined, {
forceRefetch: true,
})
);
await req.unwrap();
req.unsubscribe();
// Pass the event along
dispatch(appSocketQueueItemStatusChanged(action.payload));
},
});
};

View File

@@ -7,7 +7,7 @@ import {
} from '@reduxjs/toolkit';
import canvasReducer from 'features/canvas/store/canvasSlice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import controlAdaptersReducer from 'features/controlAdapters/store/controlAdaptersSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
@@ -44,7 +44,7 @@ const allReducers = {
config: configReducer,
ui: uiReducer,
hotkeys: hotkeysReducer,
controlNet: controlNetReducer,
controlAdapters: controlAdaptersReducer,
dynamicPrompts: dynamicPromptsReducer,
deleteImageModal: deleteImageModalReducer,
changeBoardModal: changeBoardModalReducer,
@@ -68,7 +68,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'postprocessing',
'system',
'ui',
'controlNet',
'controlAdapters',
'dynamicPrompts',
'lora',
'modelmanager',

View File

@@ -1,4 +1,4 @@
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt';

View File

@@ -15,7 +15,7 @@ type UseImageUploadButtonArgs = {
* @example
* const { getUploadButtonProps, getUploadInputProps, openUploader } = useImageUploadButton({
* postUploadAction: {
* type: 'SET_CONTROLNET_IMAGE',
* type: 'SET_CONTROL_ADAPTER_IMAGE',
* controlNetId: '12345',
* },
* isDisabled: getIsUploadDisabled(),

View File

@@ -2,16 +2,18 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isInvocationNode } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach, map } from 'lodash-es';
import { forEach } from 'lodash-es';
import { getConnectedEdges } from 'reactflow';
const selector = createSelector(
[stateSelector, activeTabNameSelector],
(
{ controlNet, generation, system, nodes, dynamicPrompts },
{ controlAdapters, generation, system, nodes, dynamicPrompts },
activeTabName
) => {
const { initialImage, model } = generation;
@@ -87,30 +89,39 @@ const selector = createSelector(
reasons.push(i18n.t('parameters.invoke.noModelSelected'));
}
if (controlNet.isEnabled) {
map(controlNet.controlNets).forEach((controlNet, i) => {
if (!controlNet.isEnabled) {
return;
}
if (!controlNet.model) {
reasons.push(
i18n.t('parameters.invoke.noModelForControlNet', { index: i + 1 })
);
}
selectControlAdapterAll(controlAdapters).forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (
!controlNet.controlImage ||
(!controlNet.processedControlImage &&
controlNet.processorType !== 'none')
) {
reasons.push(
i18n.t('parameters.invoke.noControlImageForControlNet', {
index: i + 1,
})
);
}
});
}
if (!ca.model) {
reasons.push(
i18n.t('parameters.invoke.noModelForControlAdapter', {
number: i + 1,
})
);
} else if (ca.model.base_model !== model?.base_model) {
// This should never happen, just a sanity check
reasons.push(
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', {
number: i + 1,
})
);
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) &&
!ca.processedControlImage &&
ca.processorType !== 'none')
) {
reasons.push(
i18n.t('parameters.invoke.noControlImageForControlAdapter', {
number: i + 1,
})
);
}
});
}
return { isReady: !reasons.length, reasons };

View File

@@ -1,5 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@@ -22,10 +21,10 @@ export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved'
);
export const canvasMaskToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasMaskToControlNet');
export const canvasMaskToControlAdapter = createAction<{ id: string }>(
'canvas/canvasMaskToControlAdapter'
);
export const canvasImageToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasImageToControlNet');
export const canvasImageToControlAdapter = createAction<{ id: string }>(
'canvas/canvasImageToControlAdapter'
);

View File

@@ -29,6 +29,7 @@ import {
isCanvasBaseImage,
isCanvasMaskLine,
} from './canvasTypes';
import { appSocketQueueItemStatusChanged } from 'services/events/actions';
export const initialLayerState: CanvasLayerState = {
objects: [],
@@ -786,6 +787,18 @@ export const canvasSlice = createSlice({
},
},
extraReducers: (builder) => {
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
const batch_status = action.payload.data.batch_status;
if (!state.batchIds.includes(batch_status.batch_id)) {
return;
}
if (batch_status.in_progress === 0 && batch_status.pending === 0) {
state.batchIds = state.batchIds.filter(
(id) => id !== batch_status.batch_id
);
}
});
builder.addCase(setAspectRatio, (state, action) => {
const ratio = action.payload;
if (ratio) {

View File

@@ -3,92 +3,64 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetDuplicated,
controlNetRemoved,
controlNetIsEnabledChanged,
} from '../store/controlNetSlice';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
controlAdapterDuplicated,
controlAdapterIsEnabledChanged,
controlAdapterRemoved,
} from '../store/controlAdaptersSlice';
import ParamControlAdapterModel from './parameters/ParamControlAdapterModel';
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useTranslation } from 'react-i18next';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterType } from '../hooks/useControlAdapterType';
import ControlAdapterImagePreview from './ControlAdapterImagePreview';
import ControlAdapterProcessorComponent from './ControlAdapterProcessorComponent';
import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
import ParamControlAdapterBeginEnd from './parameters/ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
type ControlNetProps = {
controlNet: ControlNetConfig;
};
const ControlNet = (props: ControlNetProps) => {
const { controlNet } = props;
const { controlNetId } = controlNet;
const ControlAdapterConfig = (props: { id: string; number: number }) => {
const { id, number } = props;
const controlAdapterType = useControlAdapterType(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const activeTabName = useAppSelector(activeTabNameSelector);
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const cn = controlNet.controlNets[controlNetId];
if (!cn) {
return {
isEnabled: false,
shouldAutoConfig: false,
};
}
const { isEnabled, shouldAutoConfig } = cn;
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isEnabled = useControlAdapterIsEnabled(id);
const [isExpanded, toggleIsExpanded] = useToggle(false);
const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]);
dispatch(controlAdapterRemoved({ id }));
}, [id, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(
controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
);
}, [controlNetId, dispatch]);
dispatch(controlAdapterDuplicated(id));
}, [id, dispatch]);
const handleToggleIsEnabled = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
controlNetIsEnabledChanged({
controlNetId,
controlAdapterIsEnabledChanged({
id,
isEnabled: e.target.checked,
})
);
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!controlAdapterType) {
return null;
}
return (
<Flex
sx={{
@@ -103,27 +75,31 @@ const ControlNet = (props: ControlNetProps) => {
},
}}
>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<Flex
sx={{ gap: 2, alignItems: 'center', justifyContent: 'space-between' }}
>
<IAISwitch
tooltip={t('controlnet.toggleControlNet')}
label={t(`controlnet.${controlAdapterType}`, { number })}
aria-label={t('controlnet.toggleControlNet')}
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
formControlProps={{ w: 'full' }}
formLabelProps={{ fontWeight: 600 }}
/>
</Flex>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<Box
sx={{
w: 'full',
minW: 0,
// opacity: isEnabled ? 1 : 0.5,
// pointerEvents: isEnabled ? 'auto' : 'none',
transitionProperty: 'common',
transitionDuration: '0.1s',
}}
>
<ParamControlNetModel controlNet={controlNet} />
<ParamControlAdapterModel id={id} />
</Box>
{activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports controlNet={controlNet} />
<ControlNetCanvasImageImports id={id} />
)}
<IAIIconButton
size="sm"
@@ -174,23 +150,6 @@ const ControlNet = (props: ControlNetProps) => {
/>
}
/>
{!shouldAutoConfig && (
<Box
sx={{
position: 'absolute',
w: 1.5,
h: 1.5,
borderRadius: 'full',
top: 4,
insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}}
/>
)}
</Flex>
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
@@ -207,8 +166,8 @@ const ControlNet = (props: ControlNetProps) => {
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight controlNet={controlNet} />
<ParamControlNetBeginEnd controlNet={controlNet} />
<ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} />
</Flex>
{!isExpanded && (
<Flex
@@ -220,26 +179,26 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={controlNet} isSmall />
<ControlAdapterImagePreview id={id} isSmall />
</Flex>
)}
</Flex>
<Flex sx={{ gap: 2 }}>
<ParamControlNetControlMode controlNet={controlNet} />
<ParamControlNetResizeMode controlNet={controlNet} />
<ParamControlAdapterControlMode id={id} />
<ParamControlAdapterResizeMode id={id} />
</Flex>
<ParamControlNetProcessorSelect controlNet={controlNet} />
<ParamControlAdapterProcessorSelect id={id} />
</Flex>
{isExpanded && (
<>
<ControlNetImagePreview controlNet={controlNet} />
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
<ControlNetProcessorComponent controlNet={controlNet} />
<ControlAdapterImagePreview id={id} />
<ControlAdapterShouldAutoConfig id={id} />
<ControlAdapterProcessorComponent id={id} />
</>
)}
</Flex>
);
};
export default memo(ControlNet);
export default memo(ControlAdapterConfig);

View File

@@ -23,20 +23,20 @@ import {
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import {
ControlNetConfig,
controlNetImageChanged,
} from '../store/controlNetSlice';
import { controlAdapterImageChanged } from '../store/controlAdaptersSlice';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from '../hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from '../hooks/useControlAdapterProcessorType';
type Props = {
controlNet: ControlNetConfig;
id: string;
isSmall?: boolean;
};
const selector = createSelector(
stateSelector,
({ controlNet, gallery }) => {
const { pendingControlImages } = controlNet;
({ controlAdapters, gallery }) => {
const { pendingControlImages } = controlAdapters;
const { autoAddBoardId } = gallery;
return {
@@ -47,13 +47,10 @@ const selector = createSelector(
defaultSelectorOptions
);
const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const {
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
controlNetId,
} = controlNet;
const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const controlImageName = useControlAdapterControlImage(id);
const processedControlImageName = useControlAdapterProcessedControlImage(id);
const processorType = useControlAdapterProcessorType(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -75,8 +72,8 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]);
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
}, [id, dispatch]);
const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) {
@@ -133,32 +130,32 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) {
return {
id: controlNetId,
id,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, controlNetId]);
}, [controlImage, id]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: controlNetId,
actionType: 'SET_CONTROLNET_IMAGE',
context: { controlNetId },
id,
actionType: 'SET_CONTROL_ADAPTER_IMAGE',
context: { id },
}),
[controlNetId]
[id]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
[controlNetId]
() => ({ type: 'SET_CONTROL_ADAPTER_IMAGE', id }),
[id]
);
const shouldShowProcessedImage =
controlImage &&
processedControlImage &&
!isMouseOverImage &&
!pendingControlImages.includes(controlNetId) &&
!pendingControlImages.includes(id) &&
processorType !== 'none';
return (
@@ -222,7 +219,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
/>
</>
{pendingControlImages.includes(controlNetId) && (
{pendingControlImages.includes(id) && (
<Flex
sx={{
position: 'absolute',
@@ -250,4 +247,4 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
);
};
export default memo(ControlNetImagePreview);
export default memo(ControlAdapterImagePreview);

View File

@@ -1,26 +1,26 @@
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetImageProcessed } from '../store/actions';
import IAIButton from 'common/components/IAIButton';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
import { memo, useCallback } from 'react';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { controlAdapterImageProcessed } from '../store/actions';
type Props = {
controlNet: ControlNetConfig;
id: string;
};
const ControlNetPreprocessButton = (props: Props) => {
const { controlNetId, controlImage } = props.controlNet;
const ControlAdapterPreprocessButton = ({ id }: Props) => {
const controlImage = useControlAdapterControlImage(id);
const dispatch = useAppDispatch();
const isReady = useIsReadyToEnqueue();
const handleProcess = useCallback(() => {
dispatch(
controlNetImageProcessed({
controlNetId,
controlAdapterImageProcessed({
id,
})
);
}, [controlNetId, dispatch]);
}, [id, dispatch]);
return (
<IAIButton
@@ -33,4 +33,4 @@ const ControlNetPreprocessButton = (props: Props) => {
);
};
export default memo(ControlNetPreprocessButton);
export default memo(ControlAdapterPreprocessButton);

View File

@@ -1,5 +1,4 @@
import { memo } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import CannyProcessor from './processors/CannyProcessor';
import ColorMapProcessor from './processors/ColorMapProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
@@ -13,18 +12,25 @@ import NormalBaeProcessor from './processors/NormalBaeProcessor';
import OpenposeProcessor from './processors/OpenposeProcessor';
import PidiProcessor from './processors/PidiProcessor';
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterProcessorNode } from '../hooks/useControlAdapterProcessorNode';
export type ControlNetProcessorProps = {
controlNet: ControlNetConfig;
export type Props = {
id: string;
};
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, isEnabled, processorNode } = props.controlNet;
const ControlAdapterProcessorComponent = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const processorNode = useControlAdapterProcessorNode(id);
if (!processorNode) {
return null;
}
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -34,7 +40,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'color_map_image_processor') {
return (
<ColorMapProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -44,7 +50,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -54,7 +60,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'lineart_image_processor') {
return (
<LineartProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -64,7 +70,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'content_shuffle_image_processor') {
return (
<ContentShuffleProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -74,7 +80,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'lineart_anime_image_processor') {
return (
<LineartAnimeProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -84,7 +90,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'mediapipe_face_processor') {
return (
<MediapipeFaceProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -94,7 +100,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'midas_depth_image_processor') {
return (
<MidasDepthProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -104,7 +110,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'mlsd_image_processor') {
return (
<MlsdImageProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -114,7 +120,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'normalbae_image_processor') {
return (
<NormalBaeProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -124,7 +130,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'openpose_image_processor') {
return (
<OpenposeProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -134,7 +140,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'pidi_image_processor') {
return (
<PidiProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -144,7 +150,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
if (processorNode.type === 'zoe_depth_image_processor') {
return (
<ZoeDepthProcessor
controlNetId={controlNetId}
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -154,4 +160,4 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
return null;
};
export default memo(ControlNetProcessorComponent);
export default memo(ControlAdapterProcessorComponent);

View File

@@ -0,0 +1,39 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterShouldAutoConfig } from '../hooks/useControlAdapterShouldAutoConfig';
import { isNil } from 'lodash-es';
type Props = {
id: string;
};
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlAdapterAutoConfigToggled({ id }));
}, [id, dispatch]);
if (isNil(shouldAutoConfig)) {
return null;
}
return (
<IAISwitch
label={t('controlnet.autoConfigure')}
aria-label={t('controlnet.autoConfigure')}
isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged}
isDisabled={!isEnabled}
/>
);
};
export default memo(ControlAdapterShouldAutoConfig);

View File

@@ -0,0 +1,114 @@
import { ButtonGroup, Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse';
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
import {
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
selectControlAdapterIds,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react';
import { FaPlus } from 'react-icons/fa';
import { useAddControlAdapter } from '../hooks/useAddControlAdapter';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
({ controlAdapters }) => {
const activeLabel: string[] = [];
const ipAdapterCount = selectAllIPAdapters(controlAdapters).length;
if (ipAdapterCount > 0) {
activeLabel.push(`${ipAdapterCount} IP`);
}
const controlNetCount = selectAllControlNets(controlAdapters).length;
if (controlNetCount > 0) {
activeLabel.push(`${controlNetCount} ControlNet`);
}
const t2iAdapterCount = selectAllT2IAdapters(controlAdapters).length;
if (t2iAdapterCount > 0) {
activeLabel.push(`${t2iAdapterCount} T2I`);
}
const controlAdapterIds =
selectControlAdapterIds(controlAdapters).map(String);
return {
controlAdapterIds,
activeLabel: activeLabel.join(', '),
};
},
defaultSelectorOptions
);
const ControlAdaptersCollapse = () => {
const { t } = useTranslation();
const { controlAdapterIds, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const [addControlNet, isAddControlNetDisabled] =
useAddControlAdapter('controlnet');
const [addIPAdapter, isAddIPAdapterDisabled] =
useAddControlAdapter('ip_adapter');
const [addT2IAdapter, isAddT2IAdapterDisabled] =
useAddControlAdapter('t2i_adapter');
if (isControlNetDisabled) {
return null;
}
return (
<IAICollapse label="Control Adapters" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ButtonGroup size="sm" w="full" justifyContent="space-between">
<IAIButton
tooltip={t('controlnet.addControlNet')}
leftIcon={<FaPlus />}
onClick={addControlNet}
data-testid="add controlnet"
flexGrow={1}
isDisabled={isAddControlNetDisabled}
>
{t('common.controlNet')}
</IAIButton>
<IAIButton
tooltip={t('controlnet.addIPAdapter')}
leftIcon={<FaPlus />}
onClick={addIPAdapter}
data-testid="add ip adapter"
flexGrow={1}
isDisabled={isAddIPAdapterDisabled}
>
{t('common.ipAdapter')}
</IAIButton>
<IAIButton
tooltip={t('controlnet.addT2IAdapter')}
leftIcon={<FaPlus />}
onClick={addT2IAdapter}
data-testid="add t2i adapter"
flexGrow={1}
isDisabled={isAddT2IAdapterDisabled}
>
{t('common.t2iAdapter')}
</IAIButton>
</ButtonGroup>
{controlAdapterIds.map((id, i) => (
<Fragment key={id}>
<Divider />
<ControlAdapterConfig id={id} number={i + 1} />
</Fragment>
))}
</Flex>
</IAICollapse>
);
};
export default memo(ControlAdaptersCollapse);

View File

@@ -0,0 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { controlAdapterProcessorParamsChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { ControlAdapterProcessorNode } from 'features/controlAdapters/store/types';
import { useCallback } from 'react';
export const useProcessorNodeChanged = () => {
const dispatch = useAppDispatch();
const handleProcessorNodeChanged = useCallback(
(id: string, params: Partial<ControlAdapterProcessorNode>) => {
dispatch(
controlAdapterProcessorParamsChanged({
id,
params,
})
);
},
[dispatch]
);
return handleProcessorNodeChanged;
};

View File

@@ -2,32 +2,31 @@ import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
canvasImageToControlNet,
canvasMaskToControlNet,
canvasImageToControlAdapter,
canvasMaskToControlAdapter,
} from 'features/canvas/store/actions';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { FaImage, FaMask } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
type ControlNetCanvasImageImportsProps = {
controlNet: ControlNetConfig;
id: string;
};
const ControlNetCanvasImageImports = (
props: ControlNetCanvasImageImportsProps
) => {
const { controlNet } = props;
const { id } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleImportImageFromCanvas = useCallback(() => {
dispatch(canvasImageToControlNet({ controlNet }));
}, [controlNet, dispatch]);
dispatch(canvasImageToControlAdapter({ id }));
}, [id, dispatch]);
const handleImportMaskFromCanvas = useCallback(() => {
dispatch(canvasMaskToControlNet({ controlNet }));
}, [controlNet, dispatch]);
dispatch(canvasMaskToControlAdapter({ id }));
}, [id, dispatch]);
return (
<Flex

View File

@@ -11,44 +11,49 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import { useControlAdapterBeginEndStepPct } from 'features/controlAdapters/hooks/useControlAdapterBeginEndStepPct';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import {
ControlNetConfig,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
controlNet: ControlNetConfig;
id: string;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { beginStepPct, endStepPct, isEnabled, controlNetId } =
props.controlNet;
const ParamControlAdapterBeginEnd = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const stepPcts = useControlAdapterBeginEndStepPct(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleStepPctChanged = useCallback(
(v: number[]) => {
dispatch(
controlNetBeginStepPctChanged({
controlNetId,
controlAdapterBeginStepPctChanged({
id,
beginStepPct: v[0] as number,
})
);
dispatch(
controlNetEndStepPctChanged({
controlNetId,
controlAdapterEndStepPctChanged({
id,
endStepPct: v[1] as number,
})
);
},
[controlNetId, dispatch]
[dispatch, id]
);
if (!stepPcts) {
return null;
}
return (
<IAIInformationalPopover feature="controlNetBeginEnd">
<FormControl isDisabled={!isEnabled}>
@@ -56,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
aria-label={['Begin Step %', 'End Step %!']}
value={[beginStepPct, endStepPct]}
value={[stepPcts.beginStepPct, stepPcts.endStepPct]}
onChange={handleStepPctChanged}
min={0}
max={1}
@@ -67,10 +72,18 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={formatPct(beginStepPct)} placement="top" hasArrow>
<Tooltip
label={formatPct(stepPcts.beginStepPct)}
placement="top"
hasArrow
>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<Tooltip
label={formatPct(stepPcts.endStepPct)}
placement="top"
hasArrow
>
<RangeSliderThumb index={1} />
</Tooltip>
<RangeSliderMark
@@ -107,4 +120,4 @@ const ParamControlNetBeginEnd = (props: Props) => {
);
};
export default memo(ParamControlNetBeginEnd);
export default memo(ParamControlAdapterBeginEnd);

View File

@@ -1,22 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
ControlNetConfig,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterControlMode } from 'features/controlAdapters/hooks/useControlAdapterControlMode';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { controlAdapterControlModeChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { ControlMode } from 'features/controlAdapters/store/types';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNet: ControlNetConfig;
type Props = {
id: string;
};
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlMode, isEnabled, controlNetId } = props.controlNet;
export default function ParamControlAdapterControlMode({ id }: Props) {
const isEnabled = useControlAdapterIsEnabled(id);
const controlMode = useControlAdapterControlMode(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -28,19 +26,23 @@ export default function ParamControlNetControlMode(
];
const handleControlModeChange = useCallback(
(controlMode: ControlModes) => {
dispatch(controlNetControlModeChanged({ controlNetId, controlMode }));
(controlMode: ControlMode) => {
dispatch(controlAdapterControlModeChanged({ id, controlMode }));
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!controlMode) {
return null;
}
return (
<IAIInformationalPopover feature="controlNetControlMode">
<IAIMantineSelect
disabled={!isEnabled}
label={t('controlnet.controlMode')}
data={CONTROL_MODE_DATA}
value={String(controlMode)}
value={controlMode}
onChange={handleControlModeChange}
/>
</IAIInformationalPopover>

View File

@@ -1,23 +1,21 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import {
ControlNetConfig,
controlNetModelChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = {
controlNet: ControlNetConfig;
type ParamControlAdapterModelProps = {
id: string;
};
const selector = createSelector(
@@ -29,23 +27,31 @@ const selector = createSelector(
defaultSelectorOptions
);
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet;
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const { mainModel } = useAppSelector(selector);
const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery();
const models = useControlAdapterModels(controlAdapterType);
const data = useMemo(() => {
if (!controlNetModels) {
if (!models) {
return [];
}
const data: SelectItem[] = [];
const data: {
value: string;
label: string;
group: string;
disabled: boolean;
tooltip?: string;
}[] = [];
forEach(controlNetModels.entities, (model, id) => {
models.forEach((model) => {
if (!model) {
return;
}
@@ -53,7 +59,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
value: model.id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
@@ -63,20 +69,23 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
});
});
data.sort((a, b) =>
// sort 'none' to the top
a.disabled ? 1 : b.disabled ? -1 : a.label.localeCompare(b.label)
);
return data;
}, [controlNetModels, mainModel?.base_model, t]);
}, [mainModel?.base_model, models, t]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
models.find(
(m) =>
m?.id ===
`${model?.base_model}/${controlAdapterType}/${model?.model_name}`
),
[controlAdapterType, model?.base_model, model?.model_name, models]
);
const handleModelChanged = useCallback(
@@ -91,11 +100,9 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
dispatch(controlAdapterModelChanged({ id, model: newControlNetModel }));
},
[controlNetId, dispatch]
[dispatch, id]
);
return (
@@ -114,4 +121,4 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
);
};
export default memo(ParamControlNetModel);
export default memo(ParamControlAdapterModel);

View File

@@ -5,19 +5,18 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterProcessorNode } from 'features/controlAdapters/hooks/useControlAdapterProcessorNode';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import {
ControlNetConfig,
controlNetProcessorTypeChanged,
} from '../../store/controlNetSlice';
import { ControlNetProcessorType } from '../../store/types';
import { useTranslation } from 'react-i18next';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlAdapterProcessortTypeChanged } from '../../store/controlAdaptersSlice';
import { ControlAdapterProcessorType } from '../../store/types';
type ParamControlNetProcessorSelectProps = {
controlNet: ControlNetConfig;
type Props = {
id: string;
};
const selector = createSelector(
@@ -41,7 +40,7 @@ const selector = createSelector(
.filter(
(d) =>
!config.sd.disabledControlNetProcessors.includes(
d.value as ControlNetProcessorType
d.value as ControlAdapterProcessorType
)
);
@@ -50,26 +49,29 @@ const selector = createSelector(
defaultSelectorOptions
);
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const ParamControlAdapterProcessorSelect = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const processorNode = useControlAdapterProcessorNode(id);
const dispatch = useAppDispatch();
const { controlNetId, isEnabled, processorNode } = props.controlNet;
const controlNetProcessors = useAppSelector(selector);
const { t } = useTranslation();
const handleProcessorTypeChanged = useCallback(
(v: string | null) => {
dispatch(
controlNetProcessorTypeChanged({
controlNetId,
processorType: v as ControlNetProcessorType,
controlAdapterProcessortTypeChanged({
id,
processorType: v as ControlAdapterProcessorType,
})
);
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!processorNode) {
return null;
}
return (
<IAIMantineSearchableSelect
label={t('controlnet.processor')}
@@ -81,4 +83,4 @@ const ParamControlNetProcessorSelect = (
);
};
export default memo(ParamControlNetProcessorSelect);
export default memo(ParamControlAdapterProcessorSelect);

View File

@@ -1,22 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlNetConfig,
ResizeModes,
controlNetResizeModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterResizeMode } from 'features/controlAdapters/hooks/useControlAdapterResizeMode';
import { controlAdapterResizeModeChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { ResizeMode } from 'features/controlAdapters/store/types';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetResizeModeProps = {
controlNet: ControlNetConfig;
type Props = {
id: string;
};
export default function ParamControlNetResizeMode(
props: ParamControlNetResizeModeProps
) {
const { resizeMode, isEnabled, controlNetId } = props.controlNet;
export default function ParamControlAdapterResizeMode({ id }: Props) {
const isEnabled = useControlAdapterIsEnabled(id);
const resizeMode = useControlAdapterResizeMode(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -27,19 +25,23 @@ export default function ParamControlNetResizeMode(
];
const handleResizeModeChange = useCallback(
(resizeMode: ResizeModes) => {
dispatch(controlNetResizeModeChanged({ controlNetId, resizeMode }));
(resizeMode: ResizeMode) => {
dispatch(controlAdapterResizeModeChanged({ id, resizeMode }));
},
[controlNetId, dispatch]
[id, dispatch]
);
if (!resizeMode) {
return null;
}
return (
<IAIInformationalPopover feature="controlNetResizeMode">
<IAIMantineSelect
disabled={!isEnabled}
label={t('controlnet.resizeMode')}
data={RESIZE_MODE_DATA}
value={String(resizeMode)}
value={resizeMode}
onChange={handleResizeModeChange}
/>
</IAIInformationalPopover>

View File

@@ -1,28 +1,34 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAISlider from 'common/components/IAISlider';
import {
ControlNetConfig,
controlNetWeightChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterWeight } from 'features/controlAdapters/hooks/useControlAdapterWeight';
import { controlAdapterWeightChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetWeightProps = {
controlNet: ControlNetConfig;
type ParamControlAdapterWeightProps = {
id: string;
};
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { weight, isEnabled, controlNetId } = props.controlNet;
const ParamControlAdapterWeight = ({ id }: ParamControlAdapterWeightProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const weight = useControlAdapterWeight(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight }));
dispatch(controlAdapterWeightChanged({ id, weight }));
},
[controlNetId, dispatch]
[dispatch, id]
);
if (isNil(weight)) {
// should never happen
return null;
}
return (
<IAIInformationalPopover feature="controlNetWeight">
<IAISlider
@@ -40,4 +46,4 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
);
};
export default memo(ParamControlNetWeight);
export default memo(ParamControlAdapterWeight);

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredColorMapImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredColorMapImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredContentShuffleImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredContentShuffleImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,7 +1,7 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,7 +1,7 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,6 +1,6 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,7 +1,7 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

View File

@@ -1,7 +1,7 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlAdapters/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';

Some files were not shown because too many files have changed in this diff Show More