Compare commits

..

167 Commits

Author SHA1 Message Date
Ryan Dick
0781fdf3b0 WIP - simplify ModelLoadRegistry 2024-07-02 20:36:36 -04:00
Ryan Dick
8d7ca9c1b7 More refactoring to help with circular imports. 2024-07-02 16:49:03 -04:00
Ryan Dick
a61e0bd2dd Remove symbol re-exports that were contributing to circular import issues. 2024-07-02 15:26:08 -04:00
Ryan Dick
798e73969c Tidy handling of SCHEDULER_NAME_VALUES to help with circular import errors. 2024-07-02 15:12:59 -04:00
Ryan Dick
44f62944ee Fix circular import caused by the organization the model size utils. 2024-07-02 11:55:05 -04:00
Ryan Dick
e9936c27fb Make the VAE tile size configurable for tiled VAE (#6555)
## Summary

- This PR exposes a `tile_size` field on `ImageToLatentsInvocation` and
`LatentsToImageInvocation`.
  - Setting `tile_size = 0` preserves the default behaviour.
- This feature is primarily intended to support upscaling workflows that
require VAE encoding/decoding high resolution images. In the future, we
may want to expose the tile size as a global application config, but
that's a separate conversation.
- As a general rule, larger tile sizes produce better results at the
cost of higher memory usage.

### Example:

Original (5472x5472)

![orig](https://github.com/invoke-ai/InvokeAI/assets/14897797/af0a975d-11ed-4f3c-9e53-84f3da6c997e)

VAE roundtrip with 512x512 tiles (note the discoloration)

![vae_roundtrip_512x512](https://github.com/invoke-ai/InvokeAI/assets/14897797/d589ae3e-fe93-410a-904c-f61f0fc0f1f2)

VAE roundtrip with 1024x1024 tiles (some discoloration still present,
but less severe than at 512x512)

![vae_roundtrip_1024x1024](https://github.com/invoke-ai/InvokeAI/assets/14897797/d0bb9752-3bfa-444f-88c9-39a3ca89c748)


## Related Issues / Discussions

Related: #6144 

## QA Instructions

- [x] Test image generation via the Linear tab
- [x] Test VAE roundtrip with tiling disabled
- [x] Test VAE roundtrip with tiling and tile_size = 0
- [x] Test VAE roundtrip with tiling and tile_size > 0

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-07-02 09:16:07 -04:00
Ryan Dick
3752509066 Expose the VAE tile_size on the VAE encode and decode invocations. 2024-07-02 09:07:03 -04:00
Ryan Dick
a1b7dbfa54 Add unit test for patch_vae_tiling_params(). 2024-07-02 09:07:03 -04:00
Ryan Dick
79640ba14e Add context manager for overriding VAE tiling params. 2024-07-02 09:07:03 -04:00
psychedelicious
4075a81676 feat(ui): gallery image selection ux
The selection logic is a bit complicated. We have image selection and pagination, both of which can be triggered using the mouse or hotkeys. We have viewer image selection and comparison image selection, which is determined by the alt key.

This change ties the room together with these behaviours:

- Changing the page using pagination buttons never changes the selection.
- Changing the selected image using arrows may change the page, if the arrow key pressed would select an image off the current page.
  - `right` on the last image of the current page goes to the next page
  - `down` on the last row of images goes to the next page
  - `left` on the first image of the current page goes to the previous page
  - `up` on the first row of images goes to the previous page
- If `alt` is held when using arrow keys, we change the page, but we only change the comparison image selection.
- When using arrow keys, if the page has changed since the last image was selected, the selection is reset to the first image on the page.
- The next/previous buttons on the image viewer do the same thing as `left` and `right` without `alt`.
- When clicking an image in the gallery:
  - If no modifier keys are held, the image is exclusively selected.
  - If `ctrl` or `meta` are held, the image's selection status is toggled.
  - If `shift` is held, all images from the last-selected image to the image are selected. If there are no images on the current page, the selection is unchanged.
  - If `alt` is held, the image is set as the compare image.
- `ctrl+a` and `meta+a` add the current page to the selection.

The logic for gallery navigation and selection is now pretty hairy. It's spread across 3 hooks, a listener, redux slice, components.

When we next make changes to this part of the app, we should consider consolidating some of the related logic. Probably most of it can go into a single listener and make it much simpler to grok.
2024-07-02 13:52:32 +10:00
psychedelicious
4d39976909 feat(ui): restore loading spinner in search box
@maryhipp you were right, after trying loading bars and different placements, this feels like the best place for it.
2024-07-02 13:52:32 +10:00
Mary Hipp
d14894b3ae (ui) clarify auto-add options 2024-07-02 06:44:09 +10:00
Mary Hipp
6f5c5b0757 lint fix 2024-07-01 15:36:06 -04:00
Mary Hipp
93caa23ef8 undo 2024-07-01 15:36:06 -04:00
Mary Hipp
977a77f4e6 fix(ui): dont mess up redux if 403 gets thrown 2024-07-01 15:36:06 -04:00
Mary Hipp
57c0fcb93d (ui) clarify auto-add options 2024-07-01 15:36:06 -04:00
Kent Keirsey
8b55900035 Update README.md
Updated to include more context confirming the community edition is in fact free for commercial use.
2024-07-01 09:12:31 -07:00
psychedelicious
b1cc413bbd tidy(ui): remove search term fetching indicator
Don't like this UI (even though I suggested it). No need to prevent the user from interacting with the search term field during fetching. Let's figure out a nicer way to present this in a followup.
2024-07-01 20:06:28 +10:00
psychedelicious
face94ce33 feat(ui): tweak search term placeholder verbiage 2024-07-01 20:06:28 +10:00
psychedelicious
f0b1f0e5b6 feat(ui): pass search term as-is to query
The images service does not add the query filter if the search term is an empty string.
2024-07-01 20:06:28 +10:00
psychedelicious
390dc47db5 feat(app): transform search term to lowercase 2024-07-01 20:06:28 +10:00
Mary Hipp
20d5c3a8bf (ui): improve loader/fetching state while searching, make search term a string in redux 2024-07-01 20:06:28 +10:00
maryhipp
134d831ebf (api) simplify query 2024-07-01 20:06:28 +10:00
maryhipp
b65ed8e8f2 fix commented out migration 2024-07-01 20:06:28 +10:00
maryhipp
93951dcf82 (api) ruff 2024-07-01 20:06:28 +10:00
Mary Hipp
da05034e20 feat(ui): debounced gallery search 2024-07-01 20:06:28 +10:00
Mary Hipp
d579aefb3e feat(api): add optional search_term query param to image list to search metadata 2024-07-01 20:06:28 +10:00
blessedcoolant
5d1f6db414 fix(app): fix SQL query w/ enum for python 3.11 (#6557)
## Summary

Python 3.11 has a wonderfully devious breaking change where _sometimes_
using enum classes that inherit from `str` or `int` do not work the same
way as they do in 3.10 when used within string formatting/interpolation.

This breaks the new gallery sort queries. The fix is to use
`order_dir.value` instead of `order_dir` in the query.

This was not an issue during development because the feature was
developed w/ python 3.10.

## Related Issues / Discussions

Thanks to @JPPhoto for reporting and troubleshooting:
https://discord.com/channels/1020123559063990373/1149513625321603162/1256211815982039173

## QA Instructions

JP's fancy python 3.11 system should work on this PR.

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-06-29 18:50:16 +05:30
psychedelicious
f9961eceb7 fix(app): fix SQL query w/ enum for python 3.11 2024-06-29 11:07:39 +10:00
psychedelicious
10076fb1e8 feat(ui): tweak gallery settings popover divider styling 2024-06-28 18:01:01 +10:00
psychedelicious
d6e85e5f67 tidy(ui): rename GalleryBulkSelect -> GallerySelectionCountTag 2024-06-28 18:01:01 +10:00
psychedelicious
1ce459198c chore(ui): knip 2024-06-28 18:01:01 +10:00
psychedelicious
17d337169d fix(ui): do not reset limit when changing gallery view 2024-06-28 18:01:01 +10:00
psychedelicious
1468f4d37e perf(ui): split out gallery settings popover components
This was taking over 15ms (!) to render each time a setting changed, wtf
2024-06-28 18:01:01 +10:00
psychedelicious
2b744480d6 feat(ui): update UI for sorting 2024-06-28 18:01:01 +10:00
psychedelicious
abb8d34b56 chore(ui): typegen 2024-06-28 18:01:01 +10:00
psychedelicious
9e664d7c58 feat(api): remove order_by in favor of starred_first for images records 2024-06-28 18:01:01 +10:00
psychedelicious
c96ccae70b feat(app): remove order_by in favor of starred_first for images records 2024-06-28 18:01:01 +10:00
maryhipp
f268fe126e feat(api): add order_by and order_dir to list images for sorting 2024-06-28 18:01:01 +10:00
Mary Hipp
6109a06f04 feat(ui): gallery sort by created at or starred, asc or desc 2024-06-28 18:01:01 +10:00
Kent Keirsey
5df2a79549 Update starter models 2024-06-28 17:49:45 +10:00
Kent Keirsey
10b9088312 update controlnet starter models 2024-06-28 17:49:45 +10:00
psychedelicious
41f46b846b chore: ruff 2024-06-28 10:36:05 +10:00
psychedelicious
6dfc406c52 tests: update test_bulk_download.py after addition of archived field 2024-06-28 10:36:05 +10:00
psychedelicious
0d4b80780b feat(ui): handle edge cases when archiving/deleting boards
If the currently selected or auto-add board is archived or deleted, we should reset them. There are some edge cases taht weren't handled in the previous implementation.

All handling of this logic is moved to the (renamed) listener.
2024-06-28 10:36:05 +10:00
psychedelicious
15b9ece411 chore(ui): typegen 2024-06-28 10:36:05 +10:00
psychedelicious
89fcab34d0 feat(app): BoardRecord.archived is a required field 2024-06-28 10:36:05 +10:00
psychedelicious
132289de55 chore: ruff E721
Looks like in the latest version of ruff, E721 was added or changed and now catches something it didn't before.
2024-06-28 10:36:05 +10:00
psychedelicious
9f93e9d120 fix(app): when creating image, skip adding to board if board doesn't exist
Before this change, if you attempt to create an image that with a nonexistent board, we'd get an unhandled error when adding the image to a board. The record would be created, but file not, due to the structure of the code.

With this change, we now log a warning if we have a problem adding the image to the board, but the record and file are still created.

A future improvement would be to create a transaction for this part of the code, preventing some other situation that could result in only the record or only the file beings saved.
2024-06-28 10:36:05 +10:00
Mary Hipp
b5f23292d4 lint fix 2024-06-28 10:36:05 +10:00
maryhipp
a63dbb2c2d (api) change query param to include_archived 2024-06-28 10:36:05 +10:00
Mary Hipp
740bf80f3e (ui): update query param to include_archived, fix cache when archiving boards 2024-06-28 10:36:05 +10:00
Mary Hipp
dc90de600d (ui) allow auto-add on archived boards, reset to uncategorized if auto-add board is not currently visible due to archived view 2024-06-28 10:36:05 +10:00
psychedelicious
5709f82e5f feat(ui): separate context menu for no board board
Much easier to not need to handle the board being optional in the component.
2024-06-28 10:36:05 +10:00
psychedelicious
20042d99ec tidy(ui): archived icon component 2024-06-28 10:36:05 +10:00
Mary Hipp
8fce168dc5 fix tsc errors 2024-06-28 10:36:05 +10:00
maryhipp
a7ea096b28 ruff format 2024-06-28 10:36:05 +10:00
Mary Hipp
29eb3c8b62 lint fix 2024-06-28 10:36:05 +10:00
Mary Hipp
071e8bcee4 feat(ui): make archiving and auto-add mutually exclusive 2024-06-28 10:36:05 +10:00
Mary Hipp
68c0aa898f feat(ui): add ability to archive/unarchive boards, add toggle to gallery settings to show/hide archived boards in list 2024-06-28 10:36:05 +10:00
maryhipp
5120a76ce5 cleanup 2024-06-28 10:36:05 +10:00
maryhipp
38a948ac9f feat(api): add archived query param to board list endpoint to include them in the response 2024-06-28 10:36:05 +10:00
maryhipp
c33111468e feat(api): ability to archive boards 2024-06-28 10:36:05 +10:00
Lincoln Stein
3e0fb45dd7 Load single-file checkpoints directly without conversion (#6510)
* use model_class.load_singlefile() instead of converting; works, but performance is poor

* adjust the convert api - not right just yet

* working, needs sql migrator update

* rename migration_11 before conflict merge with main

* Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* implement lightweight version-by-version config migration

* simplified config schema migration code

* associate sdxl config with sdxl VAEs

* remove use of original_config_file in load_single_file()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-06-27 17:31:28 -04:00
Ryan Dick
aba16085a5 fix(backend): mps should not use non_blocking (#6549)
## Summary

We can get black outputs when moving tensors from CPU to MPS. It appears
MPS to CPU is fine. See:
- https://github.com/pytorch/pytorch/issues/107455
-
https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28

Changes:
- Add properties for each device on `TorchDevice` as a convenience.
- Add `get_non_blocking` static method on `TorchDevice`. This utility
takes a torch device and returns the flag to be used for non_blocking
when moving a tensor to the device provided.
- Update model patching and caching APIs to use this new utility.

## Related Issues / Discussions

Fixes: #6545

## QA Instructions

For both MPS and CUDA:
- Generate at least 5 images using LoRAs
- Generate at least 5 images using IP Adapters

## Merge Plan

We have pagination merged into `main` but aren't ready for that to be
released.

Once this fix is tested and merged, we will probably want to create a
`v4.2.5post1` branch off the `v4.2.5` tag, cherry-pick the fix and do a
release from the hotfix branch.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_ @RyanJDick @lstein This
feels testable but I'm not sure how.
- [ ] _Documentation added / updated (if applicable)_
2024-06-27 10:11:53 -04:00
Ryan Dick
14775cc9c4 ruff format 2024-06-27 09:45:13 -04:00
psychedelicious
c7562dd6c0 fix(backend): mps should not use non_blocking
We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28

Changes:
- Add properties for each device on `TorchDevice` as a convenience.
- Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided.
- Update model patching and caching APIs to use this new utility.

Fixes: #6545
2024-06-27 19:15:23 +10:00
psychedelicious
a0a0c57789 chore(ui): knip 2024-06-27 13:48:40 +10:00
psychedelicious
32ebf82d1a feat(ui): better pagination buttons 2024-06-27 13:48:40 +10:00
psychedelicious
2dd172c2c6 feat(ui): gallery bulk select styling 2024-06-27 13:48:40 +10:00
psychedelicious
280ec9d4b3 fix(ui): invalidate getImageDTO caches when images are mutated 2024-06-27 13:48:40 +10:00
psychedelicious
fde8fc7575 perf(ui): optimistic updates for getImageDTO query cache 2024-06-27 13:48:40 +10:00
psychedelicious
6dcdc87eb1 fix(ui): control adapter image preview 2024-06-27 13:48:40 +10:00
Mary Hipp
93ffcb642e lint fix 2024-06-27 13:48:40 +10:00
Mary Hipp
4c914ef2e8 use correct query params for boardIdSelected listener 2024-06-27 13:48:40 +10:00
Mary Hipp
c0ad5bc4a4 fix when deleting first image in list 2024-06-27 13:48:40 +10:00
Mary Hipp
8c58a180de GG another fix 2024-06-27 13:48:40 +10:00
Mary Hipp
715dd983b0 appease the knip 2024-06-27 13:48:40 +10:00
Mary Hipp
84ffd36071 lint fix 2024-06-27 13:48:40 +10:00
Mary Hipp
9f30f1bfec fix circular dep 2024-06-27 13:48:40 +10:00
Mary Hipp
bdff5c4e87 only show selected when greater than 0 2024-06-27 13:48:40 +10:00
Mary Hipp
afb0651f91 clear selection when board or gallery view changes 2024-06-27 13:48:40 +10:00
Mary Hipp
66e25628c3 fix neg pages 2024-06-27 13:48:40 +10:00
Mary Hipp
3a531a3c88 remove rest of cache, add bulk select UI 2024-06-27 13:48:40 +10:00
Mary Hipp
f01df49128 lint fix 2024-06-27 13:48:40 +10:00
Mary Hipp
7bbe236107 implmenet custom sort to replace images adapter logic 2024-06-27 13:48:40 +10:00
psychedelicious
719c066ac4 feat(ui): more efficient board totals fetching
We only need to show the totals in the tooltip. Tooltips accpet a component for the tooltip label. The component isn't rendered until the tooltip is triggered.

Move the board total fetching into a tooltip component for the boards. Now we only fire these requests when the user mouses over the board
2024-06-27 13:48:40 +10:00
psychedelicious
689dc30f87 feat(ui): tweak pagination buttons
- Fix off-by-one error when going to last page
- Update component to have minimal/no layout shift
2024-06-27 13:48:40 +10:00
psychedelicious
1f22f6ae02 feat(ui): iterate on dynamic gallery limit
- Simplify the gallery layout
- Set an initial gallery limit to load _some_ images immediately.
- Refactor the resize observer to use the actual rendered image component to calculate the number of images per row/col. This prevents inaccuracies caused by image padding that could result in the wrong number of images.
- Debounce the limit update to not thrash teh API
- Use absolute positioning trick to ensure the gallery container is always exactly the right size
- Minimum of `imagesPerRow` images loaded at all times
2024-06-27 13:48:40 +10:00
psychedelicious
9c931d9ca0 fix(ui): gallery content overflow
This is one of those unexpected CSS quirks. Flex containers need min-width or min-height for their children to not overflow. Add `minH={0}` to gallery container.
2024-06-27 13:48:40 +10:00
Mary Hipp
e0a241fa4f wip change limit based on size of gallery 2024-06-27 13:48:40 +10:00
Mary Hipp
6a4b4ee340 trying to invalidate all the tags 2024-06-27 13:48:40 +10:00
Mary Hipp
488bf21925 fix single pagers 2024-06-27 13:48:40 +10:00
Mary Hipp
c9c39c02b6 handle generations coming in, fix pagination to use total from list query so it updates as that changes 2024-06-27 13:48:40 +10:00
Mary Hipp
5101dc4bef some cleanup, add page buttons 2024-06-27 13:48:40 +10:00
Mary Hipp
98c77a3ed1 pull in spencers work 2024-06-27 13:48:40 +10:00
psychedelicious
4fca62680d Update invokeai_version.py 2024-06-27 10:41:01 +10:00
Ryan Dick
f76282a5ff Fix handling handling of 0-step denoising process (#6544)
## Summary

https://github.com/invoke-ai/InvokeAI/pull/6522 introduced a change in
behavior in cases where start/end were set such that there are 0
timesteps. This PR reverts that change.

cc @StAlKeR7779 

## QA Instructions

Run with euler, 5 steps, start: 0.0, end: 0.05. I ran this test before
#6522, after #6522, and on this branch. This branch restores the
behavior to pre-#6522 i.e. noise is injected even if no denoising steps
are applied.


## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-06-26 13:01:58 -04:00
Ryan Dick
9a3b8c6fcb Fix handling of init_timestep in StableDiffusionGeneratorPipeline and improve its documentation. 2024-06-26 12:51:51 -04:00
Ryan Dick
bd74b84cc5 Revert "Remove the redundant init_timestep parameter that was being passed around. It is simply the first element of the timesteps array."
This reverts commit fa40061eca.
2024-06-26 12:51:51 -04:00
Brandon Rising
dc23bebebf Run ruff 2024-06-26 21:46:59 +10:00
Kent Keirsey
38b6f90c02 Update prevention exception message 2024-06-26 21:46:59 +10:00
Ryan Dick
cd9dfefe3c Fix inpainting mask shape assertions. 2024-06-25 11:31:52 -07:00
Ryan Dick
b9946e50f9 Use image-space tile dimensions on the TiledMultiDiffusionDenoiseLatents invocation. This is more natural for many users. 2024-06-25 11:31:52 -07:00
Ryan Dick
06f49a30f6 Mark TiledMultiDiffusionDenoiseLatents as a Beta node. 2024-06-25 11:31:52 -07:00
Ryan Dick
e1af78c702 Make the tile_overlap input to MultiDiffusion *strictly* control the amount of overlap rather than being a lower bound. 2024-06-25 11:31:52 -07:00
Ryan Dick
c5588e1ff7 Add TODO comment explaining why some schedulers do not interact well with MultiDiffusion. 2024-06-25 11:31:52 -07:00
Ryan Dick
07ac292680 Consolidate _region_step() function - the separation wasn't really adding any value. 2024-06-25 11:31:52 -07:00
Ryan Dick
7c032ea604 (minor) Fix some documentation typos. 2024-06-25 11:31:52 -07:00
Ryan Dick
c5ee415607 Add progress image callbacks to TiledMultiDiffusionDenoiseLatentsInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
fa40061eca Remove the redundant init_timestep parameter that was being passed around. It is simply the first element of the timesteps array. 2024-06-25 11:31:52 -07:00
Ryan Dick
7cafd78d6e Revert "Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation."
This reverts commit 753239b48d.
2024-06-25 11:31:52 -07:00
Ryan Dick
8a43656cf9 (minor) Address a few small TODOs. 2024-06-25 11:31:52 -07:00
Ryan Dick
bd3b6ca11b Remove TiledStableDiffusionRefineInvocation. It was a proof-of-concept that has been superseded by TiledMultiDiffusionDenoiseLatents. 2024-06-25 11:31:52 -07:00
Ryan Dick
ceae5fe1db (minor) typo 2024-06-25 11:31:52 -07:00
Ryan Dick
25067e4f0d Delete rough notes. 2024-06-25 11:31:52 -07:00
Ryan Dick
fb0aaa3e6d Fix advanced scheduler behaviour in MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
c22526b9d0 Fix handling of stateful schedulers in MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
c881882f73 Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend. 2024-06-25 11:31:52 -07:00
Ryan Dick
36473fc52a Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported. 2024-06-25 11:31:52 -07:00
Ryan Dick
b9964ecc4a Initial (untested) implementation of MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
051af802fe Remove inpainting support from MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
3ff2e558d9 Remove IP-Adapter and T2I-Adapter support from MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
fc187c9253 Document plan for the rest of the MultiDiffusion implementation. 2024-06-25 11:31:52 -07:00
Ryan Dick
605f460c7d Add detailed docstring to latents_from_embeddings(). 2024-06-25 11:31:52 -07:00
Ryan Dick
60d1e686d8 Copy StableDiffusionGeneratorPipeline as a starting point for a new MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
22704dd542 Simplify handling of inpainting models. Improve the in-code documentation around inpainting. 2024-06-25 11:31:52 -07:00
Ryan Dick
875673c9ba Minor tidying of latents_from_embeddings(...). 2024-06-25 11:31:52 -07:00
Ryan Dick
f604575862 Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function. 2024-06-25 11:31:52 -07:00
Ryan Dick
80a67572f1 Fix invocation name of tiled_multi_diffusion_denoise_latents. 2024-06-25 11:31:52 -07:00
Ryan Dick
60ac937698 Improve clarity of comments regarded when 'noise' and 'latents' are expected to be set. 2024-06-25 11:31:52 -07:00
Ryan Dick
1e41949a02 Fix static check errors on imports in diffusers_pipeline.py. 2024-06-25 11:31:52 -07:00
Ryan Dick
5f0e330ed2 Remove a condition for handling inpainting models that never resolves to True. The same logic is already applied earlier by AddsMaskLatents. 2024-06-25 11:31:52 -07:00
Ryan Dick
9dd779b414 Add clarifying comment to explain why noise might be None in latents_from_embedding(). 2024-06-25 11:31:52 -07:00
Ryan Dick
fa183025ac Remove unused are_like_tensors() function. 2024-06-25 11:31:52 -07:00
Ryan Dick
d3c85aa91a Remove unused StableDiffusionGeneratorPipeline.use_ip_adapter member. 2024-06-25 11:31:52 -07:00
Ryan Dick
82619602a5 Remove unused StableDiffusionGeneratorPipeline.control_model. 2024-06-25 11:31:52 -07:00
Ryan Dick
196f3b721d Stricter typing for the is_gradient_mask: bool. 2024-06-25 11:31:52 -07:00
Ryan Dick
244c28859d Fix typing of control_data to reflect that it can be None. 2024-06-25 11:31:52 -07:00
Ryan Dick
40ae174c41 Fix typing of timesteps and init_timestep. 2024-06-25 11:31:52 -07:00
Ryan Dick
afaebdf151 Fix typing to reflect that the callback arg to latents_from_embeddings is never None. 2024-06-25 11:31:52 -07:00
Ryan Dick
d661517d94 Move seed above optional params. 2024-06-25 11:31:52 -07:00
Ryan Dick
82a69a54ac Simplify handling of AddsMaskGuidance, and fix some related type errors. 2024-06-25 11:31:52 -07:00
Ryan Dick
ffc28176fe Remove unused num_inference_steps. 2024-06-25 11:31:52 -07:00
Ryan Dick
230e205541 WIP TiledMultiDiffusionDenoiseLatents. Updated parameter list and first half of the logic. 2024-06-25 11:31:52 -07:00
Ryan Dick
7e94350351 Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors. 2024-06-25 11:31:52 -07:00
Ryan Dick
c4e8549c73 Make DenoiseLatentsInvocation.prep_control_data(...) a staticmethod so that it can be called externally. 2024-06-25 11:31:52 -07:00
Ryan Dick
350a210835 Copy TiledStableDiffusionRefineInvocation as a starting point for TiledMultiDiffusionDenoiseLatents.py 2024-06-25 11:31:52 -07:00
Ryan Dick
ed781dbb0c Change tiling strategy to make TiledStableDiffusionRefineInvocation work with more tile shapes and overlaps. 2024-06-25 11:31:52 -07:00
Ryan Dick
b41ea963e7 Expose a few more params from TiledStableDiffusionRefineInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
da5d105049 Add support for LoRA models in TiledStableDiffusionRefineInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
5301770525 Add naive ControlNet support to TiledStableDiffusionRefineInvocation 2024-06-25 11:31:52 -07:00
Ryan Dick
d08e405017 Fix ControlNetModel type hint import source. 2024-06-25 11:31:52 -07:00
Ryan Dick
534640ccde Rough prototype of TiledStableDiffusionRefineInvocation is working. 2024-06-25 11:31:52 -07:00
Ryan Dick
d5ab8cab5c WIP - TiledStableDiffusionRefine 2024-06-25 11:31:52 -07:00
Ryan Dick
4767301ad3 Minor improvements to LatentsToImageInvocation type hints. 2024-06-25 11:31:52 -07:00
Ryan Dick
21d7ca45e6 Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
020e8eb413 Fix return type of prepare_noise_and_latents(...). 2024-06-25 11:31:52 -07:00
Ryan Dick
3d49541c09 Make init_scheduler() a staticmethod on DenoiseLatentsInvocation so that it can be called externally. 2024-06-25 11:31:52 -07:00
Ryan Dick
1ef266845a Only allow a single positive/negative prompt conditioning input for tiled refine. 2024-06-25 11:31:52 -07:00
Ryan Dick
a37589ca5f WIP on TiledStableDiffusionRefine 2024-06-25 11:31:52 -07:00
Ryan Dick
171a505f5e Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally. 2024-06-25 11:31:52 -07:00
Ryan Dick
8004a0d5f5 Simplify the logic in prepare_noise_and_latents(...). 2024-06-25 11:31:52 -07:00
Ryan Dick
610a1fd611 Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations. 2024-06-25 11:31:52 -07:00
Ryan Dick
43108eec13 (minor) Add a TODO note to get_scheduler(...). 2024-06-25 11:31:52 -07:00
Lincoln Stein
b03073d888 [MM] Add support for probing and loading SDXL VAE checkpoint files (#6524)
* add support for probing and loading SDXL VAE checkpoint files

* broaden regexp probe for SDXL VAEs

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-06-20 02:57:27 +00:00
steffylo
a43d602f16 fix(queue): add clear_queue_on_startup config to clear problematic queues 2024-06-19 11:39:25 +10:00
134 changed files with 2832 additions and 3368 deletions

View File

@@ -12,12 +12,24 @@
Invoke is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. Invoke offers an industry leading web-based UI, and serves as the foundation for multiple commercial products.
[Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs]
Invoke is available in two editions:
| **Community Edition** | **Professional Edition** |
|----------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| **For users looking for a locally installed, self-hosted and self-managed service** | **For users or teams looking for a cloud-hosted, fully managed service** |
| - Free to use under a commercially-friendly license | - Monthly subscription fee with three different plan levels |
| - Download and install on compatible hardware | - Offers additional benefits, including multi-user support, improved model training, and more |
| - Includes all core studio features: generate, refine, iterate on images, and build workflows | - Hosted in the cloud for easy, secure model access and scalability |
| Quick Start -> [Installation and Updates][installation docs] | More Information -> [www.invoke.com/pricing](https://www.invoke.com/pricing) |
<div align="center">
![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d)
# Documentation
| **Quick Links** |
|----------------------------------------------------------------------------------------------------------------------------|
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
</div>
## Quick Start

View File

@@ -73,15 +73,6 @@ model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a
diffusers model.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also
imported by, and can be reexported from,
`invokeai.app.services.model_manager.model_records`:
```
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
```
The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file.

View File

@@ -118,15 +118,13 @@ async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all()
return ApiDependencies.invoker.services.boards.get_all(include_archived)
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
else:
raise HTTPException(
status_code=400,

View File

@@ -9,9 +9,14 @@ from PIL import Image
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from ..dependencies import ApiDependencies
@@ -316,16 +321,14 @@ async def list_image_dtos(
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
)
return image_dtos

View File

@@ -3,9 +3,9 @@
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
@@ -19,7 +19,6 @@ from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordChanges,
UnknownModelException,
@@ -30,7 +29,6 @@ from invokeai.backend.model_manager.config import (
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
@@ -174,18 +172,6 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary(
# page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@@ -746,39 +732,36 @@ async def convert_model(
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
try:
cc_size = loader.convert_cache.max_size
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
loader._convert_cache.max_size = 1.0
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
finally:
loader._convert_cache.max_size = cc_size
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
converted_model = loader.load_model(model_config)
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path)
assert convert_path.exists()
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# install the diffusers
try:
new_key = installer.install_path(
convert_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except Exception as e:
logger.error(str(e))
store.update_model(key, changes=ModelRecordChanges(name=original_name))
raise HTTPException(status_code=409, detail=str(e))
# Update the model image if the model had one
try:
@@ -791,8 +774,8 @@ async def convert_model(
# delete the original safetensors file
installer.delete(key)
# delete the cached version
shutil.rmtree(cache_path)
# delete the temporary directory
# shutil.rmtree(cache_path)
# return the config record for the new diffusers directory
new_config = store.get_model(new_key)

View File

@@ -1,6 +1,5 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
@@ -11,9 +10,6 @@ factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
"""
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""

View File

@@ -19,8 +19,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor

View File

@@ -17,7 +17,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
@@ -53,7 +53,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
@@ -693,7 +693,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("'latents' or 'noise' must be provided!")
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
@@ -736,7 +736,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
# get the unet's config so that we can pass the base to dispatch_progress()
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:

View File

@@ -160,6 +160,8 @@ class FieldDescriptions:
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the "
"model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"

View File

@@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import singledispatchmethod
import einops
@@ -12,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@@ -22,8 +23,9 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation(
@@ -31,7 +33,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.0.2",
version="1.1.0",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
@@ -44,12 +46,17 @@ class ImageToLatentsInvocation(BaseInvocation):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
@@ -81,9 +88,18 @@ class ImageToLatentsInvocation(BaseInvocation):
else:
vae.disable_tiling()
tiling_context = nullcontext()
if tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=tile_size,
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
@@ -101,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation):
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)

View File

@@ -1,3 +1,5 @@
from contextlib import nullcontext
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
@@ -8,10 +10,9 @@ from diffusers.models.attention_processor import (
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -23,8 +24,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@@ -33,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.2.2",
version="1.3.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -47,22 +48,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_decode(
context: InvocationContext,
vae_info: LoadedModel,
seamless_axes: list[str],
latents: torch.Tensor,
use_fp32: bool,
use_tiling: bool,
) -> Image.Image:
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if use_fp32:
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
@@ -87,15 +87,24 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
if use_tiling or context.config.get().force_tiled_decode:
if self.tiled or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if self.tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=self.tile_size,
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode():
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
@@ -107,21 +116,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
TorchDevice.empty_cache()
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=latents,
use_fp32=self.fp32,
use_tiling=self.tiled,
)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -1,5 +1,4 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
@@ -7,6 +6,7 @@ from invokeai.app.invocations.fields import (
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@invocation_output("scheduler_output")

View File

@@ -7,8 +7,8 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
@@ -24,11 +24,12 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
@@ -55,15 +56,15 @@ def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> C
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
# TODO(ryand): Reset to 1.0.0 right before release.
classification=Classification.Beta,
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image. Future iterations of
this node should allow the user to specify custom regions with different parameters for each region to harness the
full power of Multi-Diffusion.
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
different parameters for each region to harness the full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
@@ -85,21 +86,24 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
description=FieldDescriptions.latents,
input=Input.Connection,
)
# TODO(ryand): Add multiple-of validation.
# TODO(ryand): Smaller defaults might make more sense.
tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
tile_min_overlap: int = InputField(
default=16,
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField(
default=32,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
"this to evenly cover the entire image.",
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
"amount.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
# TODO(ryand): The default here should probably be 0.0.
denoising_start: float = InputField(
default=0.65,
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
@@ -150,7 +154,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(), # TODO: oh...
vae=FakeVae(),
text_encoder=None,
tokenizer=None,
unet=unet,
@@ -162,19 +166,29 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.tile_min_overlap,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
# Get the unet's config so that we can pass the base to sd_step_callback().
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
@@ -205,8 +219,8 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=self.tile_height,
latent_width=self.tile_width,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
@@ -233,7 +247,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile.coords,
region=tile,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
@@ -251,17 +265,17 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=latent_tile_overlap,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
# TODO(ryand): Add proper callback.
callback=lambda x: None,
callback=step_callback,
)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
result_latents = result_latents.to("cpu")
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)

View File

@@ -1,380 +0,0 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import numpy as np
import numpy.typing as npt
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
ImageField,
Input,
InputField,
UIType,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation(
"tiled_stable_diffusion_refine",
title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"],
category="latents",
version="1.0.0",
)
class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
# TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_overlap: int = InputField(
default=16,
gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
)
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
# don't want to use the image field. Figure out how best to handle this.
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter.
seed = 0
# Load the input image.
input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_with_overlap(
image_height=input_image.height,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
assert input_image_torch.shape[:2] == (1, 3)
# Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = []
for tile in tiles:
image_tile = input_image_torch[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch:
latent_tiles.append(
ImageToLatentsInvocation.vae_encode(
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
)
)
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
refined_latent_tiles: list[torch.Tensor] = []
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
# Load the ControlNet model.
# TODO(ryand): Support multiple ControlNet models.
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
assert isinstance(controlnet_model, ControlNetModel)
# Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape
# Prepare a PIL Image for ControlNet processing.
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile.
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise_tile,
seed=seed,
mask=None,
masked_latents=None,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=[controlnet_data],
ip_adapter_data=None,
t2i_adapter_data=None,
callback=lambda x: None,
)
refined_latent_tiles.append(refined_latent_tile)
# VAE-decode each refined latent tile independently.
refined_image_tiles: list[Image.Image] = []
for refined_latent_tile in refined_latent_tiles:
refined_image_tile = LatentsToImageInvocation.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=refined_latent_tile,
use_fp32=self.vae_fp32,
use_tiling=False,
)
refined_image_tiles.append(refined_image_tile)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
# Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)

View File

@@ -40,16 +40,12 @@ class BoardRecordStorageBase(ABC):
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardRecord]:
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
"""Gets all board records."""
pass

View File

@@ -22,6 +22,8 @@ class BoardRecord(BaseModelExcludeNull):
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
"""The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
@@ -35,6 +37,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
return BoardRecord(
board_id=board_id,
@@ -43,12 +46,14 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
archived=archived,
)
class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.")
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
class BoardRecordNotFoundException(Exception):

View File

@@ -125,6 +125,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
self._cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@@ -134,35 +145,49 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
return self.get(board_id)
def get_many(
self,
offset: int = 0,
limit: int = 10,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
"""
# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
)
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
count = cast(int, self._cursor.fetchone()[0])
@@ -174,20 +199,25 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
finally:
self._lock.release()
def get_all(
self,
) -> list[BoardRecord]:
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
)
"""
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

View File

@@ -44,16 +44,12 @@ class BoardServiceABC(ABC):
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardDTO]:
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
"""Gets all boards."""
pass

View File

@@ -48,8 +48,10 @@ class BoardService(BoardServiceABC):
def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id)
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@@ -63,8 +65,8 @@ class BoardService(BoardServiceABC):
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all()
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import copy
import locale
import os
import re
@@ -25,14 +26,13 @@ DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.1"
CONFIG_SCHEMA_VERSION = "4.0.2"
def get_default_ram_cache_size() -> float:
@@ -85,7 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
log_tokenization: Enable logging of parsed prompt tokens.
patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
convert_cache_dir: Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).
download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
@@ -102,7 +102,6 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
@@ -113,6 +112,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@@ -147,7 +147,7 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
@@ -169,9 +169,8 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
@@ -186,6 +185,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
@@ -355,14 +355,14 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a v4.0.0.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
An `InvokeAIAppConfig` config dict.
"""
parsed_config_dict: dict[str, Any] = {}
@@ -396,32 +396,41 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
parsed_config_dict["schema_version"] = "4.0.0"
return parsed_config_dict
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
def migrate_v4_0_0_to_4_0_1_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to a v4.0.1 config dictionary
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
A config dict with the settings migrated to v4.0.1.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# precision "autocast" was replaced by "auto" in v4.0.1
if parsed_config_dict.get("precision") == "autocast":
parsed_config_dict["precision"] = "auto"
parsed_config_dict["schema_version"] = "4.0.1"
return parsed_config_dict
def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.1 config dictionary to a v4.0.2 config dictionary.
Args:
config_dict: A dictionary of settings from a v4.0.1 config file.
Returns:
An config dict with the settings migrated to v4.0.2.
"""
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# convert_cache was removed in 4.0.2
parsed_config_dict.pop("convert_cache", None)
parsed_config_dict["schema_version"] = "4.0.2"
return parsed_config_dict
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
@@ -435,27 +444,31 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
loaded_config_dict: dict[str, Any] = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
migrated = False
if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it
migrated = True
loaded_config_dict = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
if loaded_config_dict["schema_version"] == "4.0.0":
migrated = True
loaded_config_dict = migrate_v4_0_0_to_4_0_1_config_dict(loaded_config_dict)
if loaded_config_dict["schema_version"] == "4.0.1":
migrated = True
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
if migrated:
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
# load and write without environment variables
migrated_config = DefaultInvokeAIAppConfig.model_validate(loaded_config_dict)
migrated_config.write_file(config_path)
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)

View File

@@ -4,6 +4,7 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
@@ -37,10 +38,13 @@ class ImageRecordStorageBase(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass

View File

@@ -5,6 +5,7 @@ from typing import Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .image_records_base import ImageRecordStorageBase
@@ -144,10 +145,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
@@ -208,9 +212,21 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
query_params.append(board_id)
query_pagination = """--sql
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
"""
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"

View File

@@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class ImageServiceABC(ABC):
@@ -116,10 +117,13 @@ class ImageServiceABC(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass

View File

@@ -5,6 +5,7 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from ..image_files.image_files_common import (
ImageFileDeleteException,
@@ -73,7 +74,12 @@ class ImageService(ImageServiceABC):
session_id=session_id,
)
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
try:
self.__invoker.services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
@@ -202,19 +208,25 @@ class ImageService(ImageServiceABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_many(
offset,
limit,
starred_first,
order_dir,
image_origin,
categories,
is_intermediate,
board_id,
search_term,
)
image_dtos = [

View File

@@ -6,8 +6,7 @@ from pathlib import Path
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@@ -28,11 +27,6 @@ class ModelLoadServiceBase(ABC):
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None

View File

@@ -2,7 +2,7 @@
"""Implementation of model loader service."""
from pathlib import Path
from typing import Callable, Optional, Type
from typing import Callable, Optional
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
@@ -11,14 +11,9 @@ from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -33,8 +28,7 @@ class ModelLoadService(ModelLoadServiceBase):
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
registry: ModelLoaderRegistry,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)
@@ -42,7 +36,6 @@ class ModelLoadService(ModelLoadServiceBase):
self._logger = logger
self._app_config = app_config
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._registry = registry
def start(self, invoker: Invoker) -> None:
@@ -53,11 +46,6 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -76,7 +64,6 @@ class ModelLoadService(ModelLoadServiceBase):
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):

View File

@@ -1,17 +0,0 @@
"""Initialization file for model manager service."""
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelType",
"SubModelType",
"LoadedModel",
]

View File

@@ -7,7 +7,8 @@ import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -86,11 +87,9 @@ class ModelManagerService(ModelManagerServiceBase):
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
app_config=app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry,
)
installer = ModelInstallService(

View File

@@ -37,10 +37,14 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
if self.__invoker.services.configuration.clear_queue_on_startup:
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()

View File

@@ -652,7 +652,7 @@ class Graph(BaseModel):
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field) != list:
if get_origin(input_field) is not list:
return False
# Validate that all outputs match the input type

View File

@@ -14,6 +14,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -45,6 +47,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,35 @@
import shutil
import sqlite3
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration12Callback:
def __init__(self, app_config: InvokeAIAppConfig) -> None:
self._app_config = app_config
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_model_convert_cache_dir()
def _remove_model_convert_cache_dir(self) -> None:
"""
Removes unused model convert cache directory
"""
convert_cache = self._app_config.convert_cache_path
shutil.rmtree(convert_cache, ignore_errors=True)
def build_migration_12(app_config: InvokeAIAppConfig) -> Migration:
"""
Build the migration from database version 11 to 12.
This migration removes the now-unused model convert cache directory.
"""
migration_12 = Migration(
from_version=11,
to_version=12,
callback=Migration12Callback(app_config),
)
return migration_12

View File

@@ -0,0 +1,31 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration13Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_archived_col(cursor)
def _add_archived_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `archived` columns to the board table.
"""
cursor.execute("ALTER TABLE boards ADD COLUMN archived BOOLEAN DEFAULT FALSE;")
def build_migration_13() -> Migration:
"""
Build the migration from database version 12 to 13..
This migration does the following:
- Adds `archived` columns to the board table.
"""
migration_13 = Migration(
from_version=12,
to_version=13,
callback=Migration13Callback(),
)
return migration_13

View File

@@ -11,6 +11,7 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
from ..raw_model import RawModel
from .resampler import Resampler
@@ -137,10 +138,7 @@ class IPAdapter(RawModel):
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self):
# workaround for circular import
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]

View File

@@ -10,6 +10,7 @@ from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.util.devices import TorchDevice
from .raw_model import RawModel
@@ -521,7 +522,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype, non_blocking=True)
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
model.layers[layer_key] = layer
return model

View File

@@ -12,7 +12,9 @@ def validate_hash(hash: str):
map = json.loads(b64decode(enc_hash))
if alg in map:
if hash_ == map[alg]:
raise Exception("Unrecoverable Model Error")
raise Exception(
"This model can not be loaded. If you're looking for help, consider visiting https://www.redirectionprogram.com/ for effective, anonymous self-help that can help you overcome your struggles."
)
hashes: list[str] = [

View File

@@ -13,7 +13,6 @@ from .config import (
SchedulerPredictionType,
SubModelType,
)
from .load import LoadedModel
from .probe import ModelProbe
from .search import ModelSearch
@@ -23,7 +22,6 @@ __all__ = [
"BaseModelType",
"ModelRepoVariant",
"InvalidModelConfigException",
"LoadedModel",
"ModelConfigFactory",
"ModelFormat",
"ModelProbe",

View File

@@ -24,20 +24,21 @@ import time
from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
import diffusers
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
class InvalidModelConfigException(Exception):

View File

@@ -1,83 +0,0 @@
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
#
"""Conversion script for the Stable Diffusion checkpoints."""
from pathlib import Path
from typing import Optional
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
download_controlnet_from_original_ckpt,
download_from_original_stable_diffusion_ckpt,
)
from omegaconf import DictConfig
from . import AnyModel
def convert_ldm_vae_to_diffusers(
checkpoint: torch.Tensor | dict[str, torch.Tensor],
vae_config: DictConfig,
image_size: int,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
) -> AutoencoderKL:
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
return vae
def convert_ckpt_to_diffusers(
checkpoint_path: str | Path,
dump_path: Optional[str | Path] = None,
precision: torch.dtype = torch.float16,
use_safetensors: bool = True,
**kwargs,
) -> AnyModel:
"""
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors,
)
return pipe
def convert_controlnet_to_diffusers(
checkpoint_path: Path,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
**kwargs,
) -> AnyModel:
"""
Takes all the arguments of download_controlnet_from_original_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(dump_path, safe_serialization=True)
return pipe

View File

@@ -1,29 +1 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = [
"LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",
"ModelLoader",
"ModelLoaderRegistryBase",
"ModelLoaderRegistry",
]

View File

@@ -1,4 +0,0 @@
from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]

View File

@@ -1,28 +0,0 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path
class ModelConvertCacheBase(ABC):
@property
@abstractmethod
def max_size(self) -> float:
"""Return the maximum size of this cache directory."""
pass
@abstractmethod
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
pass
@abstractmethod
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
pass

View File

@@ -1,83 +0,0 @@
"""
Placeholder for convert cache implementation.
"""
import shutil
from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float = 10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists():
cache_path.mkdir(parents=True)
self._cache_path = cache_path
self._max_size = max_size
# adjust cache size at startup in case it has been changed
if self._cache_path.exists():
self.make_room(0.0)
@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size
@max_size.setter
def max_size(self, value: float) -> None:
"""Set the maximum size of this cache directory (GB)."""
self._max_size = value
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
size_needed = directory_size(self._cache_path) + size
max_size = int(self.max_size) * GIG
logger = InvokeAILogger.get_logger()
if size_needed <= max_size:
return
logger.debug(
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
)
# For this to work, we make the assumption that the directory contains
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
# This should be true for any diffusers model.
def by_atime(path: Path) -> float:
for config in ["model_index.json", "unet/config.json", "config.json"]:
sentinel = path / config
if sentinel.exists():
return sentinel.stat().st_atime
# no sentinel file found! - pick the most recent file in the directory
try:
atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True)
return atimes[0]
except IndexError:
return 0.0
# sort by last access time - least accessed files will be at the end
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
logger.debug(f"cached models in descending atime order: {lru_models}")
while size_needed > max_size and len(lru_models) > 0:
next_victim = lru_models.pop()
victim_size = directory_size(next_victim)
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
shutil.rmtree(next_victim)
size_needed -= victim_size

View File

@@ -0,0 +1,8 @@
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
def _build_model_loader_registry():
return ModelLoaderRegistry()
MODEL_LOADER_REGISTRY = _build_model_loader_registry()

View File

@@ -18,7 +18,6 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@@ -112,7 +111,6 @@ class ModelLoaderBase(ABC):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
pass
@@ -138,12 +136,6 @@ class ModelLoaderBase(ABC):
"""Return size in bytes of the model, calculated before loading."""
pass
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
pass
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:

View File

@@ -12,11 +12,10 @@ from invokeai.backend.model_manager import (
InvalidModelConfigException,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.model_size_utils import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@@ -30,13 +29,11 @@ class ModelLoader(ModelLoaderBase):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
self._app_config = app_config
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
@@ -50,23 +47,15 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path = self._get_model_path(model_config)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
locker = self._convert_and_load(model_config, model_path, submodel_type)
locker = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, _locker=locker)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
return self._convert_cache
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
@@ -76,20 +65,14 @@ class ModelLoader(ModelLoaderBase):
model_base = self._app_config.models_path
return (model_base / config.path).resolve()
def _convert_and_load(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase:
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
try:
return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)
config.path = str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(
config.key,
@@ -113,28 +96,6 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
def _do_convert(
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
) -> AnyModel:
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
if submodel_type:
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
raise NotImplementedError
# This needs to be implemented in the subclass
def _load_model(
self,

View File

@@ -285,9 +285,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
new_dict[k] = v.to(
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)

View File

@@ -1,48 +1,34 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
"""
This module implements a system in which model loaders register the
type, base and format of models that they know how to load.
from typing import Optional, Tuple, Type
Use like this:
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model = cls(
app_config=app_config,
logger=logger,
ram_cache=ram_cache,
convert_cache=convert_cache
).load_model(model_config, submodel_type)
"""
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from ..config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
from . import ModelLoaderBase
from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
class ModelLoaderRegistryBase(ABC):
"""This class allows model loaders to register their type, base and format."""
class ModelLoaderRegistry:
"""A registry that tracks which model loader class to use for a given model type/format/base combination."""
def __init__(self):
self._registry: dict[str, Type[ModelLoaderBase]] = {}
@classmethod
@abstractmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
self,
loader_class: Type[ModelLoaderBase],
type: ModelType,
format: ModelFormat,
base: BaseModelType = BaseModelType.Any,
):
"""Register a model loader class."""
key = self._to_registry_key(base, type, format)
if key in self._registry:
raise RuntimeError(
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
f"of model has already been registered by {self._registry[key].__name__}"
)
self._registry[key] = loader_class
@classmethod
@abstractmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
@@ -56,46 +42,13 @@ class ModelLoaderRegistryBase(ABC):
in, in the event that a submodel type is provided.
"""
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry(ModelLoaderRegistryBase):
"""
This class allows model loaders to register their type, base and format.
"""
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
f"format={config.format}"
)
return implementation, config, submodel_type

View File

@@ -1,9 +1,10 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for ControlNet model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -11,8 +12,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@@ -23,36 +23,15 @@ from .generic_diffusers import GenericDiffusersLoader
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, CheckpointConfigBase)
image_size = (
512
if config.base == BaseModelType.StableDiffusion1
else 768
if config.base == BaseModelType.StableDiffusion2
else 1024
)
self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
result = convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=config_stream,
image_size=image_size,
precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors",
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
return ControlNetModel.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
return result
else:
return super()._load_model(config, submodel_type)

View File

@@ -18,8 +18,8 @@ from invokeai.backend.model_manager import (
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from .. import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)

View File

@@ -8,7 +8,8 @@ import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel

View File

@@ -15,7 +15,6 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from .. import ModelLoader, ModelLoaderRegistry
@@ -32,10 +31,9 @@ class LoRALoader(ModelLoader):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
super().__init__(app_config, logger, ram_cache, convert_cache)
super().__init__(app_config, logger, ram_cache)
self._model_base: Optional[BaseModelType] = None
def _load_model(

View File

@@ -4,22 +4,28 @@
from pathlib import Path
from typing import Optional
from diffusers import (
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SchedulerPredictionType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
ModelVariantType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from invokeai.backend.util.silence_warnings import SilenceWarnings
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@@ -48,8 +54,12 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not submodel_type is not None:
if isinstance(config, CheckpointConfigBase):
return self._load_from_singlefile(config, submodel_type)
if submodel_type is None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
@@ -71,46 +81,58 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
def _load_from_singlefile(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
load_classes = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: StableDiffusionPipeline,
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: StableDiffusionPipeline,
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: StableDiffusionXLPipeline,
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
},
}
assert isinstance(config, MainCheckpointConfig)
base = config.base
try:
load_class = load_classes[config.base][config.variant]
except KeyError as e:
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
prediction_type = config.prediction_type.value
upcast_attention = config.upcast_attention
image_size = (
1024
if base == BaseModelType.StableDiffusionXL
else 768
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
else 512
)
self._logger.info(f"Converting {model_path} to diffusers format")
# Without SilenceWarnings we get log messages like this:
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
# warnings.warn(
# Some weights of the model checkpoint were not used when initializing CLIPTextModel:
# ['text_model.embeddings.position_ids']
# Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
# ['text_model.embeddings.position_ids']
loaded_model = convert_ckpt_to_diffusers(
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
original_config_file=self._app_config.legacy_conf_path / config.config_path,
extract_ema=True,
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
prediction_type=prediction_type,
image_size=image_size,
upcast_attention=upcast_attention,
load_safety_checker=False,
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
)
return loaded_model
with SilenceWarnings():
pipeline = load_class.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
load_safety_checker=False,
)
if not submodel_type:
return pipeline
# Proactively load the various submodels into the RAM cache so that we don't have to re-load
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value)

View File

@@ -1,12 +1,9 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from diffusers import AutoencoderKL
from invokeai.backend.model_manager import (
AnyModelConfig,
@@ -14,56 +11,26 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, VAECheckpointConfig):
return AutoencoderKL.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(config_file)
assert isinstance(ckpt_config, DictConfig)
self._logger.info(f"Converting {model_path} to diffusers format")
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
precision=self._torch_dtype,
dump_path=output_path,
)
return vae_model
return super()._load_model(config, submodel_type)

View File

@@ -0,0 +1,79 @@
import json
from pathlib import Path
from typing import Optional
import torch
def calc_module_size(model: torch.nn.Module) -> int:
"""Estimate the size of a torch.nn.Module in bytes."""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder
# this can happen when, for example, the safety checker is not downloaded.
if not model_path.exists():
return 0
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.name.endswith(index_postfix):
continue
try:
with open(model_path / file, "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.suffix in file_format]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = (model_path / model_file).stat()
model_size += file_stats.st_size
return model_size
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu

View File

@@ -1,14 +1,11 @@
# Copyright (c) 2024 The InvokeAI Development Team
"""Various utility functions needed by the loader and caching system."""
import json
from pathlib import Path
from typing import Optional
import torch
from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
@@ -17,7 +14,7 @@ def calc_model_size_by_data(model: AnyModel) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
return calc_module_size(model)
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
else:
@@ -30,84 +27,11 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
res += calc_module_size(submodel)
return res
def _calc_model_by_data(model: torch.nn.Module) -> int:
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
tensor_size = model.tensors.size() * 2 # The session doubles this
mem = tensor_size # in bytes
return mem
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder
# this can happen when, for example, the safety checker is not downloaded.
if not model_path.exists():
return 0
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.name.endswith(index_postfix):
continue
try:
with open(model_path / file, "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.suffix in file_format]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = (model_path / model_file).stat()
model_size += file_stats.st_size
return model_size
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu

View File

@@ -312,6 +312,8 @@ class ModelProbe(object):
config_file = (
"stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base_type is BaseModelType.StableDiffusionXL
else "stable-diffusion/v2-inference.yaml"
)
else:
@@ -451,8 +453,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
# VAEs of all base types have the same structure, so we wimp out and
# guess using the name.
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
class LoRACheckpointProbe(CheckpointProbeBase):

View File

@@ -294,8 +294,8 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="canny-sdxl",
base=BaseModelType.StableDiffusionXL,
source="diffusers/controlnet-canny-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with canny conditioning.",
source="xinsir/controlnet-canny-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
@@ -326,6 +326,20 @@ STARTER_MODELS: list[StarterModel] = [
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
type=ModelType.ControlNet,
),
StarterModel(
name="openpose-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlnet-openpose-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
name="scribble-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlnet-scribble-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
type=ModelType.ControlNet,
),
# endregion
# region T2I Adapter
StarterModel(

View File

@@ -16,6 +16,7 @@ from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.util.devices import TorchDevice
from .lora import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
@@ -139,12 +140,15 @@ class ModelPatcher:
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32, non_blocking=True)
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"), non_blocking=True)
layer.to(
device=TorchDevice.CPU_DEVICE,
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
@@ -153,7 +157,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
yield # wait for context manager exit
@@ -161,7 +165,9 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
model.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
@classmethod
@contextmanager

View File

@@ -255,8 +255,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Validate assumptions about input tensor shapes.
batch_size, latent_channels, latent_height, latent_width = latents.shape
assert latent_channels == 4
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
assert inpainting_mask == [1, 1, latent_height, latent_width]
assert list(masked_ref_image_latents.shape) == [1, 4, latent_height, latent_width]
assert list(inpainting_mask.shape) == [1, 1, latent_height, latent_width]
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
@@ -299,9 +299,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule.
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
case, and remove this duplicate param.
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
should be populated if you want noise applied *even* if timesteps is empty.
callback: A callback function that is called to report progress during the denoising process.
control_data: ControlNet data.
ip_adapter_data: IP-Adapter data.
@@ -316,9 +315,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
SD UNet model.
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
"""
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
if init_timestep.shape[0] == 0:
return latents
orig_latents = latents.clone()

View File

@@ -13,17 +13,13 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.tiles.utils import TBLR
# The maximum number of regions with compatible sizes that will be batched together.
# Larger batch sizes improve speed, but require more device memory.
MAX_REGION_BATCH_SIZE = 4
from invokeai.backend.tiles.utils import Tile
@dataclass
class MultiDiffusionRegionConditioning:
# Region coords in latent space.
region: TBLR
region: Tile
text_conditioning_data: TextConditioningData
control_data: list[ControlNetData]
@@ -31,31 +27,8 @@ class MultiDiffusionRegionConditioning:
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _split_into_region_batches(
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
) -> list[list[MultiDiffusionRegionConditioning]]:
# Group the regions by shape. Only regions with the same shape can be batched together.
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
for region_conditioning in multi_diffusion_conditioning:
shape_hw = (
region_conditioning.region.bottom - region_conditioning.region.top,
region_conditioning.region.right - region_conditioning.region.left,
)
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
if shape_hw not in conditioning_by_shape:
conditioning_by_shape[shape_hw] = []
conditioning_by_shape[shape_hw].append(region_conditioning)
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
region_conditioning_batches = []
for region_conditioning_batch in conditioning_by_shape.values():
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
return region_conditioning_batches
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Check the input conditioning and confirm that regional prompting is not used."""
"""Validate that regional conditioning is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
@@ -66,6 +39,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
def multi_diffusion_denoise(
self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
target_overlap: int,
latents: torch.Tensor,
scheduler_step_kwargs: dict[str, Any],
noise: Optional[torch.Tensor],
@@ -75,9 +49,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
if init_timestep.shape[0] == 0:
return latents
batch_size, _, latent_height, latent_width = latents.shape
@@ -94,24 +66,16 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
# cropping into regions.
self._adjust_memory_efficient_attention(latents)
# Populate a weighted mask that will be used to combine the results from each region after every step.
# For now, we assume that each region has the same weight (1.0).
region_weight_mask = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
for region_conditioning in multi_diffusion_conditioning:
region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Group the region conditioning into batches for faster processing.
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
# TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
# statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
# as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
# multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
# scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
region_batch_schedulers: list[SchedulerMixin] = [
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
]
callback(
@@ -128,72 +92,68 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
batched_t = t.expand(batch_size)
merged_latents = torch.zeros_like(latents)
merged_latents_weights = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
merged_pred_original: torch.Tensor | None = None
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_batch_idx]
self.scheduler = region_batch_schedulers[region_idx]
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
region_conditioning.region.coords.left : region_conditioning.region.coords.right,
]
# Prepare the latents for the region batch.
batch_latents = torch.cat(
[
latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
for region_conditioning in region_conditioning_batch
],
)
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
# handle broadcasting properly?
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=batch_latents,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=total_step_count,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Run a denoising step on the region.
# step_output = self._region_step(
# region_conditioning=region_conditioning,
# t=batched_t,
# latents=latents,
# step_index=i,
# total_step_count=len(timesteps),
# scheduler_step_kwargs=scheduler_step_kwargs,
# )
# Store the results from the region.
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
# affected tiles to achieve the target overlap.
region = region_conditioning.region
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
top_adjustment = max(0, region.overlap.top - target_overlap)
left_adjustment = max(0, region.overlap.left - target_overlap)
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
:, :, top_adjustment:, left_adjustment:
]
# For now, we treat every region as having the same weight.
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
# If one region has pred_original_sample, then we can assume that all regions will have it, because
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
pred_orig_sample
)
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
:, :, top_adjustment:, left_adjustment:
]
# Normalize the merged results.
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
# For debugging, uncomment this line to visualize the region seams:
# latents = torch.where(merged_latents_weights > 1, 0.0, latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = torch.where(
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
)
callback(
@@ -208,35 +168,3 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
)
return latents
@torch.inference_mode()
def _region_batch_step(
self,
region_conditioning: MultiDiffusionRegionConditioning,
t: torch.Tensor,
latents: torch.Tensor,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
):
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
# Run the denoising step on the region.
return self.step(
t=t,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=step_index,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)

View File

@@ -1,3 +0,0 @@
from .schedulers import SCHEDULER_MAP # noqa: F401
__all__ = ["SCHEDULER_MAP"]

View File

@@ -1,3 +1,5 @@
from typing import Literal
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@@ -43,3 +45,9 @@ SCHEDULER_MAP = {
"lcm": (LCMScheduler, {}),
"tcd": (TCDScheduler, {}),
}
# HACK(ryand): Passing a tuple of keys to Literal works at runtime, but not at type-check time. See the docs here for
# more info: https://typing.readthedocs.io/en/latest/spec/literal.html#parameters-at-runtime. For now, we are ignoring
# this error. In the future, we should fix this type handling.
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # type: ignore

View File

@@ -0,0 +1,35 @@
from contextlib import contextmanager
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
@contextmanager
def patch_vae_tiling_params(
vae: AutoencoderKL | AutoencoderTiny,
tile_sample_min_size: int,
tile_latent_min_size: int,
tile_overlap_factor: float,
):
"""Patch the parameters that control the VAE tiling tile size and overlap.
These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of
the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory
usage.
"""
# Record initial config.
orig_tile_sample_min_size = vae.tile_sample_min_size
orig_tile_latent_min_size = vae.tile_latent_min_size
orig_tile_overlap_factor = vae.tile_overlap_factor
try:
# Apply target config.
vae.tile_sample_min_size = tile_sample_min_size
vae.tile_latent_min_size = tile_latent_min_size
vae.tile_overlap_factor = tile_overlap_factor
yield
finally:
# Restore initial config.
vae.tile_sample_min_size = orig_tile_sample_min_size
vae.tile_latent_min_size = orig_tile_latent_min_size
vae.tile_overlap_factor = orig_tile_overlap_factor

View File

@@ -42,6 +42,10 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
class TorchDevice:
"""Abstraction layer for torch devices."""
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
@@ -108,3 +112,15 @@ class TorchDevice:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
@staticmethod
def get_non_blocking(to_device: torch.device) -> bool:
"""Return the non_blocking flag to be used when moving a tensor to a given device.
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
When moving _from_ MPS, we can use non-blocking operations.
See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
"""
return False if to_device.type == "mps" else True

View File

@@ -5,9 +5,10 @@ from typing import Optional, Union
import pytest
import torch
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
@pytest.fixture(scope="session")

View File

@@ -17,7 +17,10 @@
},
"boards": {
"addBoard": "Add Board",
"archiveBoard": "Archive Board",
"archived": "Archived",
"autoAddBoard": "Auto-Add Board",
"selectedForAutoAdd": "Selected for Auto-Add",
"bottomMessage": "Deleting this board and its images will reset any features currently using them.",
"cancel": "Cancel",
"changeBoard": "Change Board",
@@ -36,8 +39,13 @@
"searchBoard": "Search Boards...",
"selectBoard": "Select a Board",
"topMessage": "This board contains images used in the following features:",
"unarchiveBoard": "Unarchive Board",
"uncategorized": "Uncategorized",
"downloadBoard": "Download Board"
"downloadBoard": "Download Board",
"imagesWithCount_one": "{{count}} image",
"imagesWithCount_other": "{{count}} images",
"assetsWithCount_one": "{{count}} asset",
"assetsWithCount_other": "{{count}} assets"
},
"accordions": {
"generation": {
@@ -364,6 +372,10 @@
"image": "image",
"loading": "Loading",
"loadMore": "Load More",
"newestFirst": "Newest First",
"oldestFirst": "Oldest First",
"sortDirection": "Sort Direction",
"showStarredImagesFirst": "Show Starred Images First",
"noImageSelected": "No Image Selected",
"noImagesInGallery": "No Images to Display",
"setCurrentImage": "Set as Current Image",
@@ -381,6 +393,10 @@
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",
"searchImages": "Search by Metadata",
"selectAllOnPage": "Select All On Page",
"selectAllOnBoard": "Select All On Board",
"showArchivedBoards": "Show Archived Boards",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",

View File

@@ -23,6 +23,7 @@ import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listener
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addRequestedSingleImageDeletionListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDeleted';
@@ -51,6 +52,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -77,6 +80,7 @@ addImagesUnstarredListener(startAppListening);
// Gallery
addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedCanvasListener(startAppListening);
@@ -116,6 +120,7 @@ addControlNetAutoProcessListener(startAppListening);
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);

View File

@@ -0,0 +1,48 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
autoAddBoardIdChanged,
boardIdSelected,
galleryViewChanged,
shouldShowArchivedBoardsChanged,
} from 'features/gallery/store/gallerySlice';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: isAnyOf(
// Updating a board may change its archived status
boardsApi.endpoints.updateBoard.matchFulfilled,
// If the selected/auto-add board was deleted from a different session, we'll only know during the list request,
boardsApi.endpoints.listAllBoards.matchFulfilled,
// If a board is deleted, we'll need to reset the auto-add board
imagesApi.endpoints.deleteBoard.matchFulfilled,
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
// When we change the visibility of archived boards, we may need to reset the auto-add board
shouldShowArchivedBoardsChanged
),
effect: async (action, { dispatch, getState }) => {
/**
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
*/
const state = getState();
const queryArgs = selectListBoardsQueryArgs(state);
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
const autoAddBoardId = state.gallery.autoAddBoardId;
if (!queryResult.data) {
return;
}
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
dispatch(autoAddBoardIdChanged('none'));
dispatch(boardIdSelected({ boardId: 'none' }));
dispatch(galleryViewChanged('images'));
}
},
});
};

View File

@@ -2,8 +2,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageCache } from 'services/api/types';
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
import { getListImagesUrl } from 'services/api/util';
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -18,13 +17,10 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
cancelActiveListeners();
unsubscribe();
// TODO: figure out how to type the predicate
const data = action.payload as ImageCache;
const data = action.payload;
if (data.ids.length > 0) {
// Select the first image
const firstImage = imagesSelectors.selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null));
if (data.items.length > 0) {
dispatch(imageSelected(data.items[0] ?? null));
}
},
});

View File

@@ -1,9 +1,13 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
boardIdSelected,
galleryViewChanged,
imageSelected,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesSelectors } from 'services/api/util';
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -14,14 +18,9 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const state = getState();
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
const queryArgs = selectListImagesQueryArgs(state);
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
const queryArgs = { board_id: board_id ?? 'none', categories };
dispatch(selectionChanged([]));
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state
@@ -35,11 +34,12 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
const selectedImage = boardImagesData.items.find(
(item) => item.image_name === action.payload.selectedImageName
);
dispatch(imageSelected(selectedImage || null));
} else if (boardImagesData) {
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
dispatch(imageSelected(firstImage || null));
dispatch(imageSelected(boardImagesData.items[0] || null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));

View File

@@ -4,7 +4,6 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
export const galleryImageClicked = createAction<{
imageDTO: ImageDTO;
@@ -32,14 +31,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (!listImagesData) {
if (!queryResult.data) {
// Should never happen if we have clicked a gallery image
return;
}
const imageDTOs = imagesSelectors.selectAll(listImagesData);
const imageDTOs = queryResult.data.items;
const selection = state.gallery.selection;
if (altKey) {

View File

@@ -0,0 +1,119 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, offsetChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const addGalleryOffsetChangedListener = (startAppListening: AppStartListening) => {
/**
* When the user changes pages in the gallery, we need to wait until the next page of images is loaded, then maybe
* update the selection.
*
* There are a three scenarios:
*
* 1. The page is changed by clicking the pagination buttons. No changes to selection are needed.
*
* 2. The page is changed by using the arrow keys (without alt).
* - When going backwards, select the last image.
* - When going forwards, select the first image.
*
* 3. The page is changed by using the arrows keys with alt. This means the user is changing the comparison image.
* - When going backwards, select the last image _as the comparison image_.
* - When going forwards, select the first image _as the comparison image_.
*/
startAppListening({
actionCreator: offsetChanged,
effect: async (action, { dispatch, getState, getOriginalState, take, cancelActiveListeners }) => {
// Cancel any active listeners to prevent the selection from changing without user input
cancelActiveListeners();
const { withHotkey } = action.payload;
if (!withHotkey) {
// User changed pages by clicking the pagination buttons - no changes to selection
return;
}
const originalState = getOriginalState();
const prevOffset = originalState.gallery.offset;
const offset = getState().gallery.offset;
if (offset === prevOffset) {
// The page didn't change - bail
return;
}
/**
* We need to wait until the next page of images is loaded before updating the selection, so we use the correct
* page of images.
*
* The simplest way to do it would be to use `take` to wait for the next fulfilled action, but RTK-Q doesn't
* dispatch an action on cache hits. This means the `take` will only return if the cache is empty. If the user
* changes to a cached page - a common situation - the `take` will never resolve.
*
* So we need to take a two-step approach. First, check if we have data in the cache for the page of images. If
* we have data cached, use it to update the selection. If we don't have data cached, wait for the next fulfilled
* action, which updates the cache, then use the cache to update the selection.
*/
// Check if we have data in the cache for the page of images
const queryArgs = selectListImagesQueryArgs(getState());
let { data } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
// No data yet - wait for the network request to complete
if (!data) {
const takeResult = await take(imagesApi.endpoints.listImages.matchFulfilled, 5000);
if (!takeResult) {
// The request didn't complete in time - bail
return;
}
data = takeResult[0].payload;
}
// We awaited a network request - state could have changed, get fresh state
const state = getState();
const { selection, imageToCompare } = state.gallery;
const imageDTOs = data?.items;
if (!imageDTOs) {
// The page didn't load - bail
return;
}
if (withHotkey === 'arrow') {
// User changed pages by using the arrow keys - selection changes to first or last image depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (!selection.some((selectedImage) => selectedImage.image_name === lastImage?.image_name)) {
dispatch(selectionChanged(lastImage ? [lastImage] : []));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (!selection.some((selectedImage) => selectedImage.image_name === firstImage?.image_name)) {
dispatch(selectionChanged(firstImage ? [firstImage] : []));
}
}
return;
}
if (withHotkey === 'alt+arrow') {
// User changed pages by using the arrow keys with alt - comparison image changes to first or last depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (lastImage && imageToCompare?.image_name !== lastImage.image_name) {
dispatch(imageToCompareChanged(lastImage));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (firstImage && imageToCompare?.image_name !== firstImage.image_name) {
dispatch(imageToCompareChanged(firstImage));
}
}
return;
}
},
});
};

View File

@@ -22,11 +22,10 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { clamp, forEach } from 'lodash-es';
import { forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.present.nodes.forEach((node) => {
@@ -118,32 +117,7 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
}
dispatch(isModalOpenChanged(false));
const state = getState();
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const { image_name } = imageDTO;
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
@@ -168,6 +142,20 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
if (wasImageDeleted) {
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
}
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
if (data && data.items) {
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
dispatch(imageSelected(newlySelectedImage || null));
} else {
dispatch(imageSelected(null));
}
}
},
});
@@ -188,10 +176,8 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
const queryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
if (data && data.items[0]) {
dispatch(imageSelected(data.items[0]));
} else {
dispatch(imageSelected(null));
}

View File

@@ -15,7 +15,12 @@ import {
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
imageToCompareChanged,
isImageViewerOpenChanged,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -216,6 +221,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -233,6 +239,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTO,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -248,6 +255,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -261,6 +269,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTOs,
})
);
dispatch(selectionChanged([]));
return;
}
},

View File

@@ -11,6 +11,7 @@ import {
ipaLayerImageChanged,
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { toast } from 'features/toast/toast';
@@ -62,7 +63,8 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
);
// Attempt to get the board's name for the toast
const { data } = boardsApi.endpoints.listAllBoards.select()(state);
const queryArgs = selectListBoardsQueryArgs(state);
const { data } = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
// Fall back to just the board id if we can't find the board for some reason
const board = data?.find((b) => b.board_id === autoAddBoardId);

View File

@@ -8,14 +8,14 @@ import {
galleryViewChanged,
imageSelected,
isImageViewerOpenChanged,
offsetChanged,
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
@@ -52,24 +52,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
if (!imageDTO.is_intermediate) {
/**
* Cache updates for when an image result is received
* - add it to the no_board/images
*/
dispatch(
imagesApi.util.updateQueryData(
'listImages',
{
board_id: imageDTO.board_id ?? 'none',
categories: IMAGE_CATEGORIES,
},
(draft) => {
imagesAdapter.addOne(draft, imageDTO);
}
)
);
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
@@ -78,7 +60,18 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
})
);
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
const { shouldAutoSwitch } = gallery;
@@ -98,6 +91,8 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
);
}
dispatch(offsetChanged({ offset: 0 }));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({

View File

@@ -1,47 +1,37 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import type { MouseEvent, ReactElement } from 'react';
import { memo, useMemo } from 'react';
import type { MouseEvent } from 'react';
import { memo } from 'react';
type Props = {
const sx: SystemStyleObject = {
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
};
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
tooltip: string;
icon?: ReactElement;
styleOverrides?: SystemStyleObject;
};
const IAIDndImageIcon = (props: Props) => {
const { onClick, tooltip, icon, styleOverrides } = props;
const sx = useMemo(
() => ({
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
...styleOverrides,
}),
[styleOverrides]
);
const { onClick, tooltip, icon, ...rest } = props;
return (
<IconButton
onClick={onClick}
aria-label={tooltip}
tooltip={tooltip}
icon={icon}
size="sm"
variant="link"
sx={sx}
data-testid={tooltip}
{...rest}
/>
);
};

View File

@@ -1,16 +0,0 @@
/**
* Comparator function for sorting dates in ascending order
*/
export const dateComparator = (a: string, b: string) => {
const dateA = new Date(a);
const dateB = new Date(b);
// sort in ascending order
if (dateA > dateB) {
return 1;
}
if (dateA < dateB) {
return -1;
}
return 0;
};

View File

@@ -7,6 +7,7 @@ import {
isModalOpenChanged,
selectChangeBoardModalSlice,
} from 'features/changeBoardModal/store/slice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
@@ -20,7 +21,8 @@ const selectImagesToChange = createMemoizedSelector(
const ChangeBoardModal = () => {
const dispatch = useAppDispatch();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const { data: boards, isFetching } = useListAllBoardsQuery();
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
const isModalOpen = useAppSelector((s) => s.changeBoardModal.isModalOpen);
const imagesToChange = useAppSelector(selectImagesToChange);
const [addImagesToBoard] = useAddImagesToBoardMutation();

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
@@ -185,25 +184,25 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={<PiFloppyDiskBold size={16} />}
tooltip={t('controlnet.saveControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={t('controlnet.setControlImageDimensions')}
/>
</Flex>
)}
{pendingControlImages.includes(id) && (
<Flex
@@ -226,6 +225,3 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
};
export default memo(ControlAdapterImagePreview);
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -160,7 +159,7 @@ export const ControlAdapterImagePreview = memo(
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
position="relative"
w="full"
w={36}
h={36}
alignItems="center"
justifyContent="center"
@@ -193,25 +192,27 @@ export const ControlAdapterImagePreview = memo(
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={<PiFloppyDiskBold size={16} />}
tooltip={t('controlnet.saveControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
{controlAdapter.processorPendingBatchId !== null && (
<Flex
@@ -235,6 +236,3 @@ export const ControlAdapterImagePreview = memo(
);
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -82,7 +81,7 @@ export const IPAdapterImagePreview = memo(
}, [handleResetControlImage, isConnected, isErrorControlImage]);
return (
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
<Flex position="relative" w={36} h={36} alignItems="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
@@ -90,24 +89,25 @@ export const IPAdapterImagePreview = memo(
postUploadAction={postUploadAction}
/>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
</Flex>
);
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -79,31 +78,34 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
}, [onReset, isConnected, isErrorControlImage]);
return (
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={imageDTO}
postUploadAction={postUploadAction}
/>
<Flex w="full" alignItems="center" justifyContent="center">
<Flex position="relative" w={36} h={36} alignItems="center" justifyContent="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={imageDTO}
postUploadAction={postUploadAction}
/>
<>
<IAIDndImageIcon
onClick={onReset}
icon={imageDTO ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={onUseSize}
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={useSizeStyleOverrides}
/>
</>
{imageDTO && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={onReset}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={onUseSize}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
</Flex>
</Flex>
);
});
InitialImagePreview.displayName = 'InitialImagePreview';
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -11,25 +11,28 @@ const BoardAutoAddSelect = () => {
const { t } = useTranslation();
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const { options, hasBoards } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: t('controlnet.none'),
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
});
const { options, hasBoards } = useListAllBoardsQuery(
{},
{
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: t('controlnet.none'),
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
}
);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -3,11 +3,12 @@ import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-librar
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { autoAddBoardIdChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
import type { BoardDTO } from 'services/api/types';
@@ -15,52 +16,85 @@ import type { BoardDTO } from 'services/api/types';
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
type Props = {
board?: BoardDTO;
board_id: BoardId;
board: BoardDTO;
children: ContextMenuProps<HTMLDivElement>['children'];
setBoardToDelete?: (board?: BoardDTO) => void;
setBoardToDelete: (board?: BoardDTO) => void;
};
const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props) => {
const BoardContextMenu = ({ board, setBoardToDelete, children }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectGallerySlice, (gallery) => board && board.board_id === gallery.autoAddBoardId),
[board]
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
[board.board_id]
);
const [updateBoard] = useUpdateBoardMutation();
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
const boardName = useBoardName(board_id);
const boardName = useBoardName(board.board_id);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();
const handleSetAutoAdd = useCallback(() => {
dispatch(autoAddBoardIdChanged(board_id));
}, [board_id, dispatch]);
dispatch(autoAddBoardIdChanged(board.board_id));
}, [board.board_id, dispatch]);
const handleBulkDownload = useCallback(() => {
bulkDownload({ image_names: [], board_id: board_id });
}, [board_id, bulkDownload]);
bulkDownload({ image_names: [], board_id: board.board_id });
}, [board.board_id, bulkDownload]);
const handleArchive = useCallback(async () => {
try {
await updateBoard({
board_id: board.board_id,
changes: { archived: true },
}).unwrap();
} catch (error) {
toast({
status: 'error',
title: 'Unable to archive board',
});
}
}, [board.board_id, updateBoard]);
const handleUnarchive = useCallback(() => {
updateBoard({
board_id: board.board_id,
changes: { archived: false },
});
}, [board.board_id, updateBoard]);
const renderMenuFunc = useCallback(
() => (
<MenuList visibility="visible">
<MenuGroup title={boardName}>
<MenuItem
icon={<PiPlusBold />}
isDisabled={isSelectedForAutoAdd || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
{t('boards.menuItemAutoAdd')}
</MenuItem>
{!autoAssignBoardOnClick && (
<MenuItem icon={<PiPlusBold />} isDisabled={isSelectedForAutoAdd} onClick={handleSetAutoAdd}>
{isSelectedForAutoAdd ? t('boards.selectedForAutoAdd') : t('boards.menuItemAutoAdd')}
</MenuItem>
)}
{isBulkDownloadEnabled && (
<MenuItem icon={<PiDownloadBold />} onClickCapture={handleBulkDownload}>
{t('boards.downloadBoard')}
</MenuItem>
)}
{board && <GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />}
{board.archived && (
<MenuItem icon={<PiArchiveBold />} onClick={handleUnarchive}>
{t('boards.unarchiveBoard')}
</MenuItem>
)}
{!board.archived && (
<MenuItem icon={<PiArchiveFill />} onClick={handleArchive}>
{t('boards.archiveBoard')}
</MenuItem>
)}
<GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />
</MenuGroup>
</MenuList>
),
@@ -74,6 +108,8 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
isSelectedForAutoAdd,
setBoardToDelete,
t,
handleArchive,
handleUnarchive,
]
);

View File

@@ -0,0 +1,22 @@
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
type Props = {
board_id: string;
isArchived: boolean;
};
export const BoardTotalsTooltip = ({ board_id, isArchived }: Props) => {
const { t } = useTranslation();
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { imagesTotal: data?.total ?? 0 };
},
});
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { assetsTotal: data?.total ?? 0 };
},
});
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}${isArchived ? ` (${t('boards.archived')})` : ''}`;
};

View File

@@ -2,6 +2,7 @@ import { Collapse, Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo, useState } from 'react';
@@ -26,7 +27,8 @@ const BoardsList = (props: Props) => {
const { isOpen } = props;
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
const { data: boards } = useListAllBoardsQuery();
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards } = useListAllBoardsQuery(queryArgs);
const filteredBoards = boardSearchText
? boards?.filter((board) => board.board_name.toLowerCase().includes(boardSearchText.toLowerCase()))
: boards;

View File

@@ -8,15 +8,12 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
import type { AddToBoardDropData } from 'features/dnd/types';
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImagesSquare } from 'react-icons/pi';
import {
useGetBoardAssetsTotalQuery,
useGetBoardImagesTotalQuery,
useUpdateBoardMutation,
} from 'services/api/endpoints/boards';
import { PiArchiveBold, PiImagesSquare } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { BoardDTO } from 'services/api/types';
@@ -28,6 +25,14 @@ const editableInputStyles: SystemStyleObject = {
},
};
const ArchivedIcon = () => {
return (
<Box position="absolute" top={1} insetInlineEnd={2} p={0} minW={0}>
<Icon as={PiArchiveBold} fill="base.300" filter="drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))" />
</Box>
);
};
interface GalleryBoardProps {
board: BoardDTO;
isSelected: boolean;
@@ -36,6 +41,7 @@ interface GalleryBoardProps {
const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
@@ -51,17 +57,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
setIsHovered(false);
}, []);
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
const tooltip = useMemo(() => {
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
return undefined;
}
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
assetsTotal.total
} asset${assetsTotal.total === 1 ? '' : 's'}`;
}, [assetsTotal, imagesTotal]);
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
const { board_name, board_id } = board;
@@ -117,7 +112,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
const handleChange = useCallback((newBoardName: string) => {
setLocalBoardName(newBoardName);
}, []);
const { t } = useTranslation();
return (
<Box w="full" h="full" userSelect="none">
<Flex
@@ -130,9 +125,12 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
w="full"
h="full"
>
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
<BoardContextMenu board={board} setBoardToDelete={setBoardToDelete}>
{(ref) => (
<Tooltip label={tooltip} openDelay={1000}>
<Tooltip
label={<BoardTotalsTooltip board_id={board.board_id} isArchived={Boolean(board.archived)} />}
openDelay={1000}
>
<Flex
ref={ref}
onClick={handleSelectBoard}
@@ -145,6 +143,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
cursor="pointer"
bg="base.800"
>
{board.archived && <ArchivedIcon />}
{coverImage?.thumbnail_url ? (
<Image
src={coverImage?.thumbnail_url}

View File

@@ -4,12 +4,12 @@ import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay';
import type { RemoveFromBoardDropData } from 'features/dnd/types';
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import NoBoardBoardContextMenu from 'features/gallery/components/Boards/NoBoardBoardContextMenu';
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
import { useBoardName } from 'services/api/hooks/useBoardName';
interface Props {
@@ -29,17 +29,6 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
}, [dispatch, autoAssignBoardOnClick]);
const [isHovered, setIsHovered] = useState(false);
const { data: imagesTotal } = useGetBoardImagesTotalQuery('none');
const { data: assetsTotal } = useGetBoardAssetsTotalQuery('none');
const tooltip = useMemo(() => {
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
return undefined;
}
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
assetsTotal.total
} asset${assetsTotal.total === 1 ? '' : 's'}`;
}, [assetsTotal, imagesTotal]);
const handleMouseOver = useCallback(() => {
setIsHovered(true);
}, []);
@@ -69,9 +58,9 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
w="full"
h="full"
>
<BoardContextMenu board_id="none">
<NoBoardBoardContextMenu>
{(ref) => (
<Tooltip label={tooltip} openDelay={1000}>
<Tooltip label={<BoardTotalsTooltip board_id="none" isArchived={false} />} openDelay={1000}>
<Flex
ref={ref}
onClick={handleSelectBoard}
@@ -122,7 +111,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
</Flex>
</Tooltip>
)}
</BoardContextMenu>
</NoBoardBoardContextMenu>
</Flex>
</Box>
);

View File

@@ -0,0 +1,55 @@
import type { ContextMenuProps } from '@invoke-ai/ui-library';
import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
type Props = {
children: ContextMenuProps<HTMLDivElement>['children'];
};
const NoBoardBoardContextMenu = ({ children }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const isSelectedForAutoAdd = useAppSelector((s) => s.gallery.autoAddBoardId === 'none');
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();
const handleSetAutoAdd = useCallback(() => {
dispatch(autoAddBoardIdChanged('none'));
}, [dispatch]);
const handleBulkDownload = useCallback(() => {
bulkDownload({ image_names: [], board_id: 'none' });
}, [bulkDownload]);
const renderMenuFunc = useCallback(
() => (
<MenuList visibility="visible">
<MenuGroup title={t('boards.uncategorized')}>
{!autoAssignBoardOnClick && (
<MenuItem icon={<PiPlusBold />} isDisabled={isSelectedForAutoAdd} onClick={handleSetAutoAdd}>
{isSelectedForAutoAdd ? t('boards.selectedForAutoAdd') : t('boards.menuItemAutoAdd')}
</MenuItem>
)}
{isBulkDownloadEnabled && (
<MenuItem icon={<PiDownloadBold />} onClickCapture={handleBulkDownload}>
{t('boards.downloadBoard')}
</MenuItem>
)}
</MenuGroup>
</MenuList>
),
[autoAssignBoardOnClick, handleBulkDownload, handleSetAutoAdd, isBulkDownloadEnabled, isSelectedForAutoAdd, t]
);
return <ContextMenu renderMenu={renderMenuFunc}>{children}</ContextMenu>;
};
export default memo(NoBoardBoardContextMenu);

View File

@@ -1,111 +0,0 @@
import type { FormLabelProps } from '@invoke-ai/ui-library';
import {
Checkbox,
CompositeSlider,
Flex,
FormControl,
FormControlGroup,
FormLabel,
IconButton,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
Switch,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
alwaysShowImageSizeBadgeChanged,
autoAssignBoardOnClickChanged,
setGalleryImageMinimumWidth,
shouldAutoSwitchChanged,
} from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { RiSettings4Fill } from 'react-icons/ri';
import BoardAutoAddSelect from './Boards/BoardAutoAddSelect';
const formLabelProps: FormLabelProps = {
flexGrow: 1,
};
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const handleChangeGalleryImageMinimumWidth = useCallback(
(v: number) => {
dispatch(setGalleryImageMinimumWidth(v));
},
[dispatch]
);
const handleChangeAutoSwitch = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldAutoSwitchChanged(e.target.checked));
},
[dispatch]
);
const handleChangeAutoAssignBoardOnClick = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(autoAssignBoardOnClickChanged(e.target.checked)),
[dispatch]
);
const handleChangeAlwaysShowImageSizeBadgeChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(alwaysShowImageSizeBadgeChanged(e.target.checked)),
[dispatch]
);
return (
<Popover isLazy>
<PopoverTrigger>
<IconButton
tooltip={t('gallery.gallerySettings')}
aria-label={t('gallery.gallerySettings')}
size="sm"
icon={<RiSettings4Fill />}
/>
</PopoverTrigger>
<PopoverContent>
<PopoverBody>
<Flex direction="column" gap={2}>
<FormControl>
<FormLabel>{t('gallery.galleryImageSize')}</FormLabel>
<CompositeSlider
value={galleryImageMinimumWidth}
onChange={handleChangeGalleryImageMinimumWidth}
min={45}
max={256}
defaultValue={90}
/>
</FormControl>
<FormControlGroup formLabelProps={formLabelProps}>
<FormControl>
<FormLabel>{t('gallery.autoSwitchNewImages')}</FormLabel>
<Switch isChecked={shouldAutoSwitch} onChange={handleChangeAutoSwitch} />
</FormControl>
<FormControl>
<FormLabel>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
<Checkbox isChecked={autoAssignBoardOnClick} onChange={handleChangeAutoAssignBoardOnClick} />
</FormControl>
<FormControl>
<FormLabel>{t('gallery.alwaysShowImageSizeBadge')}</FormLabel>
<Checkbox isChecked={alwaysShowImageSizeBadge} onChange={handleChangeAlwaysShowImageSizeBadgeChanged} />
</FormControl>
</FormControlGroup>
<BoardAutoAddSelect />
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,26 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { alwaysShowImageSizeBadgeChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(alwaysShowImageSizeBadgeChanged(e.target.checked)),
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.alwaysShowImageSizeBadge')}</FormLabel>
<Checkbox isChecked={alwaysShowImageSizeBadge} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,26 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { autoAssignBoardOnClickChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(autoAssignBoardOnClickChanged(e.target.checked)),
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
<Checkbox isChecked={autoAssignBoardOnClick} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,28 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shouldAutoSwitchChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldAutoSwitchChanged(e.target.checked));
},
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.autoSwitchNewImages')}</FormLabel>
<Checkbox isChecked={shouldAutoSwitch} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,41 @@
import { Divider, Flex, IconButton, Popover, PopoverBody, PopoverContent, PopoverTrigger } from '@invoke-ai/ui-library';
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
import AlwaysShowImageSizeCheckbox from 'features/gallery/components/GallerySettingsPopover/AlwaysShowImageSizeCheckbox';
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
import AutoSwitchCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoSwitchCheckbox';
import ImageMinimumWidthSlider from 'features/gallery/components/GallerySettingsPopover/ImageMinimumWidthSlider';
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
import ShowStarredFirstCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowStarredFirstCheckbox';
import SortDirectionCombobox from 'features/gallery/components/GallerySettingsPopover/SortDirectionCombobox';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { RiSettings4Fill } from 'react-icons/ri';
const GallerySettingsPopover = () => {
const { t } = useTranslation();
return (
<Popover isLazy>
<PopoverTrigger>
<IconButton aria-label={t('gallery.gallerySettings')} size="sm" icon={<RiSettings4Fill />} />
</PopoverTrigger>
<PopoverContent>
<PopoverBody>
<Flex direction="column" gap={2}>
<ImageMinimumWidthSlider />
<AutoSwitchCheckbox />
<AutoAssignBoardCheckbox />
<AlwaysShowImageSizeCheckbox />
<ShowArchivedBoardsCheckbox />
<BoardAutoAddSelect />
<Divider pt={2} />
<ShowStarredFirstCheckbox />
<SortDirectionCombobox />
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,26 @@
import { CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setGalleryImageMinimumWidth } from 'features/gallery/store/gallerySlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
const onChange = useCallback(
(v: number) => {
dispatch(setGalleryImageMinimumWidth(v));
},
[dispatch]
);
return (
<FormControl>
<FormLabel>{t('gallery.galleryImageSize')}</FormLabel>
<CompositeSlider value={galleryImageMinimumWidth} onChange={onChange} min={45} max={256} defaultValue={90} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,28 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shouldShowArchivedBoardsChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const shouldShowArchivedBoards = useAppSelector((s) => s.gallery.shouldShowArchivedBoards);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldShowArchivedBoardsChanged(e.target.checked));
},
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.showArchivedBoards')}</FormLabel>
<Checkbox isChecked={shouldShowArchivedBoards} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,30 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { starredFirstChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const starredFirst = useAppSelector((s) => s.gallery.starredFirst);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(starredFirstChanged(e.target.checked));
},
[dispatch]
);
return (
<FormControl w="full">
<FormLabel flexGrow={1} m={0}>
{t('gallery.showStarredImagesFirst')}
</FormLabel>
<Switch size="sm" isChecked={starredFirst} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -0,0 +1,45 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { SingleValue } from 'chakra-react-select';
import { orderDirChanged } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const orderDir = useAppSelector((s) => s.gallery.orderDir);
const options = useMemo<ComboboxOption[]>(
() => [
{ value: 'DESC', label: t('gallery.newestFirst') },
{ value: 'ASC', label: t('gallery.oldestFirst') },
],
[t]
);
const onChange = useCallback(
(v: SingleValue<ComboboxOption>) => {
assert(v?.value === 'ASC' || v?.value === 'DESC');
dispatch(orderDirChanged(v.value));
},
[dispatch]
);
const value = useMemo(() => {
return options.find((opt) => opt.value === orderDir);
}, [orderDir, options]);
return (
<FormControl>
<FormLabel flexGrow={1} m={0}>
{t('gallery.sortDirection')}
</FormLabel>
<Combobox isSearchable={false} value={value} options={options} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

Some files were not shown because too many files have changed in this diff Show More