Compare commits

..

165 Commits

Author SHA1 Message Date
Ryan Dick
0781fdf3b0 WIP - simplify ModelLoadRegistry 2024-07-02 20:36:36 -04:00
Ryan Dick
8d7ca9c1b7 More refactoring to help with circular imports. 2024-07-02 16:49:03 -04:00
Ryan Dick
a61e0bd2dd Remove symbol re-exports that were contributing to circular import issues. 2024-07-02 15:26:08 -04:00
Ryan Dick
798e73969c Tidy handling of SCHEDULER_NAME_VALUES to help with circular import errors. 2024-07-02 15:12:59 -04:00
Ryan Dick
44f62944ee Fix circular import caused by the organization the model size utils. 2024-07-02 11:55:05 -04:00
Ryan Dick
e9936c27fb Make the VAE tile size configurable for tiled VAE (#6555)
## Summary

- This PR exposes a `tile_size` field on `ImageToLatentsInvocation` and
`LatentsToImageInvocation`.
  - Setting `tile_size = 0` preserves the default behaviour.
- This feature is primarily intended to support upscaling workflows that
require VAE encoding/decoding high resolution images. In the future, we
may want to expose the tile size as a global application config, but
that's a separate conversation.
- As a general rule, larger tile sizes produce better results at the
cost of higher memory usage.

### Example:

Original (5472x5472)

![orig](https://github.com/invoke-ai/InvokeAI/assets/14897797/af0a975d-11ed-4f3c-9e53-84f3da6c997e)

VAE roundtrip with 512x512 tiles (note the discoloration)

![vae_roundtrip_512x512](https://github.com/invoke-ai/InvokeAI/assets/14897797/d589ae3e-fe93-410a-904c-f61f0fc0f1f2)

VAE roundtrip with 1024x1024 tiles (some discoloration still present,
but less severe than at 512x512)

![vae_roundtrip_1024x1024](https://github.com/invoke-ai/InvokeAI/assets/14897797/d0bb9752-3bfa-444f-88c9-39a3ca89c748)


## Related Issues / Discussions

Related: #6144 

## QA Instructions

- [x] Test image generation via the Linear tab
- [x] Test VAE roundtrip with tiling disabled
- [x] Test VAE roundtrip with tiling and tile_size = 0
- [x] Test VAE roundtrip with tiling and tile_size > 0

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-07-02 09:16:07 -04:00
Ryan Dick
3752509066 Expose the VAE tile_size on the VAE encode and decode invocations. 2024-07-02 09:07:03 -04:00
Ryan Dick
a1b7dbfa54 Add unit test for patch_vae_tiling_params(). 2024-07-02 09:07:03 -04:00
Ryan Dick
79640ba14e Add context manager for overriding VAE tiling params. 2024-07-02 09:07:03 -04:00
psychedelicious
4075a81676 feat(ui): gallery image selection ux
The selection logic is a bit complicated. We have image selection and pagination, both of which can be triggered using the mouse or hotkeys. We have viewer image selection and comparison image selection, which is determined by the alt key.

This change ties the room together with these behaviours:

- Changing the page using pagination buttons never changes the selection.
- Changing the selected image using arrows may change the page, if the arrow key pressed would select an image off the current page.
  - `right` on the last image of the current page goes to the next page
  - `down` on the last row of images goes to the next page
  - `left` on the first image of the current page goes to the previous page
  - `up` on the first row of images goes to the previous page
- If `alt` is held when using arrow keys, we change the page, but we only change the comparison image selection.
- When using arrow keys, if the page has changed since the last image was selected, the selection is reset to the first image on the page.
- The next/previous buttons on the image viewer do the same thing as `left` and `right` without `alt`.
- When clicking an image in the gallery:
  - If no modifier keys are held, the image is exclusively selected.
  - If `ctrl` or `meta` are held, the image's selection status is toggled.
  - If `shift` is held, all images from the last-selected image to the image are selected. If there are no images on the current page, the selection is unchanged.
  - If `alt` is held, the image is set as the compare image.
- `ctrl+a` and `meta+a` add the current page to the selection.

The logic for gallery navigation and selection is now pretty hairy. It's spread across 3 hooks, a listener, redux slice, components.

When we next make changes to this part of the app, we should consider consolidating some of the related logic. Probably most of it can go into a single listener and make it much simpler to grok.
2024-07-02 13:52:32 +10:00
psychedelicious
4d39976909 feat(ui): restore loading spinner in search box
@maryhipp you were right, after trying loading bars and different placements, this feels like the best place for it.
2024-07-02 13:52:32 +10:00
Mary Hipp
d14894b3ae (ui) clarify auto-add options 2024-07-02 06:44:09 +10:00
Mary Hipp
6f5c5b0757 lint fix 2024-07-01 15:36:06 -04:00
Mary Hipp
93caa23ef8 undo 2024-07-01 15:36:06 -04:00
Mary Hipp
977a77f4e6 fix(ui): dont mess up redux if 403 gets thrown 2024-07-01 15:36:06 -04:00
Mary Hipp
57c0fcb93d (ui) clarify auto-add options 2024-07-01 15:36:06 -04:00
Kent Keirsey
8b55900035 Update README.md
Updated to include more context confirming the community edition is in fact free for commercial use.
2024-07-01 09:12:31 -07:00
psychedelicious
b1cc413bbd tidy(ui): remove search term fetching indicator
Don't like this UI (even though I suggested it). No need to prevent the user from interacting with the search term field during fetching. Let's figure out a nicer way to present this in a followup.
2024-07-01 20:06:28 +10:00
psychedelicious
face94ce33 feat(ui): tweak search term placeholder verbiage 2024-07-01 20:06:28 +10:00
psychedelicious
f0b1f0e5b6 feat(ui): pass search term as-is to query
The images service does not add the query filter if the search term is an empty string.
2024-07-01 20:06:28 +10:00
psychedelicious
390dc47db5 feat(app): transform search term to lowercase 2024-07-01 20:06:28 +10:00
Mary Hipp
20d5c3a8bf (ui): improve loader/fetching state while searching, make search term a string in redux 2024-07-01 20:06:28 +10:00
maryhipp
134d831ebf (api) simplify query 2024-07-01 20:06:28 +10:00
maryhipp
b65ed8e8f2 fix commented out migration 2024-07-01 20:06:28 +10:00
maryhipp
93951dcf82 (api) ruff 2024-07-01 20:06:28 +10:00
Mary Hipp
da05034e20 feat(ui): debounced gallery search 2024-07-01 20:06:28 +10:00
Mary Hipp
d579aefb3e feat(api): add optional search_term query param to image list to search metadata 2024-07-01 20:06:28 +10:00
blessedcoolant
5d1f6db414 fix(app): fix SQL query w/ enum for python 3.11 (#6557)
## Summary

Python 3.11 has a wonderfully devious breaking change where _sometimes_
using enum classes that inherit from `str` or `int` do not work the same
way as they do in 3.10 when used within string formatting/interpolation.

This breaks the new gallery sort queries. The fix is to use
`order_dir.value` instead of `order_dir` in the query.

This was not an issue during development because the feature was
developed w/ python 3.10.

## Related Issues / Discussions

Thanks to @JPPhoto for reporting and troubleshooting:
https://discord.com/channels/1020123559063990373/1149513625321603162/1256211815982039173

## QA Instructions

JP's fancy python 3.11 system should work on this PR.

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-06-29 18:50:16 +05:30
psychedelicious
f9961eceb7 fix(app): fix SQL query w/ enum for python 3.11 2024-06-29 11:07:39 +10:00
psychedelicious
10076fb1e8 feat(ui): tweak gallery settings popover divider styling 2024-06-28 18:01:01 +10:00
psychedelicious
d6e85e5f67 tidy(ui): rename GalleryBulkSelect -> GallerySelectionCountTag 2024-06-28 18:01:01 +10:00
psychedelicious
1ce459198c chore(ui): knip 2024-06-28 18:01:01 +10:00
psychedelicious
17d337169d fix(ui): do not reset limit when changing gallery view 2024-06-28 18:01:01 +10:00
psychedelicious
1468f4d37e perf(ui): split out gallery settings popover components
This was taking over 15ms (!) to render each time a setting changed, wtf
2024-06-28 18:01:01 +10:00
psychedelicious
2b744480d6 feat(ui): update UI for sorting 2024-06-28 18:01:01 +10:00
psychedelicious
abb8d34b56 chore(ui): typegen 2024-06-28 18:01:01 +10:00
psychedelicious
9e664d7c58 feat(api): remove order_by in favor of starred_first for images records 2024-06-28 18:01:01 +10:00
psychedelicious
c96ccae70b feat(app): remove order_by in favor of starred_first for images records 2024-06-28 18:01:01 +10:00
maryhipp
f268fe126e feat(api): add order_by and order_dir to list images for sorting 2024-06-28 18:01:01 +10:00
Mary Hipp
6109a06f04 feat(ui): gallery sort by created at or starred, asc or desc 2024-06-28 18:01:01 +10:00
Kent Keirsey
5df2a79549 Update starter models 2024-06-28 17:49:45 +10:00
Kent Keirsey
10b9088312 update controlnet starter models 2024-06-28 17:49:45 +10:00
psychedelicious
41f46b846b chore: ruff 2024-06-28 10:36:05 +10:00
psychedelicious
6dfc406c52 tests: update test_bulk_download.py after addition of archived field 2024-06-28 10:36:05 +10:00
psychedelicious
0d4b80780b feat(ui): handle edge cases when archiving/deleting boards
If the currently selected or auto-add board is archived or deleted, we should reset them. There are some edge cases taht weren't handled in the previous implementation.

All handling of this logic is moved to the (renamed) listener.
2024-06-28 10:36:05 +10:00
psychedelicious
15b9ece411 chore(ui): typegen 2024-06-28 10:36:05 +10:00
psychedelicious
89fcab34d0 feat(app): BoardRecord.archived is a required field 2024-06-28 10:36:05 +10:00
psychedelicious
132289de55 chore: ruff E721
Looks like in the latest version of ruff, E721 was added or changed and now catches something it didn't before.
2024-06-28 10:36:05 +10:00
psychedelicious
9f93e9d120 fix(app): when creating image, skip adding to board if board doesn't exist
Before this change, if you attempt to create an image that with a nonexistent board, we'd get an unhandled error when adding the image to a board. The record would be created, but file not, due to the structure of the code.

With this change, we now log a warning if we have a problem adding the image to the board, but the record and file are still created.

A future improvement would be to create a transaction for this part of the code, preventing some other situation that could result in only the record or only the file beings saved.
2024-06-28 10:36:05 +10:00
Mary Hipp
b5f23292d4 lint fix 2024-06-28 10:36:05 +10:00
maryhipp
a63dbb2c2d (api) change query param to include_archived 2024-06-28 10:36:05 +10:00
Mary Hipp
740bf80f3e (ui): update query param to include_archived, fix cache when archiving boards 2024-06-28 10:36:05 +10:00
Mary Hipp
dc90de600d (ui) allow auto-add on archived boards, reset to uncategorized if auto-add board is not currently visible due to archived view 2024-06-28 10:36:05 +10:00
psychedelicious
5709f82e5f feat(ui): separate context menu for no board board
Much easier to not need to handle the board being optional in the component.
2024-06-28 10:36:05 +10:00
psychedelicious
20042d99ec tidy(ui): archived icon component 2024-06-28 10:36:05 +10:00
Mary Hipp
8fce168dc5 fix tsc errors 2024-06-28 10:36:05 +10:00
maryhipp
a7ea096b28 ruff format 2024-06-28 10:36:05 +10:00
Mary Hipp
29eb3c8b62 lint fix 2024-06-28 10:36:05 +10:00
Mary Hipp
071e8bcee4 feat(ui): make archiving and auto-add mutually exclusive 2024-06-28 10:36:05 +10:00
Mary Hipp
68c0aa898f feat(ui): add ability to archive/unarchive boards, add toggle to gallery settings to show/hide archived boards in list 2024-06-28 10:36:05 +10:00
maryhipp
5120a76ce5 cleanup 2024-06-28 10:36:05 +10:00
maryhipp
38a948ac9f feat(api): add archived query param to board list endpoint to include them in the response 2024-06-28 10:36:05 +10:00
maryhipp
c33111468e feat(api): ability to archive boards 2024-06-28 10:36:05 +10:00
Lincoln Stein
3e0fb45dd7 Load single-file checkpoints directly without conversion (#6510)
* use model_class.load_singlefile() instead of converting; works, but performance is poor

* adjust the convert api - not right just yet

* working, needs sql migrator update

* rename migration_11 before conflict merge with main

* Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Update invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* implement lightweight version-by-version config migration

* simplified config schema migration code

* associate sdxl config with sdxl VAEs

* remove use of original_config_file in load_single_file()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-06-27 17:31:28 -04:00
Ryan Dick
aba16085a5 fix(backend): mps should not use non_blocking (#6549)
## Summary

We can get black outputs when moving tensors from CPU to MPS. It appears
MPS to CPU is fine. See:
- https://github.com/pytorch/pytorch/issues/107455
-
https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28

Changes:
- Add properties for each device on `TorchDevice` as a convenience.
- Add `get_non_blocking` static method on `TorchDevice`. This utility
takes a torch device and returns the flag to be used for non_blocking
when moving a tensor to the device provided.
- Update model patching and caching APIs to use this new utility.

## Related Issues / Discussions

Fixes: #6545

## QA Instructions

For both MPS and CUDA:
- Generate at least 5 images using LoRAs
- Generate at least 5 images using IP Adapters

## Merge Plan

We have pagination merged into `main` but aren't ready for that to be
released.

Once this fix is tested and merged, we will probably want to create a
`v4.2.5post1` branch off the `v4.2.5` tag, cherry-pick the fix and do a
release from the hotfix branch.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_ @RyanJDick @lstein This
feels testable but I'm not sure how.
- [ ] _Documentation added / updated (if applicable)_
2024-06-27 10:11:53 -04:00
Ryan Dick
14775cc9c4 ruff format 2024-06-27 09:45:13 -04:00
psychedelicious
c7562dd6c0 fix(backend): mps should not use non_blocking
We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28

Changes:
- Add properties for each device on `TorchDevice` as a convenience.
- Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided.
- Update model patching and caching APIs to use this new utility.

Fixes: #6545
2024-06-27 19:15:23 +10:00
psychedelicious
a0a0c57789 chore(ui): knip 2024-06-27 13:48:40 +10:00
psychedelicious
32ebf82d1a feat(ui): better pagination buttons 2024-06-27 13:48:40 +10:00
psychedelicious
2dd172c2c6 feat(ui): gallery bulk select styling 2024-06-27 13:48:40 +10:00
psychedelicious
280ec9d4b3 fix(ui): invalidate getImageDTO caches when images are mutated 2024-06-27 13:48:40 +10:00
psychedelicious
fde8fc7575 perf(ui): optimistic updates for getImageDTO query cache 2024-06-27 13:48:40 +10:00
psychedelicious
6dcdc87eb1 fix(ui): control adapter image preview 2024-06-27 13:48:40 +10:00
Mary Hipp
93ffcb642e lint fix 2024-06-27 13:48:40 +10:00
Mary Hipp
4c914ef2e8 use correct query params for boardIdSelected listener 2024-06-27 13:48:40 +10:00
Mary Hipp
c0ad5bc4a4 fix when deleting first image in list 2024-06-27 13:48:40 +10:00
Mary Hipp
8c58a180de GG another fix 2024-06-27 13:48:40 +10:00
Mary Hipp
715dd983b0 appease the knip 2024-06-27 13:48:40 +10:00
Mary Hipp
84ffd36071 lint fix 2024-06-27 13:48:40 +10:00
Mary Hipp
9f30f1bfec fix circular dep 2024-06-27 13:48:40 +10:00
Mary Hipp
bdff5c4e87 only show selected when greater than 0 2024-06-27 13:48:40 +10:00
Mary Hipp
afb0651f91 clear selection when board or gallery view changes 2024-06-27 13:48:40 +10:00
Mary Hipp
66e25628c3 fix neg pages 2024-06-27 13:48:40 +10:00
Mary Hipp
3a531a3c88 remove rest of cache, add bulk select UI 2024-06-27 13:48:40 +10:00
Mary Hipp
f01df49128 lint fix 2024-06-27 13:48:40 +10:00
Mary Hipp
7bbe236107 implmenet custom sort to replace images adapter logic 2024-06-27 13:48:40 +10:00
psychedelicious
719c066ac4 feat(ui): more efficient board totals fetching
We only need to show the totals in the tooltip. Tooltips accpet a component for the tooltip label. The component isn't rendered until the tooltip is triggered.

Move the board total fetching into a tooltip component for the boards. Now we only fire these requests when the user mouses over the board
2024-06-27 13:48:40 +10:00
psychedelicious
689dc30f87 feat(ui): tweak pagination buttons
- Fix off-by-one error when going to last page
- Update component to have minimal/no layout shift
2024-06-27 13:48:40 +10:00
psychedelicious
1f22f6ae02 feat(ui): iterate on dynamic gallery limit
- Simplify the gallery layout
- Set an initial gallery limit to load _some_ images immediately.
- Refactor the resize observer to use the actual rendered image component to calculate the number of images per row/col. This prevents inaccuracies caused by image padding that could result in the wrong number of images.
- Debounce the limit update to not thrash teh API
- Use absolute positioning trick to ensure the gallery container is always exactly the right size
- Minimum of `imagesPerRow` images loaded at all times
2024-06-27 13:48:40 +10:00
psychedelicious
9c931d9ca0 fix(ui): gallery content overflow
This is one of those unexpected CSS quirks. Flex containers need min-width or min-height for their children to not overflow. Add `minH={0}` to gallery container.
2024-06-27 13:48:40 +10:00
Mary Hipp
e0a241fa4f wip change limit based on size of gallery 2024-06-27 13:48:40 +10:00
Mary Hipp
6a4b4ee340 trying to invalidate all the tags 2024-06-27 13:48:40 +10:00
Mary Hipp
488bf21925 fix single pagers 2024-06-27 13:48:40 +10:00
Mary Hipp
c9c39c02b6 handle generations coming in, fix pagination to use total from list query so it updates as that changes 2024-06-27 13:48:40 +10:00
Mary Hipp
5101dc4bef some cleanup, add page buttons 2024-06-27 13:48:40 +10:00
Mary Hipp
98c77a3ed1 pull in spencers work 2024-06-27 13:48:40 +10:00
psychedelicious
4fca62680d Update invokeai_version.py 2024-06-27 10:41:01 +10:00
Ryan Dick
f76282a5ff Fix handling handling of 0-step denoising process (#6544)
## Summary

https://github.com/invoke-ai/InvokeAI/pull/6522 introduced a change in
behavior in cases where start/end were set such that there are 0
timesteps. This PR reverts that change.

cc @StAlKeR7779 

## QA Instructions

Run with euler, 5 steps, start: 0.0, end: 0.05. I ran this test before
#6522, after #6522, and on this branch. This branch restores the
behavior to pre-#6522 i.e. noise is injected even if no denoising steps
are applied.


## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-06-26 13:01:58 -04:00
Ryan Dick
9a3b8c6fcb Fix handling of init_timestep in StableDiffusionGeneratorPipeline and improve its documentation. 2024-06-26 12:51:51 -04:00
Ryan Dick
bd74b84cc5 Revert "Remove the redundant init_timestep parameter that was being passed around. It is simply the first element of the timesteps array."
This reverts commit fa40061eca.
2024-06-26 12:51:51 -04:00
Brandon Rising
dc23bebebf Run ruff 2024-06-26 21:46:59 +10:00
Kent Keirsey
38b6f90c02 Update prevention exception message 2024-06-26 21:46:59 +10:00
Ryan Dick
cd9dfefe3c Fix inpainting mask shape assertions. 2024-06-25 11:31:52 -07:00
Ryan Dick
b9946e50f9 Use image-space tile dimensions on the TiledMultiDiffusionDenoiseLatents invocation. This is more natural for many users. 2024-06-25 11:31:52 -07:00
Ryan Dick
06f49a30f6 Mark TiledMultiDiffusionDenoiseLatents as a Beta node. 2024-06-25 11:31:52 -07:00
Ryan Dick
e1af78c702 Make the tile_overlap input to MultiDiffusion *strictly* control the amount of overlap rather than being a lower bound. 2024-06-25 11:31:52 -07:00
Ryan Dick
c5588e1ff7 Add TODO comment explaining why some schedulers do not interact well with MultiDiffusion. 2024-06-25 11:31:52 -07:00
Ryan Dick
07ac292680 Consolidate _region_step() function - the separation wasn't really adding any value. 2024-06-25 11:31:52 -07:00
Ryan Dick
7c032ea604 (minor) Fix some documentation typos. 2024-06-25 11:31:52 -07:00
Ryan Dick
c5ee415607 Add progress image callbacks to TiledMultiDiffusionDenoiseLatentsInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
fa40061eca Remove the redundant init_timestep parameter that was being passed around. It is simply the first element of the timesteps array. 2024-06-25 11:31:52 -07:00
Ryan Dick
7cafd78d6e Revert "Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation."
This reverts commit 753239b48d.
2024-06-25 11:31:52 -07:00
Ryan Dick
8a43656cf9 (minor) Address a few small TODOs. 2024-06-25 11:31:52 -07:00
Ryan Dick
bd3b6ca11b Remove TiledStableDiffusionRefineInvocation. It was a proof-of-concept that has been superseded by TiledMultiDiffusionDenoiseLatents. 2024-06-25 11:31:52 -07:00
Ryan Dick
ceae5fe1db (minor) typo 2024-06-25 11:31:52 -07:00
Ryan Dick
25067e4f0d Delete rough notes. 2024-06-25 11:31:52 -07:00
Ryan Dick
fb0aaa3e6d Fix advanced scheduler behaviour in MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
c22526b9d0 Fix handling of stateful schedulers in MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
c881882f73 Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend. 2024-06-25 11:31:52 -07:00
Ryan Dick
36473fc52a Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported. 2024-06-25 11:31:52 -07:00
Ryan Dick
b9964ecc4a Initial (untested) implementation of MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
051af802fe Remove inpainting support from MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
3ff2e558d9 Remove IP-Adapter and T2I-Adapter support from MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
fc187c9253 Document plan for the rest of the MultiDiffusion implementation. 2024-06-25 11:31:52 -07:00
Ryan Dick
605f460c7d Add detailed docstring to latents_from_embeddings(). 2024-06-25 11:31:52 -07:00
Ryan Dick
60d1e686d8 Copy StableDiffusionGeneratorPipeline as a starting point for a new MultiDiffusionPipeline. 2024-06-25 11:31:52 -07:00
Ryan Dick
22704dd542 Simplify handling of inpainting models. Improve the in-code documentation around inpainting. 2024-06-25 11:31:52 -07:00
Ryan Dick
875673c9ba Minor tidying of latents_from_embeddings(...). 2024-06-25 11:31:52 -07:00
Ryan Dick
f604575862 Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function. 2024-06-25 11:31:52 -07:00
Ryan Dick
80a67572f1 Fix invocation name of tiled_multi_diffusion_denoise_latents. 2024-06-25 11:31:52 -07:00
Ryan Dick
60ac937698 Improve clarity of comments regarded when 'noise' and 'latents' are expected to be set. 2024-06-25 11:31:52 -07:00
Ryan Dick
1e41949a02 Fix static check errors on imports in diffusers_pipeline.py. 2024-06-25 11:31:52 -07:00
Ryan Dick
5f0e330ed2 Remove a condition for handling inpainting models that never resolves to True. The same logic is already applied earlier by AddsMaskLatents. 2024-06-25 11:31:52 -07:00
Ryan Dick
9dd779b414 Add clarifying comment to explain why noise might be None in latents_from_embedding(). 2024-06-25 11:31:52 -07:00
Ryan Dick
fa183025ac Remove unused are_like_tensors() function. 2024-06-25 11:31:52 -07:00
Ryan Dick
d3c85aa91a Remove unused StableDiffusionGeneratorPipeline.use_ip_adapter member. 2024-06-25 11:31:52 -07:00
Ryan Dick
82619602a5 Remove unused StableDiffusionGeneratorPipeline.control_model. 2024-06-25 11:31:52 -07:00
Ryan Dick
196f3b721d Stricter typing for the is_gradient_mask: bool. 2024-06-25 11:31:52 -07:00
Ryan Dick
244c28859d Fix typing of control_data to reflect that it can be None. 2024-06-25 11:31:52 -07:00
Ryan Dick
40ae174c41 Fix typing of timesteps and init_timestep. 2024-06-25 11:31:52 -07:00
Ryan Dick
afaebdf151 Fix typing to reflect that the callback arg to latents_from_embeddings is never None. 2024-06-25 11:31:52 -07:00
Ryan Dick
d661517d94 Move seed above optional params. 2024-06-25 11:31:52 -07:00
Ryan Dick
82a69a54ac Simplify handling of AddsMaskGuidance, and fix some related type errors. 2024-06-25 11:31:52 -07:00
Ryan Dick
ffc28176fe Remove unused num_inference_steps. 2024-06-25 11:31:52 -07:00
Ryan Dick
230e205541 WIP TiledMultiDiffusionDenoiseLatents. Updated parameter list and first half of the logic. 2024-06-25 11:31:52 -07:00
Ryan Dick
7e94350351 Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors. 2024-06-25 11:31:52 -07:00
Ryan Dick
c4e8549c73 Make DenoiseLatentsInvocation.prep_control_data(...) a staticmethod so that it can be called externally. 2024-06-25 11:31:52 -07:00
Ryan Dick
350a210835 Copy TiledStableDiffusionRefineInvocation as a starting point for TiledMultiDiffusionDenoiseLatents.py 2024-06-25 11:31:52 -07:00
Ryan Dick
ed781dbb0c Change tiling strategy to make TiledStableDiffusionRefineInvocation work with more tile shapes and overlaps. 2024-06-25 11:31:52 -07:00
Ryan Dick
b41ea963e7 Expose a few more params from TiledStableDiffusionRefineInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
da5d105049 Add support for LoRA models in TiledStableDiffusionRefineInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
5301770525 Add naive ControlNet support to TiledStableDiffusionRefineInvocation 2024-06-25 11:31:52 -07:00
Ryan Dick
d08e405017 Fix ControlNetModel type hint import source. 2024-06-25 11:31:52 -07:00
Ryan Dick
534640ccde Rough prototype of TiledStableDiffusionRefineInvocation is working. 2024-06-25 11:31:52 -07:00
Ryan Dick
d5ab8cab5c WIP - TiledStableDiffusionRefine 2024-06-25 11:31:52 -07:00
Ryan Dick
4767301ad3 Minor improvements to LatentsToImageInvocation type hints. 2024-06-25 11:31:52 -07:00
Ryan Dick
21d7ca45e6 Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation. 2024-06-25 11:31:52 -07:00
Ryan Dick
020e8eb413 Fix return type of prepare_noise_and_latents(...). 2024-06-25 11:31:52 -07:00
Ryan Dick
3d49541c09 Make init_scheduler() a staticmethod on DenoiseLatentsInvocation so that it can be called externally. 2024-06-25 11:31:52 -07:00
Ryan Dick
1ef266845a Only allow a single positive/negative prompt conditioning input for tiled refine. 2024-06-25 11:31:52 -07:00
Ryan Dick
a37589ca5f WIP on TiledStableDiffusionRefine 2024-06-25 11:31:52 -07:00
Ryan Dick
171a505f5e Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally. 2024-06-25 11:31:52 -07:00
Ryan Dick
8004a0d5f5 Simplify the logic in prepare_noise_and_latents(...). 2024-06-25 11:31:52 -07:00
Ryan Dick
610a1fd611 Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations. 2024-06-25 11:31:52 -07:00
Ryan Dick
43108eec13 (minor) Add a TODO note to get_scheduler(...). 2024-06-25 11:31:52 -07:00
149 changed files with 3762 additions and 3555 deletions

View File

@@ -12,12 +12,24 @@
Invoke is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. Invoke offers an industry leading web-based UI, and serves as the foundation for multiple commercial products.
[Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs]
Invoke is available in two editions:
| **Community Edition** | **Professional Edition** |
|----------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| **For users looking for a locally installed, self-hosted and self-managed service** | **For users or teams looking for a cloud-hosted, fully managed service** |
| - Free to use under a commercially-friendly license | - Monthly subscription fee with three different plan levels |
| - Download and install on compatible hardware | - Offers additional benefits, including multi-user support, improved model training, and more |
| - Includes all core studio features: generate, refine, iterate on images, and build workflows | - Hosted in the cloud for easy, secure model access and scalability |
| Quick Start -> [Installation and Updates][installation docs] | More Information -> [www.invoke.com/pricing](https://www.invoke.com/pricing) |
<div align="center">
![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d)
# Documentation
| **Quick Links** |
|----------------------------------------------------------------------------------------------------------------------------|
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
</div>
## Quick Start

View File

@@ -73,15 +73,6 @@ model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a
diffusers model.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also
imported by, and can be reexported from,
`invokeai.app.services.model_manager.model_records`:
```
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
```
The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file.
@@ -1328,7 +1319,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, logger=logger
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size

View File

@@ -118,15 +118,13 @@ async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all()
return ApiDependencies.invoker.services.boards.get_all(include_archived)
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
else:
raise HTTPException(
status_code=400,

View File

@@ -9,9 +9,14 @@ from PIL import Image
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from ..dependencies import ApiDependencies
@@ -316,16 +321,14 @@ async def list_image_dtos(
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
)
return image_dtos

View File

@@ -3,9 +3,9 @@
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
@@ -19,7 +19,6 @@ from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordChanges,
UnknownModelException,
@@ -30,7 +29,6 @@ from invokeai.backend.model_manager.config import (
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
@@ -174,18 +172,6 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary(
# page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@@ -746,39 +732,36 @@ async def convert_model(
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
try:
cc_size = loader.convert_cache.max_size
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
loader._convert_cache.max_size = 1.0
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
finally:
loader._convert_cache.max_size = cc_size
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
converted_model = loader.load_model(model_config)
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path)
assert convert_path.exists()
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# install the diffusers
try:
new_key = installer.install_path(
convert_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except Exception as e:
logger.error(str(e))
store.update_model(key, changes=ModelRecordChanges(name=original_name))
raise HTTPException(status_code=409, detail=str(e))
# Update the model image if the model had one
try:
@@ -791,8 +774,8 @@ async def convert_model(
# delete the original safetensors file
installer.delete(key)
# delete the cached version
shutil.rmtree(cache_path)
# delete the temporary directory
# shutil.rmtree(cache_path)
# return the config record for the new diffusers directory
new_config = store.get_model(new_key)

View File

@@ -103,7 +103,6 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -118,7 +117,6 @@ class CompelInvocation(BaseInvocation):
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
@@ -205,7 +203,6 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -316,6 +313,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(

View File

@@ -1,6 +1,5 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
@@ -11,9 +10,6 @@ factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
"""
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""

View File

@@ -19,8 +19,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor

View File

@@ -1,5 +1,4 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@@ -18,7 +17,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
@@ -54,8 +53,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -66,6 +66,9 @@ def get_scheduler(
scheduler_name: str,
seed: int,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
@@ -183,8 +186,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
@@ -194,8 +197,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
@@ -203,8 +207,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks
@staticmethod
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
@@ -226,11 +231,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
assert isinstance(resized_mask, torch.Tensor)
return resized_mask
@staticmethod
def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
@@ -280,7 +284,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
@@ -302,36 +308,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
return BasicConditioningInfo(embeds=text_embedding), regions
@staticmethod
def get_conditioning_data(
self,
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = positive_conditioning_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
@@ -339,23 +350,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype,
)
if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
if isinstance(cfg_scale, list):
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
guidance_scale=cfg_scale,
guidance_rescale_multiplier=cfg_rescale_multiplier,
)
return conditioning_data
@staticmethod
def create_pipeline(
self,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
@@ -378,38 +387,38 @@ class DenoiseLatentsInvocation(BaseInvocation):
requires_safety_checker=False,
)
@staticmethod
def prep_control_data(
self,
context: InvocationContext,
control_input: Optional[Union[ControlField, List[ControlField]]],
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
control_list = None
elif isinstance(control_input, ControlField):
) -> list[ControlNetData] | None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
control_list = None
if control_list is None:
return None
# After above handling, any control that is not None should now be of type list[ControlField].
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
if len(control_list) == 0:
return None
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
@@ -430,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
@@ -584,15 +593,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@staticmethod
def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int, Dict[str, Any]]:
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
@@ -618,7 +627,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
@@ -640,7 +648,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
return timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
@@ -657,31 +665,52 @@ class DenoiseLatentsInvocation(BaseInvocation):
return 1 - mask, masked_latents, self.denoise_mask.gradient
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed = None
@staticmethod
def prepare_noise_and_latents(
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
function handles preparing these values accordingly.
Expected workflows:
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
- Image-to-Image Denoising: `noise` and `latents` are both provided.
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
I haven't considered.
"""
noise = None
if self.noise is not None:
noise = context.tensors.load(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.tensors.load(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
if noise_field is not None:
noise = context.tensors.load(noise_field.latents_name)
if latents_field is not None:
latents = context.tensors.load(latents_field.latents_name)
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise Exception("'latents' or 'noise' must be provided!")
raise ValueError("'latents' or 'noise' must be provided!")
if seed is None:
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
if noise_field is not None and noise_field.seed is not None:
seed = noise_field.seed
elif latents_field is not None and latents_field.seed is not None:
seed = latents_field.seed
else:
seed = 0
return seed, noise, latents
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
@@ -707,7 +736,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
# get the unet's config so that we can pass the base to dispatch_progress()
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
@@ -755,7 +784,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = self.prep_control_data(
@@ -777,7 +814,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype,
)
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
@@ -794,8 +831,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed,
mask=mask,
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,

View File

@@ -160,6 +160,8 @@ class FieldDescriptions:
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the "
"model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"

View File

@@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import singledispatchmethod
import einops
@@ -12,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@@ -22,8 +23,9 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation(
@@ -31,7 +33,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.0.2",
version="1.1.0",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
@@ -44,12 +46,17 @@ class ImageToLatentsInvocation(BaseInvocation):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
@@ -81,9 +88,18 @@ class ImageToLatentsInvocation(BaseInvocation):
else:
vae.disable_tiling()
tiling_context = nullcontext()
if tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=tile_size,
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
@@ -101,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation):
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)

View File

@@ -1,3 +1,5 @@
from contextlib import nullcontext
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
@@ -8,10 +10,9 @@ from diffusers.models.attention_processor import (
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -24,6 +25,7 @@ from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@@ -32,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.2.2",
version="1.3.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -46,6 +48,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
@@ -53,9 +58,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@@ -87,10 +92,19 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
vae.disable_tiling()
tiling_context = nullcontext()
if self.tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=self.tile_size,
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode():
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]

View File

@@ -1,5 +1,4 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
@@ -7,6 +6,7 @@ from invokeai.app.invocations.fields import (
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@invocation_output("scheduler_output")

View File

@@ -0,0 +1,282 @@
import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
different parameters for each region to harness the full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField(
default=32,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
"amount.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
control: ControlField | list[ControlField] | None = InputField(
default=None,
input=Input.Connection,
)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(),
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
# Get the unet's config so that we can pass the base to sd_step_callback().
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=latent_tile_overlap,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
callback=step_callback,
)
result_latents = result_latents.to("cpu")
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@@ -40,16 +40,12 @@ class BoardRecordStorageBase(ABC):
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardRecord]:
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
"""Gets all board records."""
pass

View File

@@ -22,6 +22,8 @@ class BoardRecord(BaseModelExcludeNull):
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
"""The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
@@ -35,6 +37,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
return BoardRecord(
board_id=board_id,
@@ -43,12 +46,14 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
archived=archived,
)
class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.")
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
class BoardRecordNotFoundException(Exception):

View File

@@ -125,6 +125,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
self._cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@@ -134,35 +145,49 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
return self.get(board_id)
def get_many(
self,
offset: int = 0,
limit: int = 10,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
"""
# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
)
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
count = cast(int, self._cursor.fetchone()[0])
@@ -174,20 +199,25 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
finally:
self._lock.release()
def get_all(
self,
) -> list[BoardRecord]:
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
)
"""
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

View File

@@ -44,16 +44,12 @@ class BoardServiceABC(ABC):
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardDTO]:
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
"""Gets all boards."""
pass

View File

@@ -48,8 +48,10 @@ class BoardService(BoardServiceABC):
def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id)
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@@ -63,8 +65,8 @@ class BoardService(BoardServiceABC):
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all()
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import copy
import locale
import os
import re
@@ -25,9 +26,8 @@ DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
@@ -85,7 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
log_tokenization: Enable logging of parsed prompt tokens.
patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
convert_cache_dir: Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).
download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
@@ -102,19 +102,16 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
devices: List of execution devices; will override default device selected.
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
@@ -150,7 +147,7 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
@@ -172,15 +169,13 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION
@@ -190,7 +185,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
@@ -361,14 +355,14 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a v4.0.0.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
An `InvokeAIAppConfig` config dict.
"""
parsed_config_dict: dict[str, Any] = {}
@@ -380,6 +374,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
@@ -399,55 +396,43 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
parsed_config_dict["schema_version"] = "4.0.0"
return parsed_config_dict
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
def migrate_v4_0_0_to_4_0_1_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to a v4.0.1 config dictionary
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
A config dict with the settings migrated to v4.0.1.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# precision "autocast" was replaced by "auto" in v4.0.1
if parsed_config_dict.get("precision") == "autocast":
parsed_config_dict["precision"] = "auto"
parsed_config_dict["schema_version"] = "4.0.1"
return parsed_config_dict
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.1 config dictionary to a current config object.
A few new multi-GPU options were added in 4.0.2, and this simply
updates the schema label.
def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.1 config dictionary to a v4.0.2 config dictionary.
Args:
config_dict: A dictionary of settings from a v4.0.1 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
An config dict with the settings migrated to v4.0.2.
"""
parsed_config_dict: dict[str, Any] = {}
for k, _ in config_dict.items():
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# convert_cache was removed in 4.0.2
parsed_config_dict.pop("convert_cache", None)
parsed_config_dict["schema_version"] = "4.0.2"
return parsed_config_dict
# TO DO: replace this with a formal registration and migration system
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -459,31 +444,31 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
loaded_config_dict: dict[str, Any] = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
migrated = False
if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it
migrated = True
loaded_config_dict = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
if loaded_config_dict["schema_version"] == "4.0.0":
migrated = True
loaded_config_dict = migrate_v4_0_0_to_4_0_1_config_dict(loaded_config_dict)
if loaded_config_dict["schema_version"] == "4.0.1":
migrated = True
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
if migrated:
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
# load and write without environment variables
migrated_config = DefaultInvokeAIAppConfig.model_validate(loaded_config_dict)
migrated_config.write_file(config_path)
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
elif loaded_config_dict["schema_version"] == "4.0.1":
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)

View File

@@ -4,6 +4,7 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
@@ -37,10 +38,13 @@ class ImageRecordStorageBase(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass

View File

@@ -5,6 +5,7 @@ from typing import Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .image_records_base import ImageRecordStorageBase
@@ -144,10 +145,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
@@ -208,9 +212,21 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
query_params.append(board_id)
query_pagination = """--sql
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
"""
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"

View File

@@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class ImageServiceABC(ABC):
@@ -116,10 +117,13 @@ class ImageServiceABC(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass

View File

@@ -5,6 +5,7 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from ..image_files.image_files_common import (
ImageFileDeleteException,
@@ -73,7 +74,12 @@ class ImageService(ImageServiceABC):
session_id=session_id,
)
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
try:
self.__invoker.services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
@@ -202,19 +208,25 @@ class ImageService(ImageServiceABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_many(
offset,
limit,
starred_first,
order_dir,
image_origin,
categories,
is_intermediate,
board_id,
search_term,
)
image_dtos = [

View File

@@ -53,11 +53,11 @@ class InvocationServices:
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
performance_statistics: "InvocationStatsServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
@@ -77,11 +77,11 @@ class InvocationServices:
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache
self.names = names
self.performance_statistics = performance_statistics
self.urls = urls
self.workflow_records = workflow_records
self.tensors = tensors

View File

@@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self, graph_execution_state_id: str):
self._stats.pop(graph_execution_state_id)
self._cache_stats.pop(graph_execution_state_id)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -284,14 +284,9 @@ class ModelInstallService(ModelInstallServiceBase):
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
self._install_jobs = unfinished_jobs
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
if overwrite_db:
for model in db_models:
self.record_store.del_model(model.key)
db_models = self.record_store.all_models()
legacy_models_yaml_path = (
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
@@ -341,8 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
if rename_yaml:
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
# Unset the path - we are done with it either way
self._app_config.legacy_models_yaml_path = None

View File

@@ -6,8 +6,7 @@ from pathlib import Path
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@@ -28,16 +27,6 @@ class ModelLoadServiceBase(ABC):
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None

View File

@@ -2,7 +2,7 @@
"""Implementation of model loader service."""
from pathlib import Path
from typing import Callable, Optional, Type
from typing import Callable, Optional
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
@@ -11,14 +11,9 @@ from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -33,8 +28,7 @@ class ModelLoadService(ModelLoadServiceBase):
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
registry: ModelLoaderRegistry,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)
@@ -42,11 +36,9 @@ class ModelLoadService(ModelLoadServiceBase):
self._logger = logger
self._app_config = app_config
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._registry = registry
def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker
@property
@@ -54,16 +46,6 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -82,7 +64,6 @@ class ModelLoadService(ModelLoadServiceBase):
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):

View File

@@ -1,17 +0,0 @@
"""Initialization file for model manager service."""
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelType",
"SubModelType",
"LoadedModel",
]

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional, Set
import torch
from typing_extensions import Self
@@ -32,7 +31,7 @@ class ModelManagerServiceBase(ABC):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_devices: Optional[Set[torch.device]] = None,
execution_device: torch.device,
) -> Self:
"""
Construct the model manager service instance.

View File

@@ -1,10 +1,15 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@@ -65,6 +70,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@@ -77,13 +83,13 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
app_config=app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry,
)
installer = ModelInstallService(

View File

@@ -1,6 +1,5 @@
import shutil
import tempfile
import threading
import typing
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
@@ -10,7 +9,6 @@ import torch
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
@@ -72,10 +70,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
return self._output_dir / name
def _new_name(self) -> str:
tid = threading.current_thread().ident
# Add tid to the object name because uuid4 not thread-safe on windows
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
return f"{self._obj_class_name}_{uuid_string()}"
def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists."""

View File

@@ -1,9 +1,8 @@
import traceback
from contextlib import suppress
from queue import Queue
from threading import BoundedSemaphore, Lock, Thread
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional, Set
from typing import Optional
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
@@ -27,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import SessionQueu
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from invokeai.backend.util.devices import TorchDevice
from ..invoker import Invoker
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
@@ -59,11 +57,8 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
self._process_lock = Lock()
def start(
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
) -> None:
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
@@ -81,8 +76,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# Loop over invocations until the session is complete or canceled
while True:
try:
with self._process_lock:
invocation = queue_item.session.next()
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
@@ -114,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
@@ -216,7 +210,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
@@ -330,7 +324,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
self._active_queue_items: Set[SessionQueueItem] = set()
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self._resume_event = ThreadEvent()
@@ -356,14 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
TorchDevice.execution_devices()
)
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
# Session processor - singlethreaded
self._thread = Thread(
name="session_processor",
target=self._process,
@@ -376,16 +363,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
self._thread.start()
# Session processor workers - multithreaded
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
for _i in range(0, self._worker_thread_count):
worker = Thread(
name="session_worker",
target=self._process_next_session,
daemon=True,
)
worker.start()
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
@@ -393,7 +370,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event.set()
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
self._cancel_event.set()
self._poll_now()
@@ -401,7 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
@@ -426,7 +403,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._resume_event.is_set(),
is_processing=len(self._active_queue_items) > 0,
is_processing=self._queue_item is not None,
)
def _process(
@@ -451,22 +428,30 @@ class DefaultSessionProcessor(SessionProcessorBase):
resume_event.wait()
# Get the next session to process
queue_item = self._invoker.services.session_queue.dequeue()
self._queue_item = self._invoker.services.session_queue.dequeue()
if queue_item is None:
if self._queue_item is None:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
self._session_worker_queue.put(queue_item)
self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# Run the graph
# self.session_runner.run(queue_item=self._queue_item)
self.session_runner.run(queue_item=self._queue_item)
except Exception:
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_non_fatal_processor_error(
queue_item=self._queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
@@ -481,25 +466,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
finally:
stop_event.clear()
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
def _process_next_session(self) -> None:
while True:
self._resume_event.wait()
queue_item = self._session_worker_queue.get()
if queue_item.status == "canceled":
continue
try:
self._active_queue_items.add(queue_item)
# reserve a GPU for this session - may block
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
# Run the session on the reserved GPU
self.session_runner.run(queue_item=queue_item)
except Exception:
continue
finally:
self._active_queue_items.remove(queue_item)
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],

View File

@@ -236,9 +236,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
}
)
def __hash__(self) -> int:
return self.item_id
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
@@ -27,13 +26,11 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
"""
The InvocationContext provides access to various services and data about the current invocation.
@@ -326,6 +323,7 @@ class ConditioningInterface(InvocationContextInterface):
Returns:
The loaded conditioning data.
"""
return self._services.conditioning.load(name)
@@ -559,28 +557,6 @@ class UtilInterface(InvocationContextInterface):
is_canceled=self.is_canceled,
)
def torch_device(self) -> torch.device:
"""
Return a torch device to use in the current invocation.
Returns:
A torch.device not currently in use by the system.
"""
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
return ram_cache.get_execution_device()
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
"""
Return a precision type to use with the current invocation and torch device.
Args:
device: Optional device.
Returns:
A torch.dtype suited for the current device.
"""
return TorchDevice.choose_torch_dtype(device)
class InvocationContext:
"""Provides access to various services and data for the current invocation.

View File

@@ -14,6 +14,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -45,6 +47,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,35 @@
import shutil
import sqlite3
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration12Callback:
def __init__(self, app_config: InvokeAIAppConfig) -> None:
self._app_config = app_config
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_model_convert_cache_dir()
def _remove_model_convert_cache_dir(self) -> None:
"""
Removes unused model convert cache directory
"""
convert_cache = self._app_config.convert_cache_path
shutil.rmtree(convert_cache, ignore_errors=True)
def build_migration_12(app_config: InvokeAIAppConfig) -> Migration:
"""
Build the migration from database version 11 to 12.
This migration removes the now-unused model convert cache directory.
"""
migration_12 = Migration(
from_version=11,
to_version=12,
callback=Migration12Callback(app_config),
)
return migration_12

View File

@@ -0,0 +1,31 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration13Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_archived_col(cursor)
def _add_archived_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `archived` columns to the board table.
"""
cursor.execute("ALTER TABLE boards ADD COLUMN archived BOOLEAN DEFAULT FALSE;")
def build_migration_13() -> Migration:
"""
Build the migration from database version 12 to 13..
This migration does the following:
- Adds `archived` columns to the board table.
"""
migration_13 = Migration(
from_version=12,
to_version=13,
callback=Migration13Callback(),
)
return migration_13

View File

@@ -289,7 +289,7 @@ def prepare_control_image(
width: int,
height: int,
num_channels: int = 3,
device: str = "cuda",
device: str | torch.device = "cuda",
dtype: torch.dtype = torch.float16,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
@@ -304,7 +304,7 @@ def prepare_control_image(
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.

View File

@@ -11,6 +11,7 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
from ..raw_model import RawModel
from .resampler import Resampler
@@ -137,10 +138,7 @@ class IPAdapter(RawModel):
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self):
# workaround for circular import
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]

View File

@@ -10,6 +10,7 @@ from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.util.devices import TorchDevice
from .raw_model import RawModel
@@ -521,7 +522,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype, non_blocking=True)
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
model.layers[layer_key] = layer
return model

View File

@@ -12,7 +12,9 @@ def validate_hash(hash: str):
map = json.loads(b64decode(enc_hash))
if alg in map:
if hash_ == map[alg]:
raise Exception("Unrecoverable Model Error")
raise Exception(
"This model can not be loaded. If you're looking for help, consider visiting https://www.redirectionprogram.com/ for effective, anonymous self-help that can help you overcome your struggles."
)
hashes: list[str] = [

View File

@@ -13,7 +13,6 @@ from .config import (
SchedulerPredictionType,
SubModelType,
)
from .load import LoadedModel
from .probe import ModelProbe
from .search import ModelSearch
@@ -23,7 +22,6 @@ __all__ = [
"BaseModelType",
"ModelRepoVariant",
"InvalidModelConfigException",
"LoadedModel",
"ModelConfigFactory",
"ModelFormat",
"ModelProbe",

View File

@@ -24,21 +24,21 @@ import time
from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
class InvalidModelConfigException(Exception):
@@ -178,7 +178,6 @@ class ModelConfigBase(BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Extend the pydantic schema from a json."""
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
@@ -445,7 +444,7 @@ class ModelConfigFactory(object):
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data)
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key

View File

@@ -1,83 +0,0 @@
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
#
"""Conversion script for the Stable Diffusion checkpoints."""
from pathlib import Path
from typing import Optional
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
download_controlnet_from_original_ckpt,
download_from_original_stable_diffusion_ckpt,
)
from omegaconf import DictConfig
from . import AnyModel
def convert_ldm_vae_to_diffusers(
checkpoint: torch.Tensor | dict[str, torch.Tensor],
vae_config: DictConfig,
image_size: int,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
) -> AutoencoderKL:
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
return vae
def convert_ckpt_to_diffusers(
checkpoint_path: str | Path,
dump_path: Optional[str | Path] = None,
precision: torch.dtype = torch.float16,
use_safetensors: bool = True,
**kwargs,
) -> AnyModel:
"""
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors,
)
return pipe
def convert_controlnet_to_diffusers(
checkpoint_path: Path,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
**kwargs,
) -> AnyModel:
"""
Takes all the arguments of download_controlnet_from_original_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(dump_path, safe_serialization=True)
return pipe

View File

@@ -1,29 +1 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = [
"LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",
"ModelLoader",
"ModelLoaderRegistryBase",
"ModelLoaderRegistry",
]

View File

@@ -1,4 +0,0 @@
from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]

View File

@@ -1,28 +0,0 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path
class ModelConvertCacheBase(ABC):
@property
@abstractmethod
def max_size(self) -> float:
"""Return the maximum size of this cache directory."""
pass
@abstractmethod
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
pass
@abstractmethod
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
pass

View File

@@ -1,83 +0,0 @@
"""
Placeholder for convert cache implementation.
"""
import shutil
from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float = 10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists():
cache_path.mkdir(parents=True)
self._cache_path = cache_path
self._max_size = max_size
# adjust cache size at startup in case it has been changed
if self._cache_path.exists():
self.make_room(0.0)
@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size
@max_size.setter
def max_size(self, value: float) -> None:
"""Set the maximum size of this cache directory (GB)."""
self._max_size = value
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
size_needed = directory_size(self._cache_path) + size
max_size = int(self.max_size) * GIG
logger = InvokeAILogger.get_logger()
if size_needed <= max_size:
return
logger.debug(
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
)
# For this to work, we make the assumption that the directory contains
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
# This should be true for any diffusers model.
def by_atime(path: Path) -> float:
for config in ["model_index.json", "unet/config.json", "config.json"]:
sentinel = path / config
if sentinel.exists():
return sentinel.stat().st_atime
# no sentinel file found! - pick the most recent file in the directory
try:
atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True)
return atimes[0]
except IndexError:
return 0.0
# sort by last access time - least accessed files will be at the end
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
logger.debug(f"cached models in descending atime order: {lru_models}")
while size_needed > max_size and len(lru_models) > 0:
next_victim = lru_models.pop()
victim_size = directory_size(next_victim)
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
shutil.rmtree(next_victim)
size_needed -= victim_size

View File

@@ -0,0 +1,8 @@
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
def _build_model_loader_registry():
return ModelLoaderRegistry()
MODEL_LOADER_REGISTRY = _build_model_loader_registry()

View File

@@ -18,7 +18,6 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@@ -65,7 +64,8 @@ class LoadedModelWithoutConfig:
def __enter__(self) -> AnyModel:
"""Context entry."""
return self._locker.lock()
self._locker.lock()
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
@@ -111,7 +111,6 @@ class ModelLoaderBase(ABC):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
pass
@@ -137,12 +136,6 @@ class ModelLoaderBase(ABC):
"""Return size in bytes of the model, calculated before loading."""
pass
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
pass
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:

View File

@@ -12,11 +12,10 @@ from invokeai.backend.model_manager import (
InvalidModelConfigException,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.model_size_utils import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@@ -30,13 +29,11 @@ class ModelLoader(ModelLoaderBase):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
self._app_config = app_config
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
@@ -50,23 +47,15 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path = self._get_model_path(model_config)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
locker = self._convert_and_load(model_config, model_path, submodel_type)
locker = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, _locker=locker)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
return self._convert_cache
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
@@ -76,20 +65,14 @@ class ModelLoader(ModelLoaderBase):
model_base = self._app_config.models_path
return (model_base / config.path).resolve()
def _convert_and_load(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase:
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
try:
return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)
config.path = str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(
config.key,
@@ -113,28 +96,6 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
def _do_convert(
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
) -> AnyModel:
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
if submodel_type:
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
raise NotImplementedError
# This needs to be implemented in the subclass
def _load_model(
self,

View File

@@ -8,10 +8,9 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
from typing import Dict, Generic, Optional, TypeVar
import torch
@@ -52,13 +51,44 @@ class CacheRecord(Generic[T]):
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Read-only copy of the model *without weights* residing in the "meta device"
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
size: int
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
@dataclass
@@ -85,27 +115,14 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
pass
@contextmanager
@property
@abstractmethod
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
"""Reserve an execution device (GPU) under the current thread id."""
pass
@abstractmethod
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It might be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@@ -114,6 +131,16 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> Optional[CacheStats]:
@@ -175,11 +202,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move a copy of the model into the indicated device and return it."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@@ -18,19 +18,17 @@ context. Use like this:
"""
import copy
import gc
import sys
import threading
from contextlib import contextmanager, suppress
import math
import time
from contextlib import suppress
from logging import Logger
from threading import BoundedSemaphore
from typing import Dict, Generator, List, Optional, Set
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -41,7 +39,9 @@ from .model_locker import ModelLocker
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
@@ -57,8 +57,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
@@ -66,19 +70,23 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@@ -86,87 +94,25 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
# device to thread id
self._device_lock = threading.Lock()
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self.logger.info(
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
)
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
return self._logger
@property
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
return self._lazy_offloading
@property
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
return self._storage_device
@property
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
devices = self._execution_devices.keys()
return set(devices)
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
current_thread = threading.current_thread().ident
assert current_thread is not None
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if not assigned:
raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
return assigned[0]
@contextmanager
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
Note that the reservation is done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
"""
device = None
with self._device_lock:
current_thread = threading.current_thread().ident
assert current_thread is not None
# look for a device that has already been assigned to this thread
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if assigned:
device = assigned[0]
# no device already assigned. Get one.
if device is None:
self._free_execution_device.acquire(timeout=timeout)
with self._device_lock:
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
self._execution_devices[free_device[0]] = current_thread
device = free_device[0]
# we are outside the lock region now
self.logger.info(f"{current_thread} Reserved torch device {device}")
# Tell TorchDevice to use this object to get the torch device.
TorchDevice.set_model_cache(self)
try:
yield device
finally:
with self._device_lock:
self.logger.info(f"{current_thread} Released torch device {device}")
self._execution_devices[device] = 0
self._free_execution_device.release()
torch.cuda.empty_cache()
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
@property
def max_cache_size(self) -> float:
@@ -211,16 +157,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get(
self,
@@ -238,37 +184,36 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache.
"""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
cache_entry = self._cached_models[key]
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
@@ -280,34 +225,129 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
return model_key
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move a copy of the model into the indicated device and return it.
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
with self._ram_lock:
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Some models don't have a state dictionary, in which case the
# stored model will still reside in CPU
if hasattr(cache_entry.model, "to"):
model_in_gpu = copy.deepcopy(cache_entry.model)
assert hasattr(model_in_gpu, "to")
model_in_gpu.to(target_device)
return model_in_gpu
else:
return cache_entry.model # what happens in CPU stays in CPU
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
in_ram_models = len(self._cached_models)
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_in_vram_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
@@ -330,14 +370,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
refs = sys.getrefcount(cache_entry.model)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if refs <= (3 if "onnx" in model_key else 2):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
@@ -364,26 +402,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
if self.stats:
self.stats.cleared = models_cleared
gc.collect()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
if target_device.type != "cuda":
return
vram_device = ( # mem_get_info() needs an indexed device
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
)
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem:
raise torch.cuda.OutOfMemoryError
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
try:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
except ValueError:
pass
@staticmethod
def _device_name(device: torch.device) -> str:
return f"{device.type}:{device.index}"
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]

View File

@@ -10,8 +10,6 @@ from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
@@ -31,29 +29,33 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
self._cache_entry.lock()
try:
device = self._cache.get_execution_device()
model_on_device = self._cache.model_to_device(self._cache_entry, device)
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
raise
except Exception:
self._cache_entry.unlock()
raise
return model_on_device
return self.model
# It is no longer necessary to move the model out of VRAM
# because it will be removed when it goes out of scope
# in the caller's context
def unlock(self) -> None:
"""Call upon exit from context."""
self._cache.print_cuda_stats()
# This is no longer in use in MGPU.
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return None
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()

View File

@@ -1,48 +1,34 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
"""
This module implements a system in which model loaders register the
type, base and format of models that they know how to load.
from typing import Optional, Tuple, Type
Use like this:
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model = cls(
app_config=app_config,
logger=logger,
ram_cache=ram_cache,
convert_cache=convert_cache
).load_model(model_config, submodel_type)
"""
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from ..config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
from . import ModelLoaderBase
from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
class ModelLoaderRegistryBase(ABC):
"""This class allows model loaders to register their type, base and format."""
class ModelLoaderRegistry:
"""A registry that tracks which model loader class to use for a given model type/format/base combination."""
def __init__(self):
self._registry: dict[str, Type[ModelLoaderBase]] = {}
@classmethod
@abstractmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
self,
loader_class: Type[ModelLoaderBase],
type: ModelType,
format: ModelFormat,
base: BaseModelType = BaseModelType.Any,
):
"""Register a model loader class."""
key = self._to_registry_key(base, type, format)
if key in self._registry:
raise RuntimeError(
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
f"of model has already been registered by {self._registry[key].__name__}"
)
self._registry[key] = loader_class
@classmethod
@abstractmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
@@ -56,46 +42,13 @@ class ModelLoaderRegistryBase(ABC):
in, in the event that a submodel type is provided.
"""
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry(ModelLoaderRegistryBase):
"""
This class allows model loaders to register their type, base and format.
"""
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
f"format={config.format}"
)
return implementation, config, submodel_type

View File

@@ -1,9 +1,10 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for ControlNet model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -11,8 +12,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@@ -23,36 +23,15 @@ from .generic_diffusers import GenericDiffusersLoader
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, CheckpointConfigBase)
image_size = (
512
if config.base == BaseModelType.StableDiffusion1
else 768
if config.base == BaseModelType.StableDiffusion2
else 1024
)
self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
result = convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=config_stream,
image_size=image_size,
precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors",
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
return ControlNetModel.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
return result
else:
return super()._load_model(config, submodel_type)

View File

@@ -18,8 +18,8 @@ from invokeai.backend.model_manager import (
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from .. import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)

View File

@@ -8,7 +8,8 @@ import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel

View File

@@ -15,7 +15,6 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from .. import ModelLoader, ModelLoaderRegistry
@@ -32,10 +31,9 @@ class LoRALoader(ModelLoader):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
super().__init__(app_config, logger, ram_cache, convert_cache)
super().__init__(app_config, logger, ram_cache)
self._model_base: Optional[BaseModelType] = None
def _load_model(

View File

@@ -4,22 +4,28 @@
from pathlib import Path
from typing import Optional
from diffusers import (
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SchedulerPredictionType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
ModelVariantType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from invokeai.backend.util.silence_warnings import SilenceWarnings
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@@ -48,8 +54,12 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not submodel_type is not None:
if isinstance(config, CheckpointConfigBase):
return self._load_from_singlefile(config, submodel_type)
if submodel_type is None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
@@ -71,46 +81,58 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
def _load_from_singlefile(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
load_classes = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: StableDiffusionPipeline,
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: StableDiffusionPipeline,
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: StableDiffusionXLPipeline,
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
},
}
assert isinstance(config, MainCheckpointConfig)
base = config.base
try:
load_class = load_classes[config.base][config.variant]
except KeyError as e:
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
prediction_type = config.prediction_type.value
upcast_attention = config.upcast_attention
image_size = (
1024
if base == BaseModelType.StableDiffusionXL
else 768
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
else 512
)
self._logger.info(f"Converting {model_path} to diffusers format")
# Without SilenceWarnings we get log messages like this:
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
# warnings.warn(
# Some weights of the model checkpoint were not used when initializing CLIPTextModel:
# ['text_model.embeddings.position_ids']
# Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
# ['text_model.embeddings.position_ids']
loaded_model = convert_ckpt_to_diffusers(
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
original_config_file=self._app_config.legacy_conf_path / config.config_path,
extract_ema=True,
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
prediction_type=prediction_type,
image_size=image_size,
upcast_attention=upcast_attention,
load_safety_checker=False,
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
)
return loaded_model
with SilenceWarnings():
pipeline = load_class.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
load_safety_checker=False,
)
if not submodel_type:
return pipeline
# Proactively load the various submodels into the RAM cache so that we don't have to re-load
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value)

View File

@@ -1,12 +1,9 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from diffusers import AutoencoderKL
from invokeai.backend.model_manager import (
AnyModelConfig,
@@ -14,8 +11,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@@ -26,39 +22,15 @@ from .generic_diffusers import GenericDiffusersLoader
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, VAECheckpointConfig):
return AutoencoderKL.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(config_file)
assert isinstance(ckpt_config, DictConfig)
self._logger.info(f"Converting {model_path} to diffusers format")
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
precision=self._torch_dtype,
dump_path=output_path,
)
return vae_model
return super()._load_model(config, submodel_type)

View File

@@ -0,0 +1,79 @@
import json
from pathlib import Path
from typing import Optional
import torch
def calc_module_size(model: torch.nn.Module) -> int:
"""Estimate the size of a torch.nn.Module in bytes."""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder
# this can happen when, for example, the safety checker is not downloaded.
if not model_path.exists():
return 0
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.name.endswith(index_postfix):
continue
try:
with open(model_path / file, "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.suffix in file_format]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = (model_path / model_file).stat()
model_size += file_stats.st_size
return model_size
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu

View File

@@ -1,14 +1,11 @@
# Copyright (c) 2024 The InvokeAI Development Team
"""Various utility functions needed by the loader and caching system."""
import json
from pathlib import Path
from typing import Optional
import torch
from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
@@ -17,7 +14,7 @@ def calc_model_size_by_data(model: AnyModel) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
return calc_module_size(model)
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
else:
@@ -30,84 +27,11 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
res += calc_module_size(submodel)
return res
def _calc_model_by_data(model: torch.nn.Module) -> int:
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
tensor_size = model.tensors.size() * 2 # The session doubles this
mem = tensor_size # in bytes
return mem
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder
# this can happen when, for example, the safety checker is not downloaded.
if not model_path.exists():
return 0
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.name.endswith(index_postfix):
continue
try:
with open(model_path / file, "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.suffix in file_format]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = (model_path / model_file).stat()
model_size += file_stats.st_size
return model_size
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu

View File

@@ -312,6 +312,8 @@ class ModelProbe(object):
config_file = (
"stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base_type is BaseModelType.StableDiffusionXL
else "stable-diffusion/v2-inference.yaml"
)
else:

View File

@@ -294,8 +294,8 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="canny-sdxl",
base=BaseModelType.StableDiffusionXL,
source="diffusers/controlnet-canny-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with canny conditioning.",
source="xinsir/controlnet-canny-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
@@ -326,6 +326,20 @@ STARTER_MODELS: list[StarterModel] = [
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
type=ModelType.ControlNet,
),
StarterModel(
name="openpose-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlnet-openpose-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
name="scribble-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlnet-scribble-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
type=ModelType.ControlNet,
),
# endregion
# region T2I Adapter
StarterModel(

View File

@@ -4,7 +4,6 @@
from __future__ import annotations
import pickle
import threading
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@@ -17,6 +16,7 @@ from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.util.devices import TorchDevice
from .lora import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
@@ -35,8 +35,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
_thread_lock = threading.Lock()
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
@@ -109,7 +107,7 @@ class ModelPatcher:
"""
original_weights = {}
try:
with torch.no_grad(), cls._thread_lock:
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@@ -132,7 +130,9 @@ class ModelPatcher:
dtype = module.weight.dtype
if module_key not in original_weights:
if model_state_dict is None: # no CPU copy of the state dict was provided
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
@@ -140,12 +140,15 @@ class ModelPatcher:
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32, non_blocking=True)
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"), non_blocking=True)
layer.to(
device=TorchDevice.CPU_DEVICE,
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
@@ -154,7 +157,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
yield # wait for context manager exit
@@ -162,7 +165,9 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
model.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
@classmethod
@contextmanager

View File

@@ -10,12 +10,11 @@ import PIL.Image
import psutil
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -26,6 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass
@@ -38,56 +38,18 @@ class PipelineIntermediateState:
predicted_original: Optional[torch.Tensor] = None
@dataclass
class AddsMaskLatents:
"""Add the channels required for inpainting model input.
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
and the latent encoding of the base image.
This class assumes the same mask and base image should apply to all items in the batch.
"""
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.Tensor
initial_image_latents: torch.Tensor
def __call__(
self,
latents: torch.Tensor,
t: torch.Tensor,
text_embeddings: torch.Tensor,
**kwargs,
) -> torch.Tensor:
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings, **kwargs)
def add_mask_channels(self, latents):
batch_size = latents.size(0)
# duplicate mask and latents for each batch
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
# add mask and image as additional channels
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
return model_input
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
return isinstance(b, torch.Tensor) and (a.size() == b.size())
@dataclass
class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
mask: torch.Tensor
mask_latents: torch.Tensor
scheduler: SchedulerMixin
noise: torch.Tensor
gradient_mask: bool
is_gradient_mask: bool
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.apply_mask(latents, t)
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
@@ -100,7 +62,7 @@ class AddsMaskGuidance:
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self.gradient_mask:
if self.is_gradient_mask:
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
masked_input = torch.where(mask_bool, latents, mask_latents)
@@ -200,7 +162,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
control_model: ControlNetModel = None,
):
super().__init__(
vae=vae,
@@ -214,8 +175,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
@@ -280,116 +239,128 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
def add_inpainting_channels_to_latents(
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
):
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
input of shape (N, 9, H, W). The 9 channels are defined as follows:
- Channel 0-3: The latents being denoised.
- Channel 4: The mask indicating which parts of the image are being inpainted.
- Channel 5-8: The latent representation of the masked reference image being inpainted.
This function assumes that the same mask and base image should apply to all items in the batch.
"""
# Validate assumptions about input tensor shapes.
batch_size, latent_channels, latent_height, latent_width = latents.shape
assert latent_channels == 4
assert list(masked_ref_image_latents.shape) == [1, 4, latent_height, latent_width]
assert list(inpainting_mask.shape) == [1, 1, latent_height, latent_width]
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
# Concatenate along the channel dimension.
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
def latents_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*,
noise: Optional[torch.Tensor],
seed: int,
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
callback: Callable[[PipelineIntermediateState], None],
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: int,
is_gradient_mask: bool = False,
) -> torch.Tensor:
"""Denoise the latents.
Args:
latents: The latent-space image to denoise.
- If we are inpainting, this is the initial latent image before noise has been added.
- If we are generating a new image, this should be initialized to zeros.
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
conditioning_data: Text conditionging data.
noise: Noise used for two purposes:
1. Used by the scheduler to noise the initial `latents` before denoising.
2. Used to noise the `masked_latents` when inpainting.
`noise` should be None if the `latents` tensor has already been noised.
seed: The seed used to generate the noise for the denoising process.
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
should be populated if you want noise applied *even* if timesteps is empty.
callback: A callback function that is called to report progress during the denoising process.
control_data: ControlNet data.
ip_adapter_data: IP-Adapter data.
t2i_adapter_data: T2I-Adapter data.
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
`latents` tensor.
TODO(ryand): Check and document the expected dtype, range, and values used to represent
foreground/background.
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
SD UNet model.
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
"""
if init_timestep.shape[0] == 0:
return latents
if additional_guidance is None:
additional_guidance = []
orig_latents = latents.clone()
batch_size = latents.shape[0]
batched_t = init_timestep.expand(batch_size)
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
if mask is not None:
if is_inpainting_model(self.unet):
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try:
latents = self.generate_latents_from_embeddings(
latents,
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback,
)
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
timesteps,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
additional_guidance = []
batch_size = latents.shape[0]
# Handle mask guidance (a.k.a. inpainting).
mask_guidance: AddsMaskGuidance | None = None
if mask is not None and not is_inpainting_model(self.unet):
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
# apply mask guidance to the latents.
if timesteps.shape[0] == 0:
return latents
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
# We still need noise for inpainting, so we generate it from the seed here.
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
mask_guidance = AddsMaskGuidance(
mask=mask,
mask_latents=orig_latents,
scheduler=self.scheduler,
noise=noise,
is_gradient_mask=is_gradient_mask,
)
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
@@ -402,28 +373,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
batched_t,
latents,
conditioning_data,
t=batched_t,
latents=latents,
conditioning_data=conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
mask_guidance=mask_guidance,
mask=mask,
masked_latents=masked_latents,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
@@ -431,19 +402,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
if callback is not None:
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if is_gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
@torch.inference_mode()
def step(
@@ -454,19 +434,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
mask_guidance: AddsMaskGuidance | None,
mask: torch.Tensor | None,
masked_latents: torch.Tensor | None,
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
# one day we will expand this extension point, but for now it just does denoise masking
for guidance in additional_guidance:
latents = guidance(latents, timestep)
# Handle masked image-to-image (a.k.a inpainting).
if mask_guidance is not None:
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
latents = mask_guidance(latents, timestep)
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
@@ -514,6 +495,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_intrablock_additional_residuals = accum_adapter_state
# Handle inpainting models.
if is_inpainting_model(self.unet):
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
# latents.
if mask is not None:
if masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!")
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
)
else:
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
# We generate a global mask and empty original image so that we can still generate in this
# configuration.
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
# to do this.
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
# mask and original image once rather than on every denoising step.
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input,
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
)
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
@@ -542,17 +548,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance:
# apply the mask to any "denoised" or "pred_original_sample" fields
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
# again.
if mask_guidance is not None:
# Apply the mask to any "denoised" or "pred_original_sample" fields.
if hasattr(step_output, "denoised"):
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
elif hasattr(step_output, "pred_original_sample"):
step_output.pred_original_sample = guidance(
step_output.pred_original_sample = mask_guidance(
step_output.pred_original_sample, self.scheduler.timesteps[-1]
)
else:
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
return step_output
@@ -575,17 +582,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**kwargs,
):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
# FIXME: There are too many layers of functions and we have too many different ways of
# overriding things! This should get handled in a way more consistent with the other
# use of AddsMaskLatents.
latents = AddsMaskLatents(
self._unet_forward,
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
).add_mask_channels(latents)
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(
latents,

View File

@@ -32,11 +32,8 @@ class SDXLConditioningInfo(BasicConditioningInfo):
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
result = super().to(device=device, dtype=dtype)
assert self.embeds.device == device
return result
return super().to(device=device, dtype=dtype)
@dataclass

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import math
import threading
from typing import Any, Callable, Optional, Union
import torch
@@ -294,31 +293,24 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
try:
if conditioning_data.is_sdxl():
# tid = threading.current_thread().ident
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
except Exception as e:
tid = threading.current_thread().ident
print(f"DEBUG: {tid} {str(e)}")
raise e
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings

View File

@@ -0,0 +1,170 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.tiles.utils import Tile
@dataclass
class MultiDiffusionRegionConditioning:
# Region coords in latent space.
region: Tile
text_conditioning_data: TextConditioningData
control_data: list[ControlNetData]
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Validate that regional conditioning is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
):
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
def multi_diffusion_denoise(
self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
target_overlap: int,
latents: torch.Tensor,
scheduler_step_kwargs: dict[str, Any],
noise: Optional[torch.Tensor],
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None],
) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
if init_timestep.shape[0] == 0:
return latents
batch_size, _, latent_height, latent_width = latents.shape
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
# cropping into regions.
self._adjust_memory_efficient_attention(latents)
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
# TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
# statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
# as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
# multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
# scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
region_batch_schedulers: list[SchedulerMixin] = [
copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
]
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
merged_latents = torch.zeros_like(latents)
merged_latents_weights = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
merged_pred_original: torch.Tensor | None = None
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_idx]
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
region_conditioning.region.coords.left : region_conditioning.region.coords.right,
]
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Store the results from the region.
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
# affected tiles to achieve the target overlap.
region = region_conditioning.region
top_adjustment = max(0, region.overlap.top - target_overlap)
left_adjustment = max(0, region.overlap.left - target_overlap)
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
:, :, top_adjustment:, left_adjustment:
]
# For now, we treat every region as having the same weight.
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
# If one region has pred_original_sample, then we can assume that all regions will have it, because
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
:, :, top_adjustment:, left_adjustment:
]
# Normalize the merged results.
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
# For debugging, uncomment this line to visualize the region seams:
# latents = torch.where(merged_latents_weights > 1, 0.0, latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = torch.where(
merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents

View File

@@ -1,3 +0,0 @@
from .schedulers import SCHEDULER_MAP # noqa: F401
__all__ = ["SCHEDULER_MAP"]

View File

@@ -1,3 +1,5 @@
from typing import Literal
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@@ -43,3 +45,9 @@ SCHEDULER_MAP = {
"lcm": (LCMScheduler, {}),
"tcd": (TCDScheduler, {}),
}
# HACK(ryand): Passing a tuple of keys to Literal works at runtime, but not at type-check time. See the docs here for
# more info: https://typing.readthedocs.io/en/latest/spec/literal.html#parameters-at-runtime. For now, we are ignoring
# this error. In the future, we should fix this type handling.
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # type: ignore

View File

@@ -0,0 +1,35 @@
from contextlib import contextmanager
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
@contextmanager
def patch_vae_tiling_params(
vae: AutoencoderKL | AutoencoderTiny,
tile_sample_min_size: int,
tile_latent_min_size: int,
tile_overlap_factor: float,
):
"""Patch the parameters that control the VAE tiling tile size and overlap.
These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of
the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory
usage.
"""
# Record initial config.
orig_tile_sample_min_size = vae.tile_sample_min_size
orig_tile_latent_min_size = vae.tile_latent_min_size
orig_tile_overlap_factor = vae.tile_overlap_factor
try:
# Apply target config.
vae.tile_sample_min_size = tile_sample_min_size
vae.tile_latent_min_size = tile_latent_min_size
vae.tile_overlap_factor = tile_overlap_factor
yield
finally:
# Restore initial config.
vae.tile_sample_min_size = orig_tile_sample_min_size
vae.tile_latent_min_size = orig_tile_latent_min_size
vae.tile_overlap_factor = orig_tile_overlap_factor

View File

@@ -1,16 +1,10 @@
"""Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
from typing import Dict, Literal, Optional, Union
import torch
from deprecated import deprecated
from invokeai.app.services.config.config_default import get_config
if TYPE_CHECKING:
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
@@ -48,23 +42,13 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
class TorchDevice:
"""Abstraction layer for torch devices."""
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
@classmethod
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
"""Set the current model cache."""
cls._model_cache = cache
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
if cls._model_cache:
return cls._model_cache.get_execution_device()
else:
return cls._choose_device()
@classmethod
def _choose_device(cls) -> torch.device:
app_config = get_config()
if app_config.device != "auto":
device = torch.device(app_config.device)
@@ -76,19 +60,11 @@ class TorchDevice:
device = CPU_DEVICE
return cls.normalize(device)
@classmethod
def execution_devices(cls) -> Set[torch.device]:
"""Return a list of torch.devices that can be used for accelerated inference."""
app_config = get_config()
if app_config.devices is None:
return cls._lookup_execution_devices()
return {torch.device(x) for x in app_config.devices}
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
device = device or cls.choose_torch_device()
config = get_config()
device = device or cls._choose_device()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
@@ -137,12 +113,14 @@ class TorchDevice:
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
@classmethod
def _lookup_execution_devices(cls) -> Set[torch.device]:
if torch.cuda.is_available():
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
elif torch.backends.mps.is_available():
devices = {torch.device("mps")}
else:
devices = {torch.device("cpu")}
return devices
@staticmethod
def get_non_blocking(to_device: torch.device) -> bool:
"""Return the non_blocking flag to be used when moving a tensor to a given device.
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
When moving _from_ MPS, we can use non-blocking operations.
See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
"""
return False if to_device.type == "mps" else True

View File

@@ -5,9 +5,10 @@ from typing import Optional, Union
import pytest
import torch
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
@pytest.fixture(scope="session")

View File

@@ -17,7 +17,10 @@
},
"boards": {
"addBoard": "Add Board",
"archiveBoard": "Archive Board",
"archived": "Archived",
"autoAddBoard": "Auto-Add Board",
"selectedForAutoAdd": "Selected for Auto-Add",
"bottomMessage": "Deleting this board and its images will reset any features currently using them.",
"cancel": "Cancel",
"changeBoard": "Change Board",
@@ -36,8 +39,13 @@
"searchBoard": "Search Boards...",
"selectBoard": "Select a Board",
"topMessage": "This board contains images used in the following features:",
"unarchiveBoard": "Unarchive Board",
"uncategorized": "Uncategorized",
"downloadBoard": "Download Board"
"downloadBoard": "Download Board",
"imagesWithCount_one": "{{count}} image",
"imagesWithCount_other": "{{count}} images",
"assetsWithCount_one": "{{count}} asset",
"assetsWithCount_other": "{{count}} assets"
},
"accordions": {
"generation": {
@@ -364,6 +372,10 @@
"image": "image",
"loading": "Loading",
"loadMore": "Load More",
"newestFirst": "Newest First",
"oldestFirst": "Oldest First",
"sortDirection": "Sort Direction",
"showStarredImagesFirst": "Show Starred Images First",
"noImageSelected": "No Image Selected",
"noImagesInGallery": "No Images to Display",
"setCurrentImage": "Set as Current Image",
@@ -381,6 +393,10 @@
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",
"searchImages": "Search by Metadata",
"selectAllOnPage": "Select All On Page",
"selectAllOnBoard": "Select All On Board",
"showArchivedBoards": "Show Archived Boards",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",

View File

@@ -23,6 +23,7 @@ import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listener
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addRequestedSingleImageDeletionListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDeleted';
@@ -51,6 +52,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -77,6 +80,7 @@ addImagesUnstarredListener(startAppListening);
// Gallery
addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedCanvasListener(startAppListening);
@@ -116,6 +120,7 @@ addControlNetAutoProcessListener(startAppListening);
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);

View File

@@ -0,0 +1,48 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
autoAddBoardIdChanged,
boardIdSelected,
galleryViewChanged,
shouldShowArchivedBoardsChanged,
} from 'features/gallery/store/gallerySlice';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: isAnyOf(
// Updating a board may change its archived status
boardsApi.endpoints.updateBoard.matchFulfilled,
// If the selected/auto-add board was deleted from a different session, we'll only know during the list request,
boardsApi.endpoints.listAllBoards.matchFulfilled,
// If a board is deleted, we'll need to reset the auto-add board
imagesApi.endpoints.deleteBoard.matchFulfilled,
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
// When we change the visibility of archived boards, we may need to reset the auto-add board
shouldShowArchivedBoardsChanged
),
effect: async (action, { dispatch, getState }) => {
/**
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
*/
const state = getState();
const queryArgs = selectListBoardsQueryArgs(state);
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
const autoAddBoardId = state.gallery.autoAddBoardId;
if (!queryResult.data) {
return;
}
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
dispatch(autoAddBoardIdChanged('none'));
dispatch(boardIdSelected({ boardId: 'none' }));
dispatch(galleryViewChanged('images'));
}
},
});
};

View File

@@ -2,8 +2,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageCache } from 'services/api/types';
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
import { getListImagesUrl } from 'services/api/util';
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -18,13 +17,10 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
cancelActiveListeners();
unsubscribe();
// TODO: figure out how to type the predicate
const data = action.payload as ImageCache;
const data = action.payload;
if (data.ids.length > 0) {
// Select the first image
const firstImage = imagesSelectors.selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null));
if (data.items.length > 0) {
dispatch(imageSelected(data.items[0] ?? null));
}
},
});

View File

@@ -1,9 +1,13 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
boardIdSelected,
galleryViewChanged,
imageSelected,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesSelectors } from 'services/api/util';
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -14,14 +18,9 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const state = getState();
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
const queryArgs = selectListImagesQueryArgs(state);
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
const queryArgs = { board_id: board_id ?? 'none', categories };
dispatch(selectionChanged([]));
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state
@@ -35,11 +34,12 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
const selectedImage = boardImagesData.items.find(
(item) => item.image_name === action.payload.selectedImageName
);
dispatch(imageSelected(selectedImage || null));
} else if (boardImagesData) {
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
dispatch(imageSelected(firstImage || null));
dispatch(imageSelected(boardImagesData.items[0] || null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));

View File

@@ -4,7 +4,6 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
export const galleryImageClicked = createAction<{
imageDTO: ImageDTO;
@@ -32,14 +31,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (!listImagesData) {
if (!queryResult.data) {
// Should never happen if we have clicked a gallery image
return;
}
const imageDTOs = imagesSelectors.selectAll(listImagesData);
const imageDTOs = queryResult.data.items;
const selection = state.gallery.selection;
if (altKey) {

View File

@@ -0,0 +1,119 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, offsetChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const addGalleryOffsetChangedListener = (startAppListening: AppStartListening) => {
/**
* When the user changes pages in the gallery, we need to wait until the next page of images is loaded, then maybe
* update the selection.
*
* There are a three scenarios:
*
* 1. The page is changed by clicking the pagination buttons. No changes to selection are needed.
*
* 2. The page is changed by using the arrow keys (without alt).
* - When going backwards, select the last image.
* - When going forwards, select the first image.
*
* 3. The page is changed by using the arrows keys with alt. This means the user is changing the comparison image.
* - When going backwards, select the last image _as the comparison image_.
* - When going forwards, select the first image _as the comparison image_.
*/
startAppListening({
actionCreator: offsetChanged,
effect: async (action, { dispatch, getState, getOriginalState, take, cancelActiveListeners }) => {
// Cancel any active listeners to prevent the selection from changing without user input
cancelActiveListeners();
const { withHotkey } = action.payload;
if (!withHotkey) {
// User changed pages by clicking the pagination buttons - no changes to selection
return;
}
const originalState = getOriginalState();
const prevOffset = originalState.gallery.offset;
const offset = getState().gallery.offset;
if (offset === prevOffset) {
// The page didn't change - bail
return;
}
/**
* We need to wait until the next page of images is loaded before updating the selection, so we use the correct
* page of images.
*
* The simplest way to do it would be to use `take` to wait for the next fulfilled action, but RTK-Q doesn't
* dispatch an action on cache hits. This means the `take` will only return if the cache is empty. If the user
* changes to a cached page - a common situation - the `take` will never resolve.
*
* So we need to take a two-step approach. First, check if we have data in the cache for the page of images. If
* we have data cached, use it to update the selection. If we don't have data cached, wait for the next fulfilled
* action, which updates the cache, then use the cache to update the selection.
*/
// Check if we have data in the cache for the page of images
const queryArgs = selectListImagesQueryArgs(getState());
let { data } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
// No data yet - wait for the network request to complete
if (!data) {
const takeResult = await take(imagesApi.endpoints.listImages.matchFulfilled, 5000);
if (!takeResult) {
// The request didn't complete in time - bail
return;
}
data = takeResult[0].payload;
}
// We awaited a network request - state could have changed, get fresh state
const state = getState();
const { selection, imageToCompare } = state.gallery;
const imageDTOs = data?.items;
if (!imageDTOs) {
// The page didn't load - bail
return;
}
if (withHotkey === 'arrow') {
// User changed pages by using the arrow keys - selection changes to first or last image depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (!selection.some((selectedImage) => selectedImage.image_name === lastImage?.image_name)) {
dispatch(selectionChanged(lastImage ? [lastImage] : []));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (!selection.some((selectedImage) => selectedImage.image_name === firstImage?.image_name)) {
dispatch(selectionChanged(firstImage ? [firstImage] : []));
}
}
return;
}
if (withHotkey === 'alt+arrow') {
// User changed pages by using the arrow keys with alt - comparison image changes to first or last depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (lastImage && imageToCompare?.image_name !== lastImage.image_name) {
dispatch(imageToCompareChanged(lastImage));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (firstImage && imageToCompare?.image_name !== firstImage.image_name) {
dispatch(imageToCompareChanged(firstImage));
}
}
return;
}
},
});
};

View File

@@ -22,11 +22,10 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { clamp, forEach } from 'lodash-es';
import { forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.present.nodes.forEach((node) => {
@@ -118,32 +117,7 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
}
dispatch(isModalOpenChanged(false));
const state = getState();
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const { image_name } = imageDTO;
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
@@ -168,6 +142,20 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
if (wasImageDeleted) {
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
}
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
if (data && data.items) {
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
dispatch(imageSelected(newlySelectedImage || null));
} else {
dispatch(imageSelected(null));
}
}
},
});
@@ -188,10 +176,8 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
const queryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
if (data && data.items[0]) {
dispatch(imageSelected(data.items[0]));
} else {
dispatch(imageSelected(null));
}

View File

@@ -15,7 +15,12 @@ import {
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
imageToCompareChanged,
isImageViewerOpenChanged,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -216,6 +221,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -233,6 +239,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTO,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -248,6 +255,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -261,6 +269,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTOs,
})
);
dispatch(selectionChanged([]));
return;
}
},

View File

@@ -11,6 +11,7 @@ import {
ipaLayerImageChanged,
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { toast } from 'features/toast/toast';
@@ -62,7 +63,8 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
);
// Attempt to get the board's name for the toast
const { data } = boardsApi.endpoints.listAllBoards.select()(state);
const queryArgs = selectListBoardsQueryArgs(state);
const { data } = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
// Fall back to just the board id if we can't find the board for some reason
const board = data?.find((b) => b.board_id === autoAddBoardId);

View File

@@ -8,14 +8,14 @@ import {
galleryViewChanged,
imageSelected,
isImageViewerOpenChanged,
offsetChanged,
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
@@ -52,24 +52,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
if (!imageDTO.is_intermediate) {
/**
* Cache updates for when an image result is received
* - add it to the no_board/images
*/
dispatch(
imagesApi.util.updateQueryData(
'listImages',
{
board_id: imageDTO.board_id ?? 'none',
categories: IMAGE_CATEGORIES,
},
(draft) => {
imagesAdapter.addOne(draft, imageDTO);
}
)
);
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
@@ -78,7 +60,18 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
})
);
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
const { shouldAutoSwitch } = gallery;
@@ -98,6 +91,8 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
);
}
dispatch(offsetChanged({ offset: 0 }));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({

View File

@@ -1,47 +1,37 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import type { MouseEvent, ReactElement } from 'react';
import { memo, useMemo } from 'react';
import type { MouseEvent } from 'react';
import { memo } from 'react';
type Props = {
const sx: SystemStyleObject = {
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
};
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
tooltip: string;
icon?: ReactElement;
styleOverrides?: SystemStyleObject;
};
const IAIDndImageIcon = (props: Props) => {
const { onClick, tooltip, icon, styleOverrides } = props;
const sx = useMemo(
() => ({
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
...styleOverrides,
}),
[styleOverrides]
);
const { onClick, tooltip, icon, ...rest } = props;
return (
<IconButton
onClick={onClick}
aria-label={tooltip}
tooltip={tooltip}
icon={icon}
size="sm"
variant="link"
sx={sx}
data-testid={tooltip}
{...rest}
/>
);
};

View File

@@ -1,16 +0,0 @@
/**
* Comparator function for sorting dates in ascending order
*/
export const dateComparator = (a: string, b: string) => {
const dateA = new Date(a);
const dateB = new Date(b);
// sort in ascending order
if (dateA > dateB) {
return 1;
}
if (dateA < dateB) {
return -1;
}
return 0;
};

View File

@@ -7,6 +7,7 @@ import {
isModalOpenChanged,
selectChangeBoardModalSlice,
} from 'features/changeBoardModal/store/slice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
@@ -20,7 +21,8 @@ const selectImagesToChange = createMemoizedSelector(
const ChangeBoardModal = () => {
const dispatch = useAppDispatch();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const { data: boards, isFetching } = useListAllBoardsQuery();
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
const isModalOpen = useAppSelector((s) => s.changeBoardModal.isModalOpen);
const imagesToChange = useAppSelector(selectImagesToChange);
const [addImagesToBoard] = useAddImagesToBoardMutation();

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
@@ -185,25 +184,25 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={<PiFloppyDiskBold size={16} />}
tooltip={t('controlnet.saveControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={t('controlnet.setControlImageDimensions')}
/>
</Flex>
)}
{pendingControlImages.includes(id) && (
<Flex
@@ -226,6 +225,3 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
};
export default memo(ControlAdapterImagePreview);
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -160,7 +159,7 @@ export const ControlAdapterImagePreview = memo(
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
position="relative"
w="full"
w={36}
h={36}
alignItems="center"
justifyContent="center"
@@ -193,25 +192,27 @@ export const ControlAdapterImagePreview = memo(
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={<PiFloppyDiskBold size={16} />}
tooltip={t('controlnet.saveControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
{controlAdapter.processorPendingBatchId !== null && (
<Flex
@@ -235,6 +236,3 @@ export const ControlAdapterImagePreview = memo(
);
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -82,7 +81,7 @@ export const IPAdapterImagePreview = memo(
}, [handleResetControlImage, isConnected, isErrorControlImage]);
return (
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
<Flex position="relative" w={36} h={36} alignItems="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
@@ -90,24 +89,25 @@ export const IPAdapterImagePreview = memo(
postUploadAction={postUploadAction}
/>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
</Flex>
);
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -1,4 +1,3 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -79,31 +78,34 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
}, [onReset, isConnected, isErrorControlImage]);
return (
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={imageDTO}
postUploadAction={postUploadAction}
/>
<Flex w="full" alignItems="center" justifyContent="center">
<Flex position="relative" w={36} h={36} alignItems="center" justifyContent="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={imageDTO}
postUploadAction={postUploadAction}
/>
<>
<IAIDndImageIcon
onClick={onReset}
icon={imageDTO ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={onUseSize}
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={useSizeStyleOverrides}
/>
</>
{imageDTO && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={onReset}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={onUseSize}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
</Flex>
</Flex>
);
});
InitialImagePreview.displayName = 'InitialImagePreview';
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -11,25 +11,28 @@ const BoardAutoAddSelect = () => {
const { t } = useTranslation();
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const { options, hasBoards } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: t('controlnet.none'),
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
});
const { options, hasBoards } = useListAllBoardsQuery(
{},
{
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: t('controlnet.none'),
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
}
);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -3,11 +3,12 @@ import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-librar
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { autoAddBoardIdChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
import type { BoardDTO } from 'services/api/types';
@@ -15,52 +16,85 @@ import type { BoardDTO } from 'services/api/types';
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
type Props = {
board?: BoardDTO;
board_id: BoardId;
board: BoardDTO;
children: ContextMenuProps<HTMLDivElement>['children'];
setBoardToDelete?: (board?: BoardDTO) => void;
setBoardToDelete: (board?: BoardDTO) => void;
};
const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props) => {
const BoardContextMenu = ({ board, setBoardToDelete, children }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectGallerySlice, (gallery) => board && board.board_id === gallery.autoAddBoardId),
[board]
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
[board.board_id]
);
const [updateBoard] = useUpdateBoardMutation();
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
const boardName = useBoardName(board_id);
const boardName = useBoardName(board.board_id);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();
const handleSetAutoAdd = useCallback(() => {
dispatch(autoAddBoardIdChanged(board_id));
}, [board_id, dispatch]);
dispatch(autoAddBoardIdChanged(board.board_id));
}, [board.board_id, dispatch]);
const handleBulkDownload = useCallback(() => {
bulkDownload({ image_names: [], board_id: board_id });
}, [board_id, bulkDownload]);
bulkDownload({ image_names: [], board_id: board.board_id });
}, [board.board_id, bulkDownload]);
const handleArchive = useCallback(async () => {
try {
await updateBoard({
board_id: board.board_id,
changes: { archived: true },
}).unwrap();
} catch (error) {
toast({
status: 'error',
title: 'Unable to archive board',
});
}
}, [board.board_id, updateBoard]);
const handleUnarchive = useCallback(() => {
updateBoard({
board_id: board.board_id,
changes: { archived: false },
});
}, [board.board_id, updateBoard]);
const renderMenuFunc = useCallback(
() => (
<MenuList visibility="visible">
<MenuGroup title={boardName}>
<MenuItem
icon={<PiPlusBold />}
isDisabled={isSelectedForAutoAdd || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
{t('boards.menuItemAutoAdd')}
</MenuItem>
{!autoAssignBoardOnClick && (
<MenuItem icon={<PiPlusBold />} isDisabled={isSelectedForAutoAdd} onClick={handleSetAutoAdd}>
{isSelectedForAutoAdd ? t('boards.selectedForAutoAdd') : t('boards.menuItemAutoAdd')}
</MenuItem>
)}
{isBulkDownloadEnabled && (
<MenuItem icon={<PiDownloadBold />} onClickCapture={handleBulkDownload}>
{t('boards.downloadBoard')}
</MenuItem>
)}
{board && <GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />}
{board.archived && (
<MenuItem icon={<PiArchiveBold />} onClick={handleUnarchive}>
{t('boards.unarchiveBoard')}
</MenuItem>
)}
{!board.archived && (
<MenuItem icon={<PiArchiveFill />} onClick={handleArchive}>
{t('boards.archiveBoard')}
</MenuItem>
)}
<GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />
</MenuGroup>
</MenuList>
),
@@ -74,6 +108,8 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
isSelectedForAutoAdd,
setBoardToDelete,
t,
handleArchive,
handleUnarchive,
]
);

View File

@@ -0,0 +1,22 @@
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
type Props = {
board_id: string;
isArchived: boolean;
};
export const BoardTotalsTooltip = ({ board_id, isArchived }: Props) => {
const { t } = useTranslation();
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { imagesTotal: data?.total ?? 0 };
},
});
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { assetsTotal: data?.total ?? 0 };
},
});
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}${isArchived ? ` (${t('boards.archived')})` : ''}`;
};

View File

@@ -2,6 +2,7 @@ import { Collapse, Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo, useState } from 'react';
@@ -26,7 +27,8 @@ const BoardsList = (props: Props) => {
const { isOpen } = props;
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
const { data: boards } = useListAllBoardsQuery();
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards } = useListAllBoardsQuery(queryArgs);
const filteredBoards = boardSearchText
? boards?.filter((board) => board.board_name.toLowerCase().includes(boardSearchText.toLowerCase()))
: boards;

View File

@@ -8,15 +8,12 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
import type { AddToBoardDropData } from 'features/dnd/types';
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImagesSquare } from 'react-icons/pi';
import {
useGetBoardAssetsTotalQuery,
useGetBoardImagesTotalQuery,
useUpdateBoardMutation,
} from 'services/api/endpoints/boards';
import { PiArchiveBold, PiImagesSquare } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { BoardDTO } from 'services/api/types';
@@ -28,6 +25,14 @@ const editableInputStyles: SystemStyleObject = {
},
};
const ArchivedIcon = () => {
return (
<Box position="absolute" top={1} insetInlineEnd={2} p={0} minW={0}>
<Icon as={PiArchiveBold} fill="base.300" filter="drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))" />
</Box>
);
};
interface GalleryBoardProps {
board: BoardDTO;
isSelected: boolean;
@@ -36,6 +41,7 @@ interface GalleryBoardProps {
const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
@@ -51,17 +57,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
setIsHovered(false);
}, []);
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
const tooltip = useMemo(() => {
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
return undefined;
}
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
assetsTotal.total
} asset${assetsTotal.total === 1 ? '' : 's'}`;
}, [assetsTotal, imagesTotal]);
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
const { board_name, board_id } = board;
@@ -117,7 +112,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
const handleChange = useCallback((newBoardName: string) => {
setLocalBoardName(newBoardName);
}, []);
const { t } = useTranslation();
return (
<Box w="full" h="full" userSelect="none">
<Flex
@@ -130,9 +125,12 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
w="full"
h="full"
>
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
<BoardContextMenu board={board} setBoardToDelete={setBoardToDelete}>
{(ref) => (
<Tooltip label={tooltip} openDelay={1000}>
<Tooltip
label={<BoardTotalsTooltip board_id={board.board_id} isArchived={Boolean(board.archived)} />}
openDelay={1000}
>
<Flex
ref={ref}
onClick={handleSelectBoard}
@@ -145,6 +143,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
cursor="pointer"
bg="base.800"
>
{board.archived && <ArchivedIcon />}
{coverImage?.thumbnail_url ? (
<Image
src={coverImage?.thumbnail_url}

Some files were not shown because too many files have changed in this diff Show More