Compare commits

..

168 Commits

Author SHA1 Message Date
psychedelicious
b7d439c295 fix(mm): model names with periods borked
When we provide a config object during a model install, the config can override individual fields that would otherwise be derived programmatically. We use this to install starter models w/ a given name, description, etc.

This logic used `pathlib` to append a suffix to the model's name. When we provide a model name that has a period in it, `pathlib` splits the name at the period and replaces everything after it with the suffix. This is then used to determine the output path of the model.

As a result, some starter model paths are incorrect. For example, `IP Adapter SD1.5 Image Encoder` gets installed to `sd-1/clip_vision/IP Adapter SD1`.
2024-10-15 17:00:08 +10:00
Brandon Rising
3da8076a2b fix: Pin onnx versions to builds that don't require rare dlls 2024-10-12 10:36:51 -04:00
Mary Hipp
80360a8abb fix(api): update enum usage to work for python 3.11 2024-10-12 10:21:26 -04:00
Mary Hipp
acfeb4a276 undo changes that made category optional 2024-10-11 17:37:57 -04:00
Mary Hipp
b33dbfc95f prefix share link with window location 2024-10-11 17:25:58 -04:00
Mary Hipp
f9bc29203b ruff 2024-10-11 17:23:34 -04:00
Mary Hipp
cbe7717409 make sure combobox is not searchable 2024-10-11 17:23:34 -04:00
Mary Hipp
d6add93901 lint 2024-10-11 17:23:34 -04:00
Mary Hipp
ea45dce9dc (ui) add board sorting UI to board settings popover 2024-10-11 17:23:34 -04:00
Mary Hipp
8d44363d49 (ui): update boards list queries to only use sort params for list, and make sure archived boards are included in most places we are searching 2024-10-11 17:23:34 -04:00
Mary Hipp
9933cdb6b7 (api) fix missing sort params being drilled down, add case insensitivity to name sorting 2024-10-11 17:23:34 -04:00
Mary Hipp
e3e9d1f27c (ui) break out boards settings from gallery/image settings 2024-10-11 17:23:34 -04:00
psychedelicious
bb59ad438a docs(ui): add comments to ImageContextMenu 2024-10-11 09:36:23 -04:00
psychedelicious
e38f5b1576 fix(ui): safari doesn't have find on iterators 2024-10-11 09:36:23 -04:00
psychedelicious
1bb49b698f perf(ui): efficient gallery image hover state 2024-10-11 09:36:23 -04:00
psychedelicious
fa1fbd89fe tidy(ui): remove extraneous prop extraction 2024-10-11 09:36:23 -04:00
psychedelicious
190ef6732c perf(ui): properly memoize gallery image icon components 2024-10-11 09:36:23 -04:00
psychedelicious
947cd4694b perf(ui): use single event for all image context menus
Image elements register their target ref in a map, which is used to look up the image that was clicked on. Substantial perf improvement.
2024-10-11 09:36:23 -04:00
psychedelicious
ee32d0666d perf(ui): memoize gallery page buttons 2024-10-11 09:36:23 -04:00
psychedelicious
bc8ad9ccbf perf(ui): remove another extraneous useCallback 2024-10-11 09:36:23 -04:00
psychedelicious
e96b290fa9 perf(ui): remove extraneous useCallbacks 2024-10-11 09:36:23 -04:00
psychedelicious
b9f83eae6a perf(ui): do not call upload hook unless upload is needed 2024-10-11 09:36:23 -04:00
psychedelicious
9868e23235 feat(ui): use singleton context menu
This improves render perf for the image component by 10-20%.
2024-10-11 09:36:23 -04:00
psychedelicious
0060cae17c build(ui): set package mode target to ES2015 2024-10-11 07:54:44 -04:00
psychedelicious
56f0845552 tidy(ui): consistent naming for selector builder util 2024-10-11 07:51:55 -04:00
psychedelicious
da3f85dd8b fix(ui): edge case where entity isn't visible until interacting with canvas
To trigger the edge case:
- Have an empty layer and non-empty layer
- Select the non-empty layer
- Refresh the page
- Select to the empty layer without doing any other action
- You may be unable to draw on the layer
- Zoom in/out slightly
- You can now draw on it

The problem was not syncing visibility when a layer is selected, leaving the layer hidden. This indirectly disabled interactions.

The fix is to listen for changes to the layer's selected status and sync visibility when that changes.
2024-10-11 07:51:55 -04:00
psychedelicious
7185363f17 fix(ui): edge case where controladapters added counts could be off
We were:
- Incrementing `addedControlNets` or `addedT2IAdapters`
- Attempting to add it, but maybe failing and skipping

Need to swap the order of operations to prevent misreporting of added cnet/t2i.

I don't think this would ever actually cause problems.
2024-10-11 10:37:30 +11:00
Ryan Dick
ac08c31fbc Remove unnecessary hasattr checks for scaled_dot_product_attention. We pin the torch version, so there should be no concern that this function does not exist. 2024-10-10 19:23:45 -04:00
Ryan Dick
ea54a2655a Add a workaround for broken sliced attention on MPS with torch 2.4.1. 2024-10-10 19:23:45 -04:00
psychedelicious
cc83dede9f chore: bump version to v5.2.0rc1 2024-10-11 10:11:47 +11:00
Riccardo Giovanetti
8464fd2ced translationBot(ui): update translation (Italian)
Currently translated at 98.5% (1462 of 1483 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-10-11 09:41:45 +11:00
Васянатор
c3316368d9 translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1479 of 1479 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-10-11 09:41:45 +11:00
Riku
8b2d5ab28a translationBot(ui): update translation (German)
Currently translated at 70.3% (1048 of 1490 strings)

translationBot(ui): update translation (German)

Currently translated at 69.4% (1027 of 1479 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-10-11 09:41:45 +11:00
psychedelicious
3f6acdc2d3 fix(ui): use non-icon version of delete menu item on canvas context menu 2024-10-10 18:23:32 -04:00
psychedelicious
4aa20a95b2 feat(ui): consolidate img2img canvas flow
Make the `New Canvas From Image` button do what the `New Img2Img From Image` does.
2024-10-11 09:03:44 +11:00
Ryan Dick
2d82e69a33 Add support for FLUX ControlNet models (XLabs and InstantX) (#7070)
## Summary

Add support for FLUX ControlNet models (XLabs and InstantX).

## QA Instructions

- [x] SD1 and SDXL ControlNets, since the ModelLoaderRegistry calls were
changed.
- [x] Single Xlabs controlnet
- [x] Single InstantX union controlnet
- [x] Single InstantX controlnet
- [x] Single Shakker Labs Union controlnet
- [x] Multiple controlnets
- [x] Weight, start, end params all work as expected
- [x] Can be used with image-to-image and inpainting.
- [x] Clear error message if no VAE is passed when using InstantX
controlnet.
- [x] Install InstantX ControlNet in diffusers format from HF repo
(`InstantX/FLUX.1-dev-Controlnet-Union`)
- [x] Test all FLUX ControlNet starter models
## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-10-10 12:37:09 -04:00
Ryan Dick
683f9a70e7 Restore instantx_control_mode field on FLUX ControlNet invocation. 2024-10-10 15:25:30 +00:00
Ryan Dick
bb6d073828 Use the Shakker-Labs ControlNet union model as the only FLUX ControlNet starter model. 2024-10-10 13:59:59 +00:00
Kent Keirsey
7f7d8e5177 Merge branch 'ryan/flux-controlnet-xlabs-instantx' of https://github.com/invoke-ai/InvokeAI into ryan/flux-controlnet-xlabs-instantx 2024-10-10 08:06:25 -04:00
Ryan Dick
f37c5011f4 Reduce peak memory utilization when preparing FLUX controlnet inputs. 2024-10-10 07:59:29 -04:00
Ryan Dick
bb947c6162 Make FLUX controlnet node API more like SD API and get it working with linear UI. 2024-10-10 07:59:29 -04:00
Ryan Dick
a654dad20f Remove instantx_control_mode from FLUX ControlNet node. 2024-10-10 07:59:29 -04:00
Mary Hipp
2bd44662f3 possibly a working FLUX controlnet graph 2024-10-10 07:59:29 -04:00
Ryan Dick
e7f9086006 Fix bug with InstantX input image range. 2024-10-10 07:59:29 -04:00
Mary Hipp
5141be8009 hide Control Mode for FLUX control net layer 2024-10-10 07:59:29 -04:00
Mary Hipp
eacdfc660b ui: enable controlnet controls when FLUX is main model, update schema 2024-10-10 07:59:29 -04:00
maryhipp
5fd3c39431 update prepreprocessor logic to be more resilient 2024-10-10 07:59:29 -04:00
maryhipp
7daf3b7d4a update starter models to include FLUX controlnets 2024-10-10 07:59:29 -04:00
Ryan Dick
908f65698d Fix support for InstantX non-union models (with no single blocks). 2024-10-10 07:59:29 -04:00
Ryan Dick
63c4ac58e9 Support installing InstantX ControlNet models from diffusers directory format. 2024-10-10 07:59:29 -04:00
Ryan Dick
8c125681ea Skip tests that are failing on MacOS CI runners (for now). 2024-10-10 07:59:29 -04:00
Ryan Dick
118f0ba3bf Revert "Try to fix test failures affecting MacOS CI runners."
This reverts commit 216b36c75d.
2024-10-10 07:59:29 -04:00
Ryan Dick
b3b7d084d0 Try to fix test failures affecting MacOS CI runners. 2024-10-10 07:59:29 -04:00
Ryan Dick
812940eb95 (minor) Add comment about future memory optimization. 2024-10-10 07:59:29 -04:00
Ryan Dick
0559480dd6 Shift the controlnet-type-specific logic into the specific ControlNet extensions and make the FLUX model controlnet-type-agnostic. 2024-10-10 07:59:29 -04:00
Ryan Dick
d99e7dd4e4 Add instantx_control_mode param to FLUX ControlNet invocation. 2024-10-10 07:59:29 -04:00
Ryan Dick
e854181417 Create a dedicated FLUX ControlNet invocation. 2024-10-10 07:59:29 -04:00
Ryan Dick
de414c09fd Bugfixes to get InstantX ControlNet working. 2024-10-10 07:59:29 -04:00
Ryan Dick
ce4624f72b Update ControlNetCheckpointProbe.get_base_type() to work with InstantX. 2024-10-10 07:59:29 -04:00
Ryan Dick
47c7df3476 Fix circular imports related to XLabsControlNetFluxOutput and InstantXControlNetFluxOutput. 2024-10-10 07:59:29 -04:00
Ryan Dick
4289b5e6c3 Add instantx controlnet logic to FLUX model forward(). 2024-10-10 07:59:29 -04:00
Ryan Dick
c8d1d14662 Work on integrating InstantX into denoise process. 2024-10-10 07:59:29 -04:00
Ryan Dick
44c588d778 Rename DiffusersControlNetFlux -> InstantXControlNetFlux. 2024-10-10 07:59:29 -04:00
Ryan Dick
d75ac56d00 Create flux/extensions directory. 2024-10-10 07:59:29 -04:00
Ryan Dick
714dd5f0be Update FluxControlnetModel to work with both XLabs and InstantX. 2024-10-10 07:59:29 -04:00
Ryan Dick
2f4d3cb5e6 Add unit test to test the full flow of loading an InstantX ControlNet from a state dict. 2024-10-10 07:59:29 -04:00
Ryan Dick
b76555bda9 Add unit test for infer_instantx_num_control_modes_from_state_dict(). 2024-10-10 07:59:29 -04:00
Ryan Dick
1cdd501a0a Add unit test for infer_flux_params_from_state_dict(...). 2024-10-10 07:59:29 -04:00
Ryan Dick
1125218bc5 Update FLUX ControlNet unit test state dicts to include shapes. 2024-10-10 07:59:29 -04:00
Ryan Dick
683504bfb5 Add scripts/extract_sd_keys_and_shapes.py 2024-10-10 07:59:29 -04:00
Ryan Dick
03cf953398 First pass of utility function to infer the FluxParams from a state dict. 2024-10-10 07:59:29 -04:00
Ryan Dick
24c115663d Add unit test for convert_diffusers_instantx_state_dict_to_bfl_format(...) and fix a few bugs. 2024-10-10 07:59:29 -04:00
Ryan Dick
a9e7ecad49 Finish first draft of convert_diffusers_instantx_state_dict_to_bfl_format(...). 2024-10-10 07:59:29 -04:00
Ryan Dick
76f4766324 WIP - implement convert_diffusers_instantx_state_dict_to_bfl_format(...). 2024-10-10 07:59:29 -04:00
Ryan Dick
3dfc242f77 (minor) rename other_forward() -> forward() 2024-10-10 07:59:29 -04:00
Ryan Dick
1e43389cb4 Add utils for detecting XLabs ControlNet vs. InstantX ControlNet from
state dict.
2024-10-10 07:59:29 -04:00
Ryan Dick
cb33de34f7 Migrate DiffusersControlNetFlux from diffusers-style to BFL-style. 2024-10-10 07:59:29 -04:00
Ryan Dick
7562ea48dc Improve typing of zero_module(). 2024-10-10 07:59:29 -04:00
Ryan Dick
83f4700f5a Use top-level torch import for all torch stuff. 2024-10-10 07:59:29 -04:00
Ryan Dick
704e7479b2 Remove DiffusersControlNetFlux.from_transformer(...). 2024-10-10 07:59:29 -04:00
Ryan Dick
5f44559f30 Fixup typing around DiffusersControlNetFluxOutput. 2024-10-10 07:59:29 -04:00
Ryan Dick
7a22819100 Remove gradient checkpointing from DiffusersControlNetFlux. 2024-10-10 07:59:29 -04:00
Ryan Dick
70495665c5 Remove FluxMultiControlNetModel 2024-10-10 07:59:29 -04:00
Ryan Dick
ca30acc5b4 Remove LoRA stuff from DiffusersCotnrolNetFlux. 2024-10-10 07:59:29 -04:00
Ryan Dick
8121843d86 Remove logic for modifying attn processors from DiffusersControlNetFlux. 2024-10-10 07:59:29 -04:00
Ryan Dick
bc0ded0a23 Rename FluxControlNetModel -> DiffusersControlNetFlux 2024-10-10 07:59:29 -04:00
Ryan Dick
30f6034f88 Start updating imports for FluxControlNetModel 2024-10-10 07:59:29 -04:00
Ryan Dick
7d56a8ce54 Copy model from 99f608218c/src/diffusers/models/controlnet_flux.py 2024-10-10 07:59:29 -04:00
Ryan Dick
e7dc439006 Rename ControlNetFlux -> XLabsControlNetFlux 2024-10-10 07:59:29 -04:00
Ryan Dick
bce5a93eb1 Add InstantX FLUX ControlNet state dict for unit testing. 2024-10-10 07:59:29 -04:00
Ryan Dick
93e98a1f63 Add support for FLUX controlnet weight, begin_step_percent and end_step_percent. 2024-10-10 07:59:29 -04:00
Ryan Dick
0f93deab3b First pass at integrating FLUX ControlNets into the FLUX Denoise invocation. 2024-10-10 07:59:29 -04:00
Ryan Dick
3f3aba8b10 Add FLUX XLabs ControlNet model probing. 2024-10-10 07:59:29 -04:00
Ryan Dick
0b84f567f1 Fix type errors and imporve docs for ControlNetFlux. 2024-10-10 07:59:29 -04:00
Ryan Dick
69c0d7dcc9 Remove gradient checkpointing from ControlNetFlux. 2024-10-10 07:59:29 -04:00
Ryan Dick
5307248fcf Remove ControlNetFlux logic related to attn processor overrides. 2024-10-10 07:59:29 -04:00
Ryan Dick
2efaea8f79 Remove duplicate FluxParams class. 2024-10-10 07:59:29 -04:00
Ryan Dick
c1dfd9b7d9 Fix FLUX module imports for ControlNetFlux. 2024-10-10 07:59:29 -04:00
Ryan Dick
c594ef89d2 Copy ControlNetFlux model from 47495425db/src/flux/controlnet.py. 2024-10-10 07:59:29 -04:00
Ryan Dick
563db67b80 Add XLabs FLUX controlnet state dict key file to be used for development/testing. 2024-10-10 07:59:29 -04:00
psychedelicious
236c065edd fix(ui): respect grid size when fitting layer to box 2024-10-10 07:43:46 -04:00
psychedelicious
1f5d744d01 fix(ui): disable canvas-related image context menu items when canvas is busy 2024-10-10 07:43:46 -04:00
psychedelicious
b36c6af0ae feat(ui): add new img2img canvas from image functionality
This replicates the img2img flow:
- Reset the canvas
- Resize the bbox to the image's aspect ratio at the optimal size for the selected model
- Add the image as a raster layer
- Resizes the layer to fit the bbox using the 'fill' strategy

After this completes, the user can immediately click Invoke and it will do img2img.
2024-10-10 07:43:46 -04:00
psychedelicious
4e431a9d5f feat(ui): add a mutex to CanvasEntityTransformer to prevent concurrent operations 2024-10-10 07:43:46 -04:00
psychedelicious
48a8232285 feat(ui): add entity adapter init callbacks
If an entity needs to do something after init, it can use this system. For example, if a layer should be transformed immediately after initializing, it can use an init callback.
2024-10-10 07:43:46 -04:00
psychedelicious
94007fef5b tidy(ui): remove unused reducer 2024-10-10 07:43:46 -04:00
psychedelicious
9e6fb3bd3f feat(ui): add hooks for new layer/canvas from image & use them 2024-10-10 07:43:46 -04:00
Ryan Dick
4aace24f1f Reduce peak memory utilization when preparing FLUX controlnet inputs. 2024-10-10 00:18:46 +00:00
Ryan Dick
b1567fe0e4 Make FLUX controlnet node API more like SD API and get it working with linear UI. 2024-10-09 23:38:31 +00:00
Ryan Dick
3953e60a4f Remove instantx_control_mode from FLUX ControlNet node. 2024-10-09 22:00:54 +00:00
Mary Hipp
63a2e17f6b possibly a working FLUX controlnet graph 2024-10-09 15:42:02 -04:00
Ryan Dick
8b1ef4b902 Fix bug with InstantX input image range. 2024-10-09 19:38:30 +00:00
Mary Hipp
5f2279c984 hide Control Mode for FLUX control net layer 2024-10-09 15:31:44 -04:00
Mary Hipp
e82d67849c ui: enable controlnet controls when FLUX is main model, update schema 2024-10-09 15:05:29 -04:00
maryhipp
3977ffaa3e update prepreprocessor logic to be more resilient 2024-10-09 14:57:14 -04:00
maryhipp
9a8a858fe4 update starter models to include FLUX controlnets 2024-10-09 14:57:14 -04:00
Ryan Dick
859944f848 Fix support for InstantX non-union models (with no single blocks). 2024-10-09 18:51:53 +00:00
Ryan Dick
8d1a45863c Support installing InstantX ControlNet models from diffusers directory format. 2024-10-09 17:04:10 +00:00
Ryan Dick
6798bbab26 Skip tests that are failing on MacOS CI runners (for now). 2024-10-09 16:34:42 +00:00
Ryan Dick
2c92e8a495 Revert "Try to fix test failures affecting MacOS CI runners."
This reverts commit 216b36c75d.
2024-10-09 16:30:40 +00:00
Ryan Dick
216b36c75d Try to fix test failures affecting MacOS CI runners. 2024-10-09 16:21:52 +00:00
Ryan Dick
8bf8742984 (minor) Add comment about future memory optimization. 2024-10-09 16:16:04 +00:00
Ryan Dick
c78eeb1645 Shift the controlnet-type-specific logic into the specific ControlNet extensions and make the FLUX model controlnet-type-agnostic. 2024-10-09 16:12:09 +00:00
Ryan Dick
cd88723a80 Add instantx_control_mode param to FLUX ControlNet invocation. 2024-10-09 14:17:42 +00:00
Ryan Dick
dea6cbd599 Create a dedicated FLUX ControlNet invocation. 2024-10-09 14:17:42 +00:00
Ryan Dick
0dd9f1f772 Bugfixes to get InstantX ControlNet working. 2024-10-09 14:17:42 +00:00
Ryan Dick
5d11c30ce6 Update ControlNetCheckpointProbe.get_base_type() to work with InstantX. 2024-10-09 14:17:42 +00:00
Ryan Dick
a783539cd2 Fix circular imports related to XLabsControlNetFluxOutput and InstantXControlNetFluxOutput. 2024-10-09 14:17:42 +00:00
Ryan Dick
2f8f30b497 Add instantx controlnet logic to FLUX model forward(). 2024-10-09 14:17:42 +00:00
Ryan Dick
f878e5e74e Work on integrating InstantX into denoise process. 2024-10-09 14:17:42 +00:00
Ryan Dick
bfc460a5c6 Rename DiffusersControlNetFlux -> InstantXControlNetFlux. 2024-10-09 14:17:42 +00:00
Ryan Dick
a24581ede2 Create flux/extensions directory. 2024-10-09 14:17:42 +00:00
Ryan Dick
56731766ca Update FluxControlnetModel to work with both XLabs and InstantX. 2024-10-09 14:17:42 +00:00
Ryan Dick
80bc4ebee3 Add unit test to test the full flow of loading an InstantX ControlNet from a state dict. 2024-10-09 14:17:42 +00:00
Ryan Dick
745b6dbd5d Add unit test for infer_instantx_num_control_modes_from_state_dict(). 2024-10-09 14:17:42 +00:00
Ryan Dick
c7628945c4 Add unit test for infer_flux_params_from_state_dict(...). 2024-10-09 14:17:42 +00:00
Ryan Dick
728927ecff Update FLUX ControlNet unit test state dicts to include shapes. 2024-10-09 14:17:42 +00:00
Ryan Dick
1a7eece695 Add scripts/extract_sd_keys_and_shapes.py 2024-10-09 14:17:42 +00:00
Ryan Dick
2cd14dd066 First pass of utility function to infer the FluxParams from a state dict. 2024-10-09 14:17:42 +00:00
Ryan Dick
5872f05342 Add unit test for convert_diffusers_instantx_state_dict_to_bfl_format(...) and fix a few bugs. 2024-10-09 14:17:42 +00:00
Ryan Dick
4ad135c6ae Finish first draft of convert_diffusers_instantx_state_dict_to_bfl_format(...). 2024-10-09 14:17:42 +00:00
Ryan Dick
c72c2770fe WIP - implement convert_diffusers_instantx_state_dict_to_bfl_format(...). 2024-10-09 14:17:42 +00:00
Ryan Dick
e733a1f30e (minor) rename other_forward() -> forward() 2024-10-09 14:17:42 +00:00
Ryan Dick
4be3a33744 Add utils for detecting XLabs ControlNet vs. InstantX ControlNet from
state dict.
2024-10-09 14:17:42 +00:00
Ryan Dick
1751c380db Migrate DiffusersControlNetFlux from diffusers-style to BFL-style. 2024-10-09 14:17:42 +00:00
Ryan Dick
16cda33025 Improve typing of zero_module(). 2024-10-09 14:17:42 +00:00
Ryan Dick
8308e7d186 Use top-level torch import for all torch stuff. 2024-10-09 14:17:42 +00:00
Ryan Dick
c0aab56d08 Remove DiffusersControlNetFlux.from_transformer(...). 2024-10-09 14:17:42 +00:00
Ryan Dick
1795f4f8a2 Fixup typing around DiffusersControlNetFluxOutput. 2024-10-09 14:17:42 +00:00
Ryan Dick
5bfd2ec6b7 Remove gradient checkpointing from DiffusersControlNetFlux. 2024-10-09 14:17:42 +00:00
Ryan Dick
a35b229a9d Remove FluxMultiControlNetModel 2024-10-09 14:17:42 +00:00
Ryan Dick
e93da5d4b2 Remove LoRA stuff from DiffusersCotnrolNetFlux. 2024-10-09 14:17:42 +00:00
Ryan Dick
a17ea9bfad Remove logic for modifying attn processors from DiffusersControlNetFlux. 2024-10-09 14:17:42 +00:00
Ryan Dick
3578010ba4 Rename FluxControlNetModel -> DiffusersControlNetFlux 2024-10-09 14:17:42 +00:00
Ryan Dick
459cf52043 Start updating imports for FluxControlNetModel 2024-10-09 14:17:42 +00:00
Ryan Dick
9bcb93f575 Copy model from 99f608218c/src/diffusers/models/controlnet_flux.py 2024-10-09 14:17:42 +00:00
Ryan Dick
d1a0e99701 Rename ControlNetFlux -> XLabsControlNetFlux 2024-10-09 14:17:42 +00:00
Ryan Dick
92b1515d9d Add InstantX FLUX ControlNet state dict for unit testing. 2024-10-09 14:17:42 +00:00
Ryan Dick
36515e1e2a Add support for FLUX controlnet weight, begin_step_percent and end_step_percent. 2024-10-09 14:17:42 +00:00
Ryan Dick
c81bb761ed First pass at integrating FLUX ControlNets into the FLUX Denoise invocation. 2024-10-09 14:17:42 +00:00
Ryan Dick
1d4a58e52b Add FLUX XLabs ControlNet model probing. 2024-10-09 14:17:42 +00:00
Ryan Dick
62d12e6468 Fix type errors and imporve docs for ControlNetFlux. 2024-10-09 14:17:41 +00:00
Ryan Dick
9541156ce5 Remove gradient checkpointing from ControlNetFlux. 2024-10-09 14:17:41 +00:00
Ryan Dick
eb5b6625ea Remove ControlNetFlux logic related to attn processor overrides. 2024-10-09 14:17:41 +00:00
Ryan Dick
9758e5a622 Remove duplicate FluxParams class. 2024-10-09 14:17:41 +00:00
Ryan Dick
58eba8bdbd Fix FLUX module imports for ControlNetFlux. 2024-10-09 14:17:41 +00:00
Ryan Dick
2821ba8967 Copy ControlNetFlux model from 47495425db/src/flux/controlnet.py. 2024-10-09 14:17:41 +00:00
Ryan Dick
2cc72b19bc Add XLabs FLUX controlnet state dict key file to be used for development/testing. 2024-10-09 14:17:41 +00:00
83 changed files with 3491 additions and 502 deletions

View File

@@ -5,9 +5,10 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@@ -115,6 +116,8 @@ async def delete_board(
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
)
async def list_boards(
order_by: BoardRecordOrderBy = Query(default=BoardRecordOrderBy.CreatedAt, description="The attribute to order by"),
direction: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The direction to order by"),
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
@@ -122,9 +125,9 @@ async def list_boards(
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all(include_archived)
return ApiDependencies.invoker.services.boards.get_all(order_by, direction, include_archived)
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
return ApiDependencies.invoker.services.boards.get_many(order_by, direction, offset, limit, include_archived)
else:
raise HTTPException(
status_code=400,

View File

@@ -88,7 +88,7 @@ async def list_workflows(
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
),
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
category: Optional[WorkflowCategory] = Query(default=None, description="The category of workflow to get"),
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets a page of workflows"""

View File

@@ -192,6 +192,7 @@ class FieldDescriptions:
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
class ImageField(BaseModel):

View File

@@ -0,0 +1,99 @@
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
class FluxControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
instantx_control_mode: int | None = Field(default=-1, description=FieldDescriptions.instantx_control_mode)
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("flux_controlnet_output")
class FluxControlNetOutput(BaseInvocationOutput):
"""FLUX ControlNet info"""
control: FluxControlNetField = OutputField(description=FieldDescriptions.control)
@invocation(
"flux_controlnet",
title="FLUX ControlNet",
tags=["controlnet", "flux"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxControlNetInvocation(BaseInvocation):
"""Collect FLUX ControlNet info to pass to other nodes."""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float | list[float] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
# Note: We default to -1 instead of None, because in the workflow editor UI None is not currently supported.
instantx_control_mode: int | None = InputField(default=-1, description=FieldDescriptions.instantx_control_mode)
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
return FluxControlNetOutput(
control=FluxControlNetField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
instantx_control_mode=self.instantx_control_mode,
),
)

View File

@@ -16,11 +16,16 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
@@ -44,7 +49,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.0.0",
version="3.1.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -87,6 +92,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -167,8 +179,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
@@ -192,12 +204,21 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
with (
transformer_info.model_on_device() as (cached_weights, transformer),
ExitStack() as exit_stack,
):
assert isinstance(transformer, Flux)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
controlnet_extensions = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
device=x.device,
)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None
@@ -242,6 +263,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -288,6 +310,104 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _prep_controlnet_extensions(
self,
context: InvocationContext,
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[FluxControlNetField]
if self.control is None:
controlnets = []
elif isinstance(self.control, FluxControlNetField):
controlnets = [self.control]
elif isinstance(self.control, list):
controlnets = self.control
else:
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
image = context.images.get_pil(controlnet.image.image_name)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
controlnet_conds.append(
XLabsControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)
# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)
if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
XLabsControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
elif isinstance(model, InstantXControlNetFlux):
instantx_control_mode: torch.Tensor | None = None
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
controlnet_extensions.append(
InstantXControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
else:
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
return controlnet_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)

View File

@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord, BoardRecordOrderBy
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class BoardRecordStorageBase(ABC):
@@ -39,12 +40,19 @@ class BoardRecordStorageBase(ABC):
@abstractmethod
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
"""Gets all board records."""
pass

View File

@@ -1,8 +1,10 @@
from datetime import datetime
from enum import Enum
from typing import Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
@@ -60,6 +62,13 @@ class BoardChanges(BaseModel, extra="forbid"):
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum):
"""The order by options for board records"""
CreatedAt = "created_at"
Name = "board_name"
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""

View File

@@ -8,10 +8,12 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecord,
BoardRecordDeleteException,
BoardRecordNotFoundException,
BoardRecordOrderBy,
BoardRecordSaveException,
deserialize_board_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
@@ -144,7 +146,12 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
return self.get(board_id)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
@@ -154,17 +161,16 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
@@ -198,23 +204,32 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
finally:
self._lock.release()
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
try:
self._lock.acquire()
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
if include_archived:
archived_filter = ""
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
archived_filter = "WHERE archived = 0"
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
final_query = base_query.format(archived_filter=archived_filter)
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
self._cursor.execute(final_query)

View File

@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class BoardServiceABC(ABC):
@@ -43,12 +44,19 @@ class BoardServiceABC(ABC):
@abstractmethod
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardDTO]:
"""Gets all boards."""
pass

View File

@@ -1,8 +1,9 @@
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class BoardService(BoardServiceABC):
@@ -47,9 +48,16 @@ class BoardService(BoardServiceABC):
self.__invoker.services.board_records.delete(board_id)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_records = self.__invoker.services.board_records.get_many(
order_by, direction, offset, limit, include_archived
)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@@ -63,8 +71,10 @@ class BoardService(BoardServiceABC):
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(order_by, direction, include_archived)
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)

View File

@@ -184,7 +184,8 @@ class ModelInstallService(ModelInstallServiceBase):
) # type: ignore
if preferred_name := config.name:
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
if model_path.suffix:
preferred_name = f"{preferred_name}.{model_path.suffix}"
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)

View File

@@ -41,9 +41,9 @@ class WorkflowRecordsStorageBase(ABC):
self,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
category: WorkflowCategory,
page: int,
per_page: Optional[int],
category: Optional[WorkflowCategory],
query: Optional[str],
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""

View File

@@ -127,9 +127,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
category: WorkflowCategory,
page: int = 0,
per_page: Optional[int] = None,
category: Optional[WorkflowCategory] = None,
query: Optional[str] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
try:
@@ -137,6 +137,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library"
main_query = """
SELECT
@@ -148,26 +149,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
updated_at,
opened_at
FROM workflow_library
WHERE category = ?
"""
main_params: list[int | str] = []
count_params: list[int | str] = []
if category:
assert category in WorkflowCategory
main_query += " WHERE category = ?"
count_query += " WHERE category = ?"
main_params.append(category.value)
count_params.append(category.value)
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
if "WHERE" in main_query:
main_query += " AND (name LIKE ? OR description LIKE ?)"
count_query += " AND (name LIKE ? OR description LIKE ?)"
else:
main_query += " WHERE name LIKE ? OR description LIKE ?"
count_query += " WHERE name LIKE ? OR description LIKE ?"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])

View File

@@ -0,0 +1,58 @@
from dataclasses import dataclass
import torch
@dataclass
class ControlNetFluxOutput:
single_block_residuals: list[torch.Tensor] | None
double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.single_block_residuals is not None:
for i in range(len(self.single_block_residuals)):
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
if self.double_block_residuals is not None:
for i in range(len(self.double_block_residuals)):
self.double_block_residuals[i] = self.double_block_residuals[i] * weight
def add_tensor_lists_elementwise(
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
) -> list[torch.Tensor] | None:
"""Add two tensor lists elementwise that could be None."""
if list1 is None and list2 is None:
return None
if list1 is None:
return list2
if list2 is None:
return list1
new_list: list[torch.Tensor] = []
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
new_list.append(list1_tensor + list2_tensor)
return new_list
def add_controlnet_flux_outputs(
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
) -> ControlNetFluxOutput:
return ControlNetFluxOutput(
single_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
),
double_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
),
)
def sum_controlnet_flux_outputs(
controlnet_outputs: list[ControlNetFluxOutput],
) -> ControlNetFluxOutput:
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
for controlnet_output in controlnet_outputs:
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)
return controlnet_output_sum

View File

@@ -0,0 +1,180 @@
# This file was initially copied from:
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
from dataclasses import dataclass
import torch
import torch.nn as nn
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
# - Diffusers: BFL
# - in_channels: in_channels
# - num_layers: depth
# - num_single_layers: depth_single_blocks
# - attention_head_dim: hidden_size // num_heads
# - num_attention_heads: num_heads
# - joint_attention_dim: context_in_dim
# - pooled_projection_dim: vec_in_dim
# - guidance_embeds: guidance_embed
# - axes_dims_rope: axes_dim
class InstantXControlNetFlux(torch.nn.Module):
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
"""
Args:
params (FluxParams): The parameters for the FLUX model.
num_control_modes (int | None, optional): The number of controlnet modes. If non-None, then the model is a
'union controlnet' model and expects a mode conditioning input at runtime.
"""
super().__init__()
# The following modules mirror the base FLUX transformer model.
# -------------------------------------------------------------
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
# The following modules are specific to the ControlNet model.
# -----------------------------------------------------------
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.double_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.is_union = False
if num_control_modes is not None:
self.is_union = True
self.controlnet_mode_embedder = nn.Embedding(num_control_modes, self.hidden_size)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size))
def forward(
self,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor | None,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> InstantXControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
# Add controlnet_cond embedding.
img = img + self.controlnet_x_embedder(controlnet_cond)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
# If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding.
if self.is_union:
if controlnet_mode is None:
# We allow users to enter 'None' as the controlnet_mode if they don't want to worry about this input.
# We've chosen to use a zero-embedding in this case.
zero_index = torch.zeros([1, 1], dtype=torch.long, device=txt.device)
controlnet_mode_emb = torch.zeros_like(self.controlnet_mode_embedder(zero_index))
else:
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
else:
assert controlnet_mode is None
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
double_block_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
double_block_samples.append(img)
img = torch.cat((txt, img), 1)
single_block_samples: list[torch.Tensor] = []
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
single_block_samples.append(img[:, txt.shape[1] :])
# ControlNet Block
controlnet_double_block_samples: list[torch.Tensor] = []
for double_block_sample, controlnet_block in zip(double_block_samples, self.controlnet_blocks, strict=True):
double_block_sample = controlnet_block(double_block_sample)
controlnet_double_block_samples.append(double_block_sample)
controlnet_single_block_samples: list[torch.Tensor] = []
for single_block_sample, controlnet_block in zip(
single_block_samples, self.controlnet_single_blocks, strict=True
):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples.append(single_block_sample)
return InstantXControlNetFluxOutput(
controlnet_block_samples=controlnet_double_block_samples or None,
controlnet_single_block_samples=controlnet_single_block_samples or None,
)

View File

@@ -0,0 +1,295 @@
from typing import Any, Dict
import torch
from invokeai.backend.flux.model import FluxParams
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an XLabs ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an XLabs ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"input_hint_block.0.bias",
"input_hint_block.0.weight",
"pos_embed_input.bias",
"pos_embed_input.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an InstantX ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an InstantX ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"controlnet_x_embedder.bias",
"controlnet_x_embedder.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
"""Fuse weights along dimension 0.
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
"""
# TODO(ryand): Double check dim=0 is correct.
return torch.cat(t, dim=0)
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], double_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a double block from diffusers format to BFL format."""
to_prefix = f"double_blocks.{double_block_index}"
from_prefix = f"transformer_blocks.{double_block_index}"
new_sd: dict[str, torch.Tensor] = {}
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
return new_sd
# txt_attn.qkv
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
)
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
)
# img_attn.qkv
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
)
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# img_attn
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
# img_mlp
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
# img_mod
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
# txt_attn
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
# txt_mlp
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
# txt_mod
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], single_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a single block from diffusers format to BFL format."""
to_prefix = f"single_blocks.{single_block_index}"
from_prefix = f"single_transformer_blocks.{single_block_index}"
new_sd: dict[str, torch.Tensor] = {}
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.to_q.bias" not in sd:
return new_sd
# linear1 (qkv)
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
sd.pop(f"{from_prefix}.proj_mlp.bias"),
)
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
sd.pop(f"{from_prefix}.proj_mlp.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# linear2
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
# modulation
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
# norm
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.norm_q.weight": "norm.query_norm.scale",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
InstantXControlNetFlux model.
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
implementation to InstantXControlNetFlux to make it compatible with BFL-style models. This function converts the
original state dict to the format expected by InstantXControlNetFlux.
"""
# Shallow copy sd so that we can pop keys from it without modifying the original.
sd = sd.copy()
new_sd: dict[str, torch.Tensor] = {}
# Handle basic 1-to-1 key conversions.
basic_key_map = {
# Base model keys.
# ----------------
# txt_in keys.
"context_embedder.bias": "txt_in.bias",
"context_embedder.weight": "txt_in.weight",
# guidance_in MLPEmbedder keys.
"time_text_embed.guidance_embedder.linear_1.bias": "guidance_in.in_layer.bias",
"time_text_embed.guidance_embedder.linear_1.weight": "guidance_in.in_layer.weight",
"time_text_embed.guidance_embedder.linear_2.bias": "guidance_in.out_layer.bias",
"time_text_embed.guidance_embedder.linear_2.weight": "guidance_in.out_layer.weight",
# vector_in MLPEmbedder keys.
"time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias",
"time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight",
"time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias",
"time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight",
# time_in MLPEmbedder keys.
"time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias",
"time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight",
"time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias",
"time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight",
# img_in keys.
"x_embedder.bias": "img_in.bias",
"x_embedder.weight": "img_in.weight",
}
for old_key, new_key in basic_key_map.items():
v = sd.pop(old_key, None)
if v is not None:
new_sd[new_key] = v
# Handle the double_blocks.
block_index = 0
while True:
converted_double_block_sd = _convert_flux_double_block_sd_from_diffusers_to_bfl_format(sd, block_index)
if len(converted_double_block_sd) == 0:
break
new_sd.update(converted_double_block_sd)
block_index += 1
# Handle the single_blocks.
block_index = 0
while True:
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
if len(converted_singe_block_sd) == 0:
break
new_sd.update(converted_singe_block_sd)
block_index += 1
# Transfer controlnet keys as-is.
for k in list(sd.keys()):
if k.startswith("controlnet_"):
new_sd[k] = sd.pop(k)
# Assert that all keys have been handled.
assert len(sd) == 0
return new_sd
def infer_flux_params_from_state_dict(sd: Dict[str, torch.Tensor]) -> FluxParams:
"""Infer the FluxParams from the shape of a FLUX state dict. When a model is distributed in diffusers format, this
information is all contained in the config.json file that accompanies the model. However, being apple to infer the
params from the state dict enables us to load models (e.g. an InstantX ControlNet) from a single weight file.
"""
hidden_size = sd["img_in.weight"].shape[0]
mlp_hidden_dim = sd["double_blocks.0.img_mlp.0.weight"].shape[0]
# mlp_ratio is a float, but we treat it as an int here to avoid having to think about possible float precision
# issues. In practice, mlp_ratio is usually 4.
mlp_ratio = mlp_hidden_dim // hidden_size
head_dim = sd["double_blocks.0.img_attn.norm.query_norm.scale"].shape[0]
num_heads = hidden_size // head_dim
# Count the number of double blocks.
double_block_index = 0
while f"double_blocks.{double_block_index}.img_attn.qkv.weight" in sd:
double_block_index += 1
# Count the number of single blocks.
single_block_index = 0
while f"single_blocks.{single_block_index}.linear1.weight" in sd:
single_block_index += 1
return FluxParams(
in_channels=sd["img_in.weight"].shape[1],
vec_in_dim=sd["vector_in.in_layer.weight"].shape[1],
context_in_dim=sd["txt_in.weight"].shape[1],
hidden_size=hidden_size,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
depth=double_block_index,
depth_single_blocks=single_block_index,
# axes_dim cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
axes_dim=[16, 56, 56],
# theta cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
theta=10_000,
qkv_bias="double_blocks.0.img_attn.qkv.bias" in sd,
guidance_embed="guidance_in.in_layer.weight" in sd,
)
def infer_instantx_num_control_modes_from_state_dict(sd: Dict[str, torch.Tensor]) -> int | None:
"""Infer the number of ControlNet Union modes from the shape of a InstantX ControlNet state dict.
Returns None if the model is not a ControlNet Union model. Otherwise returns the number of modes.
"""
mode_embedder_key = "controlnet_mode_embedder.weight"
if mode_embedder_key not in sd:
return None
return sd[mode_embedder_key].shape[0]

View File

@@ -0,0 +1,130 @@
# This file was initially based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
from dataclasses import dataclass
import torch
from einops import rearrange
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
class XLabsControlNetFlux(torch.nn.Module):
"""A ControlNet model for FLUX.
The architecture is very similar to the base FLUX model, with the following differences:
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
"""
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else torch.nn.Identity()
)
self.txt_in = torch.nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = torch.nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(controlnet_depth)
]
)
# Add ControlNet blocks.
self.controlnet_blocks = torch.nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.input_hint_block = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)),
)
def forward(
self,
img: torch.Tensor,
img_ids: torch.Tensor,
controlnet_cond: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> XLabsControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples.append(img)
controlnet_block_res_samples: list[torch.Tensor] = []
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples.append(block_res_sample)
return XLabsControlNetFluxOutput(controlnet_double_block_residuals=controlnet_block_res_samples)

View File

@@ -0,0 +1,12 @@
from typing import TypeVar
import torch
T = TypeVar("T", bound=torch.nn.Module)
def zero_module(module: T) -> T:
"""Initialize the parameters of a module to zero."""
for p in module.parameters():
torch.nn.init.zeros_(p)
return module

View File

@@ -3,7 +3,10 @@ from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -21,6 +24,7 @@ def denoise(
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -38,6 +42,30 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
)
# Merge the ControlNet residuals from multiple ControlNets.
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred = model(
img=img,
img_ids=img_ids,
@@ -46,6 +74,8 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -0,0 +1,45 @@
import math
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
class BaseControlNetExtension(ABC):
def __init__(
self,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return 0.0
if isinstance(self._weight, list):
return self._weight[timestep_index]
return self._weight
@abstractmethod
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput: ...

View File

@@ -0,0 +1,194 @@
import math
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
InstantXControlNetFlux,
InstantXControlNetFluxOutput,
)
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.model_manager.load.load_base import LoadedModel
class InstantXControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: InstantXControlNetFlux,
controlnet_cond: torch.Tensor,
instantx_control_mode: torch.Tensor | None,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
self._controlnet_cond = controlnet_cond
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
# The control mode for InstantX ControlNet union models.
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
# Expected shape: (batch_size, 1), Expected dtype: torch.long
# If None, a zero-embedding will be used.
self._instantx_control_mode = instantx_control_mode
# TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def prepare_controlnet_cond(
cls,
controlnet_image: Image,
vae_info: LoadedModel,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
resized_controlnet_image = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Shift the image from [0, 1] to [-1, 1].
resized_controlnet_image = resized_controlnet_image * 2 - 1
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
controlnet_cond = pack(controlnet_cond)
return controlnet_cond
@classmethod
def from_controlnet_image(
cls,
model: InstantXControlNetFlux,
controlnet_image: Image,
instantx_control_mode: torch.Tensor | None,
vae_info: LoadedModel,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
resized_controlnet_image = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Shift the image from [0, 1] to [-1, 1].
resized_controlnet_image = resized_controlnet_image * 2 - 1
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
controlnet_cond = pack(controlnet_cond)
return cls(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def _instantx_output_to_controlnet_output(
self, instantx_output: InstantXControlNetFluxOutput
) -> ControlNetFluxOutput:
# The `interval_control` logic here is based on
# https://github.com/huggingface/diffusers/blob/31058cdaef63ca660a1a045281d156239fba8192/src/diffusers/models/transformers/transformer_flux.py#L507-L511
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
double_block_samples = instantx_output.controlnet_block_samples
if double_block_samples:
interval_control = self._flux_transformer_num_double_blocks / len(double_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(double_block_samples[i // interval_control])
# Handle single block residuals.
single_block_residuals: list[torch.Tensor] = []
single_block_samples = instantx_output.controlnet_single_block_samples
if single_block_samples:
interval_control = self._flux_transformer_num_single_blocks / len(single_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_single_blocks):
single_block_residuals.append(single_block_samples[i // interval_control])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals or None,
single_block_residuals=single_block_residuals or None,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
# Make sure inputs have correct device and dtype.
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
self._instantx_control_mode = (
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
)
instantx_output: InstantXControlNetFluxOutput = self._model(
controlnet_cond=self._controlnet_cond,
controlnet_mode=self._instantx_control_mode,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
controlnet_output = self._instantx_output_to_controlnet_output(instantx_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -0,0 +1,150 @@
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
class XLabsControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: XLabsControlNetFlux,
controlnet_cond: torch.Tensor,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
# _controlnet_cond is the control image passed to the ControlNet model.
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
self._controlnet_cond = controlnet_cond
# TODO(ryand): Pass in these params if a new base transformer / XLabs ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def prepare_controlnet_cond(
cls,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return controlnet_cond
@classmethod
def from_controlnet_image(
cls,
model: XLabsControlNetFlux,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return cls(
model=model,
controlnet_cond=controlnet_cond,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def _xlabs_output_to_controlnet_output(self, xlabs_output: XLabsControlNetFluxOutput) -> ControlNetFluxOutput:
# The modulo index logic used here is based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/model.py#L198-L200
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
xlabs_double_block_residuals = xlabs_output.controlnet_double_block_residuals
if xlabs_double_block_residuals is not None:
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(xlabs_double_block_residuals[i % len(xlabs_double_block_residuals)])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals,
single_block_residuals=None,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
xlabs_output: XLabsControlNetFluxOutput = self._model(
img=img,
img_ids=img_ids,
controlnet_cond=self._controlnet_cond,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
controlnet_output = self._xlabs_output_to_controlnet_output(xlabs_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -87,7 +87,9 @@ class Flux(nn.Module):
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
guidance: Tensor | None,
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -105,12 +107,27 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
# Validate double_block_residuals shape.
if controlnet_double_block_residuals is not None:
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if controlnet_double_block_residuals is not None:
img += controlnet_double_block_residuals[block_index]
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
# Validate single_block_residuals shape.
if controlnet_single_block_residuals is not None:
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
for block_index, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
if controlnet_single_block_residuals is not None:
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)

View File

@@ -8,17 +8,36 @@ from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
)
from invokeai.backend.model_manager.config import (
BaseModelType,
ControlNetCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""

View File

@@ -10,6 +10,15 @@ from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
infer_instantx_num_control_modes_from_state_dict,
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
@@ -24,6 +33,8 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
@@ -293,3 +304,51 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
class FluxControlnetModel(ModelLoader):
"""Class to load FLUX ControlNet models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
model_path = Path(config.path)
elif isinstance(config, ControlNetDiffusersConfig):
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
else:
raise ValueError(f"Unexpected ControlNet model config type: {type(config)}")
sd = load_file(model_path)
# Detect the FLUX ControlNet model type from the state dict.
if is_state_dict_xlabs_controlnet(sd):
return self._load_xlabs_controlnet(sd)
elif is_state_dict_instantx_controlnet(sd):
return self._load_instantx_controlnet(sd)
else:
raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.")
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = XLabsControlNetFlux(params["flux-dev"])
model.load_state_dict(sd, assign=True)
return model
def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with accelerate.init_empty_weights():
model = InstantXControlNetFlux(flux_params, num_control_modes)
model.load_state_dict(sd, assign=True)
return model

View File

@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
@@ -116,6 +120,7 @@ class ModelProbe(object):
"CLIPModel": ModelType.CLIPEmbed,
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
}
@classmethod
@@ -255,7 +260,19 @@ class ModelProbe(object):
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
elif key.startswith(
(
"controlnet",
"control_model",
"input_blocks",
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
# delicate.
"controlnet_blocks",
)
):
return ModelType.ControlNet
elif key.startswith(("image_proj.", "ip_adapter.")):
return ModelType.IPAdapter
@@ -438,6 +455,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"hed": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
@@ -449,7 +467,8 @@ MODEL_NAME_TO_PREPROCESSOR = {
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
model_name_lower = model_name.lower()
if k in model_name_lower:
return ControlAdapterDefaultSettings(preprocessor=v)
return None
@@ -623,6 +642,11 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint):
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
# get_format()?
return BaseModelType.Flux
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"controlnet_mid_block.bias",
@@ -844,22 +868,19 @@ class ControlNetFolderProbe(FolderProbeBase):
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
if config.get("_class_name", None) == "FluxControlNetModel":
return BaseModelType.Flux
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
return base_model
if dimension == 768:
return BaseModelType.StableDiffusion1
if dimension == 1024:
return BaseModelType.StableDiffusion2
if dimension == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
class LoRAFolderProbe(FolderProbeBase):

View File

@@ -422,6 +422,13 @@ STARTER_MODELS: list[StarterModel] = [
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="FLUX.1-dev-Controlnet-Union-Pro",
base=BaseModelType.Flux,
source="Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
description="A unified ControlNet for FLUX.1-dev model that supports 7 control modes, including canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6)",
type=ModelType.ControlNet,
),
# endregion
# region T2I Adapter
StarterModel(

View File

@@ -198,20 +198,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing()
return
elif config.attention_type == "torch-sdp":
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
else:
raise Exception("torch-sdp attention slicing not available")
# torch-sdp is the default in diffusers.
return
# the remainder if this code is called when attention_type=='auto'
# See https://github.com/invoke-ai/InvokeAI/issues/7049 for context.
# Bumping torch from 2.2.2 to 2.4.1 caused the sliced attention implementation to produce incorrect results.
# For now, if a user is on an MPS device and has not explicitly set the attention_type, then we select the
# non-sliced torch-sdp implementation. This keeps things working on MPS at the cost of increased peak memory
# utilization.
if torch.backends.mps.is_available():
return
# The remainder if this code is called when attention_type=='auto'.
if self.unet.device.type == "cuda":
if is_xformers_available() and prefer_xformers:
self.enable_xformers_memory_efficient_attention()
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
# torch-sdp is the default in diffusers.
return
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free

View File

@@ -936,7 +936,8 @@
},
"paramScheduler": {
"paragraphs": [
"\"Planer\" definiert, wie iterativ Rauschen zu einem Bild hinzugefügt wird, oder wie ein Sample bei der Ausgabe eines Modells aktualisiert wird."
"Verwendeter Planer währende des Generierungsprozesses.",
"Jeder Planer definiert, wie einem Bild iterativ Rauschen hinzugefügt wird, oder wie ein Sample basierend auf der Ausgabe eines Modells aktualisiert wird."
],
"heading": "Planer"
},
@@ -962,6 +963,61 @@
},
"ipAdapterMethod": {
"heading": "Methode"
},
"refinerScheduler": {
"heading": "Planer",
"paragraphs": [
"Planer, der während der Veredelungsphase des Generierungsprozesses verwendet wird.",
"Ähnlich wie der Generierungsplaner."
]
},
"compositingCoherenceMode": {
"paragraphs": [
"Verwendete Methode zur Erstellung eines kohärenten Bildes mit dem neu generierten maskierten Bereich."
],
"heading": "Modus"
},
"compositingCoherencePass": {
"heading": "Kohärenzdurchlauf"
},
"controlNet": {
"heading": "ControlNet"
},
"compositingMaskAdjustments": {
"paragraphs": [
"Die Maske anpassen."
],
"heading": "Maskenanpassungen"
},
"compositingMaskBlur": {
"paragraphs": [
"Der Unschärferadius der Maske."
],
"heading": "Maskenunschärfe"
},
"compositingBlurMethod": {
"paragraphs": [
"Die auf den maskierten Bereich angewendete Unschärfemethode."
],
"heading": "Unschärfemethode"
},
"controlNetResizeMode": {
"heading": "Größenänderungsmodus"
},
"paramWidth": {
"heading": "Breite",
"paragraphs": [
"Breite des generierten Bildes. Muss ein Vielfaches von 8 sein."
]
},
"controlNetControlMode": {
"heading": "Kontrollmodus"
},
"controlNetProcessor": {
"heading": "Prozessor"
},
"patchmatchDownScaleSize": {
"heading": "Herunterskalieren"
}
},
"invocationCache": {
@@ -1080,7 +1136,8 @@
"workflowContact": "Kontaktdaten",
"workflowNotes": "Notizen",
"workflowTags": "Tags",
"workflowVersion": "Version"
"workflowVersion": "Version",
"saveToGallery": "In Galerie speichern"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",
@@ -1250,7 +1307,16 @@
"searchByName": "Nach Name suchen",
"promptTemplateCleared": "Promptvorlage gelöscht",
"preview": "Vorschau",
"positivePrompt": "Positiv-Prompt"
"positivePrompt": "Positiv-Prompt",
"active": "Aktiv",
"deleteTemplate2": "Sind Sie sicher, dass Sie diese Vorlage löschen möchten? Dies kann nicht rückgängig gemacht werden.",
"deleteTemplate": "Vorlage löschen",
"copyTemplate": "Vorlage kopieren",
"editTemplate": "Vorlage bearbeiten",
"deleteImage": "Bild löschen",
"defaultTemplates": "Standardvorlagen",
"nameColumn": "'name'",
"exportDownloaded": "Export heruntergeladen"
},
"newUserExperience": {
"gettingStartedSeries": "Wünschen Sie weitere Anleitungen? In unserer <LinkComponent>Einführungsserie</LinkComponent> finden Sie Tipps, wie Sie das Potenzial von Invoke Studio voll ausschöpfen können.",
@@ -1263,13 +1329,22 @@
"bbox": "Bbox"
},
"transform": {
"fitToBbox": "An Bbox anpassen"
"fitToBbox": "An Bbox anpassen",
"reset": "Zurücksetzen",
"apply": "Anwenden",
"cancel": "Abbrechen"
},
"pullBboxIntoLayerError": "Problem, Bbox in die Ebene zu ziehen",
"pullBboxIntoLayer": "Bbox in Ebene ziehen",
"HUD": {
"bbox": "Bbox",
"scaledBbox": "Skalierte Bbox"
"scaledBbox": "Skalierte Bbox",
"entityStatus": {
"isHidden": "{{title}} ist ausgeblendet",
"isDisabled": "{{title}} ist deaktiviert",
"isLocked": "{{title}} ist gesperrt",
"isEmpty": "{{title}} ist leer"
}
},
"fitBboxToLayers": "Bbox an Ebenen anpassen",
"pullBboxIntoReferenceImage": "Bbox ins Referenzbild ziehen",
@@ -1279,7 +1354,12 @@
"clipToBbox": "Pinselstriche auf Bbox beschränken",
"canvasContextMenu": {
"saveBboxToGallery": "Bbox in Galerie speichern",
"bboxGroup": "Aus Bbox erstellen"
"bboxGroup": "Aus Bbox erstellen",
"canvasGroup": "Leinwand",
"newGlobalReferenceImage": "Neues globales Referenzbild",
"newRegionalReferenceImage": "Neues regionales Referenzbild",
"newControlLayer": "Neue Kontroll-Ebene",
"newRasterLayer": "Neue Raster-Ebene"
},
"rectangle": "Rechteck",
"saveCanvasToGallery": "Leinwand in Galerie speichern",
@@ -1310,7 +1390,7 @@
"regional": "Regional",
"newGlobalReferenceImageOk": "Globales Referenzbild erstellt",
"savedToGalleryError": "Fehler beim Speichern in der Galerie",
"savedToGalleryOk": "In Galerie speichern",
"savedToGalleryOk": "In Galerie gespeichert",
"newGlobalReferenceImageError": "Problem beim Erstellen eines globalen Referenzbilds",
"newRegionalReferenceImageOk": "Regionales Referenzbild erstellt",
"duplicate": "Duplizieren",
@@ -1343,12 +1423,39 @@
"showProgressOnCanvas": "Fortschritt auf Leinwand anzeigen",
"controlMode": {
"balanced": "Ausgewogen"
}
},
"globalReferenceImages_withCount_hidden": "Globale Referenzbilder ({{count}} ausgeblendet)",
"sendToGallery": "An Galerie senden",
"stagingArea": {
"accept": "Annehmen",
"next": "Nächste",
"discardAll": "Alle verwerfen",
"discard": "Verwerfen",
"previous": "Vorherige"
},
"regionalGuidance_withCount_visible": "Regionale Führung ({{count}})",
"regionalGuidance_withCount_hidden": "Regionale Führung ({{count}} ausgeblendet)",
"settings": {
"snapToGrid": {
"on": "Ein",
"off": "Aus",
"label": "Am Raster ausrichten"
}
},
"layer_one": "Ebene",
"layer_other": "Ebenen",
"layer_withCount_one": "Ebene ({{count}})",
"layer_withCount_other": "Ebenen ({{count}})"
},
"upsell": {
"shareAccess": "Zugang teilen",
"professional": "Professionell",
"inviteTeammates": "Teamkollegen einladen",
"professionalUpsell": "Verfügbar in der Professional Edition von Invoke. Klicken Sie hier oder besuchen Sie invoke.com/pricing für weitere Details."
},
"upscaling": {
"creativity": "Kreativität",
"structure": "Struktur",
"scale": "Maßstab"
}
}

View File

@@ -285,6 +285,7 @@
"assetsTab": "Files youve uploaded for use in your projects.",
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
"autoSwitchNewImages": "Auto-Switch to New Images",
"boardsSettings": "Boards Settings",
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
@@ -304,6 +305,7 @@
"go": "Go",
"image": "image",
"imagesTab": "Images youve created and saved within Invoke.",
"imagesSettings": "Gallery Images Settings",
"jump": "Jump",
"loading": "Loading",
"newestFirst": "Newest First",
@@ -1641,6 +1643,7 @@
"sendToCanvas": "Send To Canvas",
"newLayerFromImage": "New Layer from Image",
"newCanvasFromImage": "New Canvas from Image",
"newImg2ImgCanvasFromImage": "New Img2Img from Image",
"copyToClipboard": "Copy to Clipboard",
"sendToCanvasDesc": "Pressing Invoke stages your work in progress on the canvas.",
"viewProgressInViewer": "View progress and outputs in the <Btn>Image Viewer</Btn>.",

View File

@@ -1730,7 +1730,8 @@
"mlsd_detection": {
"score_threshold": "Soglia di punteggio",
"distance_threshold": "Soglia di distanza",
"description": "Genera una mappa dei segmenti di linea dal livello selezionato utilizzando il modello di rilevamento dei segmenti di linea MLSD."
"description": "Genera una mappa dei segmenti di linea dal livello selezionato utilizzando il modello di rilevamento dei segmenti di linea MLSD.",
"label": "Rilevamento segmenti di linea"
},
"content_shuffle": {
"label": "Mescola contenuto",

View File

@@ -158,7 +158,9 @@
"move": "Двигать",
"gallery": "Галерея",
"openViewer": "Открыть просмотрщик",
"closeViewer": "Закрыть просмотрщик"
"closeViewer": "Закрыть просмотрщик",
"imagesTab": "Изображения, созданные и сохраненные в Invoke.",
"assetsTab": "Файлы, которые вы загрузили для использования в своих проектах."
},
"hotkeys": {
"searchHotkeys": "Поиск горячих клавиш",
@@ -928,7 +930,10 @@
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию",
"saveToGallery": "Сохранить в галерею"
"saveToGallery": "Сохранить в галерею",
"noWorkflows": "Нет рабочих процессов",
"noMatchingWorkflows": "Нет совпадающих рабочих процессов",
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>"
},
"boards": {
"autoAddBoard": "Авто добавление Доски",
@@ -1553,7 +1558,10 @@
"autoLayout": "Автоматическое расположение",
"userWorkflows": "Пользовательские рабочие процессы",
"projectWorkflows": "Рабочие процессы проекта",
"defaultWorkflows": "Стандартные рабочие процессы"
"defaultWorkflows": "Стандартные рабочие процессы",
"deleteWorkflow2": "Вы уверены, что хотите удалить этот рабочий процесс? Это нельзя отменить.",
"chooseWorkflowFromLibrary": "Выбрать рабочий процесс из библиотеки",
"uploadAndSaveWorkflow": "Загрузить в библиотеку"
},
"hrf": {
"enableHrf": "Включить исправление высокого разрешения",
@@ -1872,8 +1880,8 @@
"duplicate": "Дублировать",
"inpaintMasks_withCount_visible": "Маски перерисовки ({{count}})",
"layer_one": "Слой",
"layer_few": "",
"layer_many": "",
"layer_few": "Слоя",
"layer_many": "Слоев",
"prompt": "Запрос",
"negativePrompt": "Исключающий запрос",
"beginEndStepPercentShort": "Начало/конец %",
@@ -2035,7 +2043,7 @@
"whatsNewInInvoke": "Что нового в Invoke"
},
"newUserExperience": {
"toGetStarted": "Чтобы начать работу, введите в поле запрос и нажмите <StrongComponent>Invoke</StrongComponent>, чтобы сгенерировать первое изображение. Вы можете сохранить изображения непосредственно в <StrongComponent>Галерею</StrongComponent> или отредактировать их на <StrongComponent>Холсте</StrongComponent>.",
"toGetStarted": "Чтобы начать работу, введите в поле запрос и нажмите <StrongComponent>Invoke</StrongComponent>, чтобы сгенерировать первое изображение. Выберите шаблон запроса, чтобы улучшить результаты. Вы можете сохранить изображения непосредственно в <StrongComponent>Галерею</StrongComponent> или отредактировать их на <StrongComponent>Холсте</StrongComponent>.",
"gettingStartedSeries": "Хотите получить больше рекомендаций? Ознакомьтесь с нашей серией <LinkComponent>Getting Started Series</LinkComponent> для получения советов по раскрытию всего потенциала Invoke Studio."
}
}

View File

@@ -20,6 +20,7 @@ import {
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
@@ -120,6 +121,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<GlobalImageHotkeys />
<NewGallerySessionDialog />
<NewCanvasSessionDialog />
<ImageContextMenu />
</ErrorBoundary>
);
};

View File

@@ -4,9 +4,9 @@ import { IAILoadingImageFallback, IAINoContentFallback } from 'common/components
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
@@ -17,7 +17,14 @@ const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
const baseStyles: SystemStyleObject = {
touchAction: 'none',
userSelect: 'none',
webkitUserSelect: 'none',
};
const sx: SystemStyleObject = {
...baseStyles,
'.gallery-image-container::before': {
content: '""',
display: 'inline-block',
@@ -102,59 +109,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
useThumbailFallback,
withHoverOverlay = false,
children,
onMouseOver,
onMouseOut,
dataTestId,
...rest
} = props;
const handleMouseOver = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOver) {
onMouseOver(e);
}
},
[onMouseOver]
);
const handleMouseOut = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOut) {
onMouseOut(e);
}
},
[onMouseOut]
);
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
isDisabled: isUploadDisabled,
});
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
const styles: SystemStyleObject = {
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: 'base.500',
};
if (!isUploadDisabled) {
Object.assign(styles, {
cursor: 'pointer',
bg: 'base.700',
_hover: {
bg: 'base.650',
color: 'base.300',
},
});
}
return styles;
}, [isUploadDisabled, minSize]);
const openInNewTab = useCallback(
(e: MouseEvent) => {
if (!imageDTO) {
@@ -168,76 +126,126 @@ const IAIDndImage = (props: IAIDndImageProps) => {
[imageDTO]
);
const ref = useRef<HTMLDivElement>(null);
useImageContextMenu(imageDTO, ref);
return (
<ImageContextMenu imageDTO={imageDTO}>
{(ref) => (
<Flex
ref={ref}
width="full"
height="full"
alignItems="center"
justifyContent="center"
position="relative"
minW={minSize ? minSize : undefined}
minH={minSize ? minSize : undefined}
userSelect="none"
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
sx={withHoverOverlay ? sx : baseStyles}
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
{...rest}
>
{imageDTO && (
<Flex
ref={ref}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
width="full"
height="full"
className="gallery-image-container"
w="full"
h="full"
position={fitContainer ? 'absolute' : 'relative'}
alignItems="center"
justifyContent="center"
position="relative"
minW={minSize ? minSize : undefined}
minH={minSize ? minSize : undefined}
userSelect="none"
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
sx={withHoverOverlay ? sx : undefined}
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
{...rest}
>
{imageDTO && (
<Flex
className="gallery-image-container"
w="full"
h="full"
position={fitContainer ? 'absolute' : 'relative'}
alignItems="center"
justifyContent="center"
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
onError={onError}
draggable={false}
w={imageDTO.width}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
sx={imageSx}
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
<input {...getUploadInputProps()} />
{uploadElement}
</Flex>
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
onError={onError}
draggable={false}
w={imageDTO.width}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
sx={imageSx}
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
</Flex>
)}
</ImageContextMenu>
{!imageDTO && !isUploadDisabled && (
<UploadButton
isUploadDisabled={isUploadDisabled}
postUploadAction={postUploadAction}
uploadElement={uploadElement}
minSize={minSize}
/>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
</Flex>
);
};
export default memo(IAIDndImage);
const UploadButton = memo(
({
isUploadDisabled,
postUploadAction,
uploadElement,
minSize,
}: {
isUploadDisabled: boolean;
postUploadAction?: PostUploadAction;
uploadElement: ReactNode;
minSize: number;
}) => {
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
isDisabled: isUploadDisabled,
});
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
const styles: SystemStyleObject = {
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: 'base.500',
};
if (!isUploadDisabled) {
Object.assign(styles, {
cursor: 'pointer',
bg: 'base.700',
_hover: {
bg: 'base.650',
color: 'base.300',
},
});
}
return styles;
}, [isUploadDisabled, minSize]);
return (
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
<input {...getUploadInputProps()} />
{uploadElement}
</Flex>
);
}
);
UploadButton.displayName = 'UploadButton';

View File

@@ -9,7 +9,6 @@ import {
isModalOpenChanged,
selectChangeBoardModalSlice,
} from 'features/changeBoardModal/store/slice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
@@ -29,8 +28,7 @@ const ChangeBoardModal = () => {
useAssertSingleton('ChangeBoardModal');
const dispatch = useAppDispatch();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true });
const isModalOpen = useAppSelector(selectIsModalOpen);
const imagesToChange = useAppSelector(selectImagesToChange);
const [addImagesToBoard] = useAddImagesToBoardMutation();

View File

@@ -80,7 +80,6 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
isDisabled={isFLUX}
>
{t('controlLayers.controlLayer')}
</Button>

View File

@@ -56,7 +56,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.layer_other')}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isFLUX}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>

View File

@@ -99,7 +99,6 @@ const PanelTabs = memo(() => {
<Box as="span" w="full">
{layersTabLabel}
</Box>
{dndCtx.active && <Box position="absolute" top={0} left={0} right={0} bottom={0} border="2px solid red" />}
</Tab>
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut}>
{t('gallery.gallery')}

View File

@@ -16,6 +16,7 @@ import {
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
@@ -42,6 +43,7 @@ export const ControlLayerControlAdapter = memo(() => {
const entityIdentifier = useEntityIdentifierContext('control_layer');
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
const isFLUX = useAppSelector(selectIsFLUX);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@@ -117,7 +119,7 @@ export const ControlLayerControlAdapter = memo(() => {
</Flex>
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
{controlAdapter.type === 'controlnet' && (
{controlAdapter.type === 'controlnet' && !isFLUX && (
<ControlLayerControlAdapterControlMode
controlMode={controlAdapter.controlMode}
onChange={onChangeControlMode}

View File

@@ -18,7 +18,7 @@ export const ControlLayerMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete />
<CanvasEntityMenuItemsDelete asIcon />
</IconMenuItemGroup>
<MenuDivider />
<CanvasEntityMenuItemsTransform />

View File

@@ -9,7 +9,7 @@ export const IPAdapterMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete />
<CanvasEntityMenuItemsDelete asIcon />
</IconMenuItemGroup>
);
});

View File

@@ -13,7 +13,7 @@ export const InpaintMaskMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete />
<CanvasEntityMenuItemsDelete asIcon />
</IconMenuItemGroup>
<MenuDivider />
<CanvasEntityMenuItemsTransform />

View File

@@ -17,7 +17,7 @@ export const RasterLayerMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete />
<CanvasEntityMenuItemsDelete asIcon />
</IconMenuItemGroup>
<MenuDivider />
<CanvasEntityMenuItemsTransform />

View File

@@ -14,7 +14,7 @@ export const RegionalGuidanceMenuItems = memo(() => {
<Flex gap={2}>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete />
<CanvasEntityMenuItemsDelete asIcon />
</Flex>
<MenuDivider />
<RegionalGuidanceMenuItemsAddPromptsAndIPAdapter />

View File

@@ -1,3 +1,4 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { IconMenuItem } from 'common/components/IconMenuItem';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -7,7 +8,11 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
export const CanvasEntityMenuItemsDelete = memo(() => {
type Props = {
asIcon?: boolean;
};
export const CanvasEntityMenuItemsDelete = memo(({ asIcon = false }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
@@ -17,15 +22,23 @@ export const CanvasEntityMenuItemsDelete = memo(() => {
dispatch(entityDeleted({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
if (asIcon) {
return (
<IconMenuItem
aria-label={t('common.delete')}
tooltip={t('common.delete')}
onClick={deleteEntity}
icon={<PiTrashSimpleBold />}
isDestructive
isDisabled={!isInteractable}
/>
);
}
return (
<IconMenuItem
aria-label={t('common.delete')}
tooltip={t('common.delete')}
onClick={deleteEntity}
icon={<PiTrashSimpleBold />}
isDestructive
isDisabled={!isInteractable}
/>
<MenuItem onClick={deleteEntity} icon={<PiTrashSimpleBold />} isDestructive isDisabled={!isInteractable}>
{t('common.delete')}
</MenuItem>
);
});

View File

@@ -2,8 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
import {
bboxChangedFromCanvas,
controlLayerAdded,
inpaintMaskAdded,
rasterLayerAdded,
@@ -14,19 +17,32 @@ import {
rgPositivePromptChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import {
selectBboxModelBase,
selectBboxRect,
selectCanvasSlice,
selectEntityOrThrow,
} from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
import {
imageDTOToImageObject,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
export const selectDefaultControlAdapter = createSelector(
@@ -90,6 +106,74 @@ export const useAddRasterLayer = () => {
return func;
};
export const useNewRasterLayerFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
/**
* Returns a function that adds a new canvas with the given image as the initial image, replicating the img2img flow:
* - Reset the canvas
* - Resize the bbox to the image's aspect ratio at the optimal size for the selected model
* - Add the image as a raster layer
* - Resizes the layer to fit the bbox using the 'fill' strategy
*
* This allows the user to immediately generate a new image from the given image without any additional steps.
*/
export const useNewCanvasFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const base = useAppSelector(selectBboxModelBase);
const func = useCallback(
(imageDTO: ImageDTO) => {
// Calculate the new bbox dimensions to fit the image's aspect ratio at the optimal size
const ratio = imageDTO.width / imageDTO.height;
const optimalDimension = getOptimalDimension(base);
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
// The overrides need to include the layer's ID so we can transform the layer it is initialized
const overrides = {
id: getPrefixedId('raster_layer'),
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageDTOToImageObject(imageDTO)],
} satisfies Partial<CanvasRasterLayerState>;
CanvasEntityAdapterBase.registerInitCallback(async (adapter) => {
// Skip the callback if the adapter is not the one we are creating
if (adapter.id !== overrides.id) {
return false;
}
// Fit the layer to the bbox w/ fill strategy
await adapter.transformer.startTransform({ silent: true });
adapter.transformer.fitToBboxFill();
await adapter.transformer.applyTransform();
return true;
});
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
},
[base, bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useAddInpaintMask = () => {
const dispatch = useAppDispatch();
const func = useCallback(() => {

View File

@@ -7,6 +7,7 @@ import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/ko
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
import type { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getKonvaNodeDebugAttrs, getRectIntersection } from 'features/controlLayers/konva/util';
@@ -15,7 +16,8 @@ import {
selectIsolatedTransformingPreview,
} from 'features/controlLayers/store/canvasSettingsSlice';
import {
buildEntityIsHiddenSelector,
buildSelectIsHidden,
buildSelectIsSelected,
selectBboxRect,
selectCanvasSlice,
selectEntity,
@@ -29,6 +31,11 @@ import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import { assert } from 'tsafe';
// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter`
// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. We'll need to do a
// type assertion below in the `onInit` method, which calls these callbacks.
type InitCallback = (adapter: CanvasEntityAdapter) => Promise<boolean>;
export abstract class CanvasEntityAdapterBase<
T extends CanvasRenderableEntityState,
U extends string,
@@ -87,7 +94,79 @@ export abstract class CanvasEntityAdapterBase<
*/
abstract getHashableState: () => SerializableObject;
/**
* Callbacks that are executed when the module is initialized.
*/
private static initCallbacks = new Set<InitCallback>();
/**
* Register a callback to be run when an entity adapter is initialized.
*
* The callback is called for every adapter that is initialized with the adapter as its only argument. Use an early
* return to skip entities that are not of interest, returning `false` to keep the callback registered. Return `true`
* to unregister the callback after it is called.
*
* @param callback The callback to register.
*
* @example
* ```ts
* // A callback that is executed once for a specific entity:
* const myId = 'my_id';
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
* if (adapter.id !== myId) {
* // These are not the droids you are looking for, move along
* return false;
* }
*
* doSomething();
*
* // Remove the callback
* return true;
* });
* ```
*
* @example
* ```ts
* // A callback that is executed once for the next entity that is initialized:
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
* doSomething();
*
* // Remove the callback
* return true;
* });
* ```
*
* @example
* ```ts
* // A callback that is executed for every entity and is never removed:
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
* // Do something with the adapter
* return false;
* });
*/
static registerInitCallback = (callback: InitCallback) => {
const wrapped = async (adapter: CanvasEntityAdapter) => {
const result = await callback(adapter);
if (result) {
this.initCallbacks.delete(wrapped);
}
return result;
};
this.initCallbacks.add(wrapped);
};
/**
* Runs all init callbacks with the given entity adapter.
* @param adapter The adapter of the entity that was initialized.
*/
private static runInitCallbacks = (adapter: CanvasEntityAdapter) => {
for (const callback of this.initCallbacks) {
callback(adapter);
}
};
selectIsHidden: Selector<RootState, boolean>;
selectIsSelected: Selector<RootState, boolean>;
/**
* The Konva nodes that make up the entity adapter:
@@ -171,7 +250,8 @@ export abstract class CanvasEntityAdapterBase<
assert(state !== undefined, 'Missing entity state on creation');
this.state = state;
this.selectIsHidden = buildEntityIsHiddenSelector(this.entityIdentifier);
this.selectIsHidden = buildSelectIsHidden(this.entityIdentifier);
this.selectIsSelected = buildSelectIsSelected(this.entityIdentifier);
/**
* There are a number of reason we may need to show or hide a layer:
@@ -180,6 +260,7 @@ export abstract class CanvasEntityAdapterBase<
* - Staging status changes and `isolatedStagingPreview` is enabled
* - Global filtering status changes and `isolatedFilteringPreview` is enabled
* - Global transforming status changes and `isolatedTransformingPreview` is enabled
* - The entity is selected or deselected (only selected and onscreen entities are rendered)
*/
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsHidden, this.syncVisibility));
this.subscriptions.add(
@@ -190,6 +271,7 @@ export abstract class CanvasEntityAdapterBase<
this.manager.stateApi.createStoreSubscription(selectIsolatedTransformingPreview, this.syncVisibility)
);
this.subscriptions.add(this.manager.stateApi.$transformingAdapter.listen(this.syncVisibility));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsSelected, this.syncVisibility));
/**
* The tool preview may need to be updated when the entity is locked or disabled. For example, when we disable the
@@ -228,21 +310,8 @@ export abstract class CanvasEntityAdapterBase<
syncIsOnscreen = () => {
const stageRect = this.manager.stage.getScaledStageRect();
const entityRect = this.transformer.$pixelRect.get();
const position = this.manager.stateApi.runSelector(this.selectPosition);
if (!position) {
return;
}
const entityRectRelativeToStage = {
x: entityRect.x + position.x,
y: entityRect.y + position.y,
width: entityRect.width,
height: entityRect.height,
};
const intersection = getRectIntersection(stageRect, entityRectRelativeToStage);
const isOnScreen = this.checkIntersection(stageRect);
const prevIsOnScreen = this.$isOnScreen.get();
const isOnScreen = intersection.width > 0 && intersection.height > 0;
this.$isOnScreen.set(isOnScreen);
if (prevIsOnScreen !== isOnScreen) {
this.log.trace(`Moved ${isOnScreen ? 'on-screen' : 'off-screen'}`);
@@ -252,10 +321,19 @@ export abstract class CanvasEntityAdapterBase<
syncIntersectsBbox = () => {
const bboxRect = this.manager.stateApi.getBbox().rect;
const intersectsBbox = this.checkIntersection(bboxRect);
const prevIntersectsBbox = this.$intersectsBbox.get();
this.$intersectsBbox.set(intersectsBbox);
if (prevIntersectsBbox !== intersectsBbox) {
this.log.trace(`Moved ${intersectsBbox ? 'into bbox' : 'out of bbox'}`);
}
};
checkIntersection = (rect: Rect): boolean => {
const entityRect = this.transformer.$pixelRect.get();
const position = this.manager.stateApi.runSelector(this.selectPosition);
if (!position) {
return;
return false;
}
const entityRectRelativeToStage = {
x: entityRect.x + position.x,
@@ -263,14 +341,9 @@ export abstract class CanvasEntityAdapterBase<
width: entityRect.width,
height: entityRect.height,
};
const intersection = getRectIntersection(bboxRect, entityRectRelativeToStage);
const prevIntersectsBbox = this.$intersectsBbox.get();
const intersectsBbox = intersection.width > 0 && intersection.height > 0;
this.$intersectsBbox.set(intersectsBbox);
if (prevIntersectsBbox !== intersectsBbox) {
this.log.trace(`Moved ${intersectsBbox ? 'into bbox' : 'out of bbox'}`);
}
const intersection = getRectIntersection(rect, entityRectRelativeToStage);
const doesIntersect = intersection.width > 0 && intersection.height > 0;
return doesIntersect;
};
initialize = async () => {
@@ -299,6 +372,10 @@ export abstract class CanvasEntityAdapterBase<
await this.renderer.initialize();
this.syncZIndices();
this.syncVisibility();
// Call the init callbacks.
// TODO(psyche): Get rid of the cast - see note in type def for `InitCallback`.
CanvasEntityAdapterBase.runInitCallbacks(this as CanvasEntityAdapter);
};
syncZIndices = () => {

View File

@@ -1,3 +1,4 @@
import { Mutex } from 'async-mutex';
import { withResultAsync } from 'common/util/result';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
@@ -166,6 +167,13 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
*/
$silentTransform = atom(false);
/**
* A mutex to prevent concurrent operations.
*
* The mutex is locked during transformation and during rect calculations which are handled in a web worker.
*/
transformMutex = new Mutex();
konva: {
transformer: Konva.Transformer;
proxyRect: Konva.Rect;
@@ -424,6 +432,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -437,8 +446,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const offsetY = (rect.height - height * scale) / 2;
this.konva.proxyRect.setAttrs({
x: clamp(Math.round(rect.x + offsetX), rect.x, rect.x + rect.width),
y: clamp(Math.round(rect.y + offsetY), rect.y, rect.y + rect.height),
x: clamp(roundToMultiple(rect.x + offsetX, gridSize), rect.x, rect.x + rect.width),
y: clamp(roundToMultiple(rect.y + offsetY, gridSize), rect.y, rect.y + rect.height),
scaleX: scale,
scaleY: scale,
rotation: 0,
@@ -455,6 +464,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -468,8 +478,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const offsetY = (rect.height - height * scale) / 2;
this.konva.proxyRect.setAttrs({
x: Math.round(rect.x + offsetX),
y: Math.round(rect.y + offsetY),
x: roundToMultiple(rect.x + offsetX, gridSize),
y: roundToMultiple(rect.y + offsetY, gridSize),
scaleX: scale,
scaleY: scale,
rotation: 0,
@@ -647,11 +657,13 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
* @param arg.silent Whether the transformation should be silent. If silent, the transform controls will not be shown,
* so you _must_ immediately call `applyTransform` or `stopTransform` to complete the transformation.
*/
startTransform = (arg?: { silent: boolean }) => {
startTransform = async (arg?: { silent: boolean }) => {
const transformingAdapter = this.manager.stateApi.$transformingAdapter.get();
if (transformingAdapter) {
assert(false, `Already transforming an entity: ${transformingAdapter.id}`);
}
// This will be released when the transformation is stopped
await this.transformMutex.acquire();
this.log.debug('Starting transform');
const { silent } = { silent: false, ...arg };
this.$silentTransform.set(silent);
@@ -704,6 +716,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.syncInteractionState();
this.manager.stateApi.$transformingAdapter.set(null);
this.$isProcessing.set(false);
this.transformMutex.release();
};
/**
@@ -807,7 +820,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
calculateRect = debounce(() => {
this.log.debug('Calculating bbox');
this.$isPendingRectCalculation.set(true);
const canvas = this.parent.getCanvas();
if (!this.parent.renderer.hasObjects()) {
@@ -817,6 +829,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.$canvasCache.set(canvas);
this.$isPendingRectCalculation.set(false);
this.updateBbox();
this.transformMutex.release();
return;
}
@@ -829,6 +842,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.$canvasCache.set(canvas);
this.$isPendingRectCalculation.set(false);
this.updateBbox();
this.transformMutex.release();
return;
}
@@ -857,11 +871,14 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.$canvasCache.set(canvas);
this.$isPendingRectCalculation.set(false);
this.updateBbox();
this.transformMutex.release();
}
);
}, this.config.RECT_CALC_DEBOUNCE_MS);
requestRectCalculation = () => {
requestRectCalculation = async () => {
// This will be released when the rect calculation is complete
await this.transformMutex.acquire();
this.$isPendingRectCalculation.set(true);
this.syncInteractionState();
this.calculateRect();

View File

@@ -25,7 +25,6 @@ import {
getScaledBoundingBoxDimensions,
} from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import type { MainModelBase } from 'features/nodes/types/common';
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
@@ -772,11 +771,6 @@ export const canvasSlice = createSlice({
syncScaledSize(state);
},
bboxModelBaseChanged: (state, action: PayloadAction<{ modelBase: MainModelBase }>) => {
const { modelBase } = action.payload;
state.bbox.modelBase = modelBase;
syncScaledSize(state);
},
bboxSyncedToOptimalDimension: (state) => {
const optimalDimension = getOptimalDimension(state.bbox.modelBase);

View File

@@ -308,7 +308,7 @@ const getSelectIsTypeHidden = (type: CanvasEntityType) => {
/**
* Builds a selector taht selects if the entity is hidden.
*/
export const buildEntityIsHiddenSelector = (entityIdentifier: CanvasEntityIdentifier) => {
export const buildSelectIsHidden = (entityIdentifier: CanvasEntityIdentifier) => {
const selectIsTypeHidden = getSelectIsTypeHidden(entityIdentifier.type);
return createSelector(
[selectCanvasSlice, selectIsTypeHidden, selectIsStaging, selectIsolatedStagingPreview],
@@ -339,6 +339,16 @@ export const buildEntityIsHiddenSelector = (entityIdentifier: CanvasEntityIdenti
);
};
/**
* Builds a selector taht selects if the entity is selected.
*/
export const buildSelectIsSelected = (entityIdentifier: CanvasEntityIdentifier) => {
return createSelector(
selectSelectedEntityIdentifier,
(selectedEntityIdentifier) => selectedEntityIdentifier?.id === entityIdentifier.id
);
};
export const selectWidth = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.width);
export const selectHeight = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.height);
export const selectAspectRatioID = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.id);

View File

@@ -0,0 +1,86 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectBoardsListOrderBy, selectBoardsListOrderDir } from 'features/gallery/store/gallerySelectors';
import { boardsListOrderByChanged, boardsListOrderDirChanged } from 'features/gallery/store/gallerySlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
const zOrderBy = z.enum(['created_at', 'board_name']);
type OrderBy = z.infer<typeof zOrderBy>;
const isOrderBy = (v: unknown): v is OrderBy => zOrderBy.safeParse(v).success;
const zDirection = z.enum(['ASC', 'DESC']);
type Direction = z.infer<typeof zDirection>;
const isDirection = (v: unknown): v is Direction => zDirection.safeParse(v).success;
export const BoardsListSortControls = () => {
const { t } = useTranslation();
const orderBy = useAppSelector(selectBoardsListOrderBy);
const direction = useAppSelector(selectBoardsListOrderDir);
const ORDER_BY_OPTIONS: ComboboxOption[] = useMemo(
() => [
{ value: 'created_at', label: t('workflows.created') },
{ value: 'board_name', label: t('workflows.name') },
],
[t]
);
const DIRECTION_OPTIONS: ComboboxOption[] = useMemo(
() => [
{ value: 'ASC', label: t('workflows.ascending') },
{ value: 'DESC', label: t('workflows.descending') },
],
[t]
);
const dispatch = useAppDispatch();
const onChangeOrderBy = useCallback<ComboboxOnChange>(
(v) => {
if (!isOrderBy(v?.value) || v.value === orderBy) {
return;
}
dispatch(boardsListOrderByChanged(v.value));
},
[orderBy, dispatch]
);
const valueOrderBy = useMemo(() => {
return ORDER_BY_OPTIONS.find((o) => o.value === orderBy) || ORDER_BY_OPTIONS[0];
}, [orderBy, ORDER_BY_OPTIONS]);
const onChangeDirection = useCallback<ComboboxOnChange>(
(v) => {
if (!isDirection(v?.value) || v.value === direction) {
return;
}
dispatch(boardsListOrderDirChanged(v.value));
},
[direction, dispatch]
);
const valueDirection = useMemo(
() => DIRECTION_OPTIONS.find((o) => o.value === direction),
[direction, DIRECTION_OPTIONS]
);
return (
<Flex flexDir="column" gap={4}>
<FormControl orientation="horizontal" gap={1}>
<FormLabel>{t('common.orderBy')}</FormLabel>
<Combobox isSearchable={false} value={valueOrderBy} options={ORDER_BY_OPTIONS} onChange={onChangeOrderBy} />
</FormControl>
<FormControl orientation="horizontal" gap={1}>
<FormLabel>{t('common.direction')}</FormLabel>
<Combobox
isSearchable={false}
value={valueDirection}
options={DIRECTION_OPTIONS}
onChange={onChangeDirection}
/>
</FormControl>
</Flex>
);
};

View File

@@ -0,0 +1,53 @@
import {
Box,
Divider,
Flex,
IconButton,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
} from '@invoke-ai/ui-library';
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixFill } from 'react-icons/pi';
import { BoardsListSortControls } from './BoardsListSortControls';
const BoardsSettingsPopover = () => {
const { t } = useTranslation();
return (
<Popover isLazy>
<PopoverTrigger>
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('gallery.boardsSettings')}
icon={<PiGearSixFill />}
tooltip={t('gallery.boardsSettings')}
/>
</PopoverTrigger>
<PopoverContent>
<PopoverBody>
<Flex direction="column" gap={2}>
<AutoAssignBoardCheckbox />
<ShowArchivedBoardsCheckbox />
<BoardAutoAddSelect />
<Box py={2}>
<Divider />
</Box>
<BoardsListSortControls />
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default memo(BoardsSettingsPopover);

View File

@@ -23,6 +23,7 @@ import { useTranslation } from 'react-i18next';
import { PiMagnifyingGlassBold } from 'react-icons/pi';
import { useBoardName } from 'services/api/hooks/useBoardName';
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
import { GalleryPagination } from './ImageGrid/GalleryPagination';
import { GallerySearch } from './ImageGrid/GallerySearch';
@@ -85,15 +86,18 @@ export const Gallery = () => {
{t('gallery.assets')}
</Tab>
</Tooltip>
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={handleClickSearch}
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
aria-label={t('gallery.displaySearch')}
icon={<PiMagnifyingGlassBold />}
/>
<Flex h="full" justifyContent="flex-end">
<GallerySettingsPopover />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={handleClickSearch}
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
aria-label={t('gallery.displaySearch')}
icon={<PiMagnifyingGlassBold />}
/>
</Flex>
</TabList>
</Tabs>

View File

@@ -15,8 +15,8 @@ import { Panel, PanelGroup } from 'react-resizable-panels';
import BoardsListWrapper from './Boards/BoardsList/BoardsListWrapper';
import BoardsSearch from './Boards/BoardsList/BoardsSearch';
import BoardsSettingsPopover from './Boards/BoardsSettingsPopover';
import { Gallery } from './Gallery';
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0 };
@@ -64,7 +64,7 @@ const GalleryPanelContent = () => {
</Flex>
<GalleryHeader />
<Flex h="full" w="25%" justifyContent="flex-end">
<GallerySettingsPopover />
<BoardsSettingsPopover />
<IconButton
size="sm"
variant="link"

View File

@@ -1,10 +1,7 @@
import { Divider, Flex, IconButton, Popover, PopoverBody, PopoverContent, PopoverTrigger } from '@invoke-ai/ui-library';
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
import AlwaysShowImageSizeCheckbox from 'features/gallery/components/GallerySettingsPopover/AlwaysShowImageSizeCheckbox';
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
import AutoSwitchCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoSwitchCheckbox';
import ImageMinimumWidthSlider from 'features/gallery/components/GallerySettingsPopover/ImageMinimumWidthSlider';
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
import ShowStarredFirstCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowStarredFirstCheckbox';
import SortDirectionCombobox from 'features/gallery/components/GallerySettingsPopover/SortDirectionCombobox';
import { memo } from 'react';
@@ -21,8 +18,9 @@ const GallerySettingsPopover = () => {
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('gallery.gallerySettings')}
aria-label={t('gallery.imagesSettings')}
icon={<PiGearSixFill />}
tooltip={t('gallery.imagesSettings')}
/>
</PopoverTrigger>
<PopoverContent>
@@ -30,10 +28,7 @@ const GallerySettingsPopover = () => {
<Flex direction="column" gap={2}>
<ImageMinimumWidthSlider />
<AutoSwitchCheckbox />
<AutoAssignBoardCheckbox />
<AlwaysShowImageSizeCheckbox />
<ShowArchivedBoardsCheckbox />
<BoardAutoAddSelect />
<Divider pt={2} />
<ShowStarredFirstCheckbox />
<SortDirectionCombobox />

View File

@@ -1,42 +1,276 @@
import type { ContextMenuProps } from '@invoke-ai/ui-library';
import { ContextMenu, MenuList } from '@invoke-ai/ui-library';
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Menu, MenuButton, MenuList, Portal, useGlobalMenuClose } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import MultipleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems';
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
import { selectSelectionCount } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback } from 'react';
import { map } from 'nanostores';
import type { RefObject } from 'react';
import { memo, useCallback, useEffect, useRef } from 'react';
import type { ImageDTO } from 'services/api/types';
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
/**
* The delay in milliseconds before the context menu opens on long press.
*/
const LONGPRESS_DELAY_MS = 500;
/**
* The threshold in pixels that the pointer must move before the long press is cancelled.
*/
const LONGPRESS_MOVE_THRESHOLD_PX = 10;
type Props = {
imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children'];
/**
* The singleton state of the context menu.
*/
const $imageContextMenuState = map<{
isOpen: boolean;
imageDTO: ImageDTO | null;
position: { x: number; y: number };
}>({
isOpen: false,
imageDTO: null,
position: { x: -1, y: -1 },
});
/**
* Convenience function to close the context menu.
*/
const onClose = () => {
$imageContextMenuState.setKey('isOpen', false);
};
const ImageContextMenu = ({ imageDTO, children }: Props) => {
const selectionCount = useAppSelector(selectSelectionCount);
/**
* Map of elements to image DTOs. This is used to determine which image DTO to show the context menu for, depending on
* the target of the context menu or long press event.
*/
const elToImageMap = new Map<HTMLDivElement, ImageDTO>();
/**
* Given a target node, find the first registered parent element that contains the target node and return the imageDTO
* associated with it.
*/
const getImageDTOFromMap = (target: Node): ImageDTO | undefined => {
const entry = Array.from(elToImageMap.entries()).find((entry) => entry[0].contains(target));
return entry?.[1];
};
/**
* Register a context menu for an image DTO on a target element.
* @param imageDTO The image DTO to register the context menu for.
* @param targetRef The ref of the target element that should trigger the context menu.
*/
export const useImageContextMenu = (imageDTO: ImageDTO | undefined, targetRef: RefObject<HTMLDivElement>) => {
useEffect(() => {
if (!targetRef.current || !imageDTO) {
return;
}
const el = targetRef.current;
elToImageMap.set(el, imageDTO);
return () => {
elToImageMap.delete(el);
};
}, [imageDTO, targetRef]);
};
/**
* Singleton component that renders the context menu for images.
*/
export const ImageContextMenu = memo(() => {
useAssertSingleton('ImageContextMenu');
const state = useStore($imageContextMenuState);
useGlobalMenuClose(onClose);
return (
<Portal>
<Menu isOpen={state.isOpen} gutter={0} placement="auto-end" onClose={onClose}>
<MenuButton
aria-hidden={true}
w={1}
h={1}
position="absolute"
left={state.position.x}
top={state.position.y}
cursor="default"
bg="transparent"
_hover={_hover}
pointerEvents="none"
/>
<MenuContent />
</Menu>
<ImageContextMenuEventLogical />
</Portal>
);
});
ImageContextMenu.displayName = 'ImageContextMenu';
const _hover: ChakraProps['_hover'] = { bg: 'transparent' };
/**
* A logical component that listens for context menu events and opens the context menu. It's separate from
* ImageContextMenu component to avoid re-rendering the whole context menu on every context menu event.
*/
const ImageContextMenuEventLogical = memo(() => {
const lastPositionRef = useRef<{ x: number; y: number }>({ x: -1, y: -1 });
const longPressTimeoutRef = useRef(0);
const animationTimeoutRef = useRef(0);
const onContextMenu = useCallback((e: MouseEvent | PointerEvent) => {
if (e.shiftKey) {
// This is a shift + right click event, which should open the native context menu
onClose();
return;
}
const imageDTO = getImageDTOFromMap(e.target as Node);
const renderMenuFunc = useCallback(() => {
if (!imageDTO) {
return null;
// Can't find the image DTO, close the context menu
onClose();
return;
}
if (selectionCount > 1) {
return (
<MenuList visibility="visible">
<MultipleSelectionMenuItems />
</MenuList>
);
// clear pending delayed open
window.clearTimeout(animationTimeoutRef.current);
e.preventDefault();
if (lastPositionRef.current.x !== e.pageX || lastPositionRef.current.y !== e.pageY) {
// if the mouse moved, we need to close, wait for animation and reopen the menu at the new position
if ($imageContextMenuState.get().isOpen) {
onClose();
}
animationTimeoutRef.current = window.setTimeout(() => {
// Open the menu after the animation with the new state
$imageContextMenuState.set({
isOpen: true,
position: { x: e.pageX, y: e.pageY },
imageDTO,
});
}, 100);
} else {
// else we can just open the menu at the current position w/ new state
$imageContextMenuState.set({
isOpen: true,
position: { x: e.pageX, y: e.pageY },
imageDTO,
});
}
// Always sync the last position
lastPositionRef.current = { x: e.pageX, y: e.pageY };
}, []);
// Use a long press to open the context menu on touch devices
const onPointerDown = useCallback(
(e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
longPressTimeoutRef.current = window.setTimeout(() => {
onContextMenu(e);
}, LONGPRESS_DELAY_MS);
lastPositionRef.current = { x: e.pageX, y: e.pageY };
},
[onContextMenu]
);
const onPointerMove = useCallback((e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
if (longPressTimeoutRef.current === null) {
return;
}
// If the pointer has moved more than the threshold, cancel the long press
const lastPosition = lastPositionRef.current;
const distanceFromLastPosition = Math.hypot(e.pageX - lastPosition.x, e.pageY - lastPosition.y);
if (distanceFromLastPosition > LONGPRESS_MOVE_THRESHOLD_PX) {
clearTimeout(longPressTimeoutRef.current);
}
}, []);
const onPointerUp = useCallback((e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
if (longPressTimeoutRef.current) {
clearTimeout(longPressTimeoutRef.current);
}
}, []);
const onPointerCancel = useCallback((e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
if (longPressTimeoutRef.current) {
clearTimeout(longPressTimeoutRef.current);
}
}, []);
useEffect(() => {
const controller = new AbortController();
// Context menu events
window.addEventListener('contextmenu', onContextMenu, { signal: controller.signal });
// Long press events
window.addEventListener('pointerdown', onPointerDown, { signal: controller.signal });
window.addEventListener('pointerup', onPointerUp, { signal: controller.signal });
window.addEventListener('pointercancel', onPointerCancel, { signal: controller.signal });
window.addEventListener('pointermove', onPointerMove, { signal: controller.signal });
return () => {
controller.abort();
};
}, [onContextMenu, onPointerCancel, onPointerDown, onPointerMove, onPointerUp]);
useEffect(
() => () => {
// Clean up any timeouts when we unmount
window.clearTimeout(animationTimeoutRef.current);
window.clearTimeout(longPressTimeoutRef.current);
},
[]
);
return null;
});
ImageContextMenuEventLogical.displayName = 'ImageContextMenuEventLogical';
// The content of the context menu, which changes based on the selection count. Split out and memoized to avoid
// re-rendering the whole context menu too often.
const MenuContent = memo(() => {
const selectionCount = useAppSelector(selectSelectionCount);
const state = useStore($imageContextMenuState);
if (!state.imageDTO) {
return null;
}
if (selectionCount > 1) {
return (
<MenuList visibility="visible">
<SingleSelectionMenuItems imageDTO={imageDTO} />
<MultipleSelectionMenuItems />
</MenuList>
);
}, [imageDTO, selectionCount]);
}
return <ContextMenu renderMenu={renderMenuFunc}>{children}</ContextMenu>;
};
return (
<MenuList visibility="visible">
<SingleSelectionMenuItems imageDTO={state.imageDTO} />
</MenuList>
);
});
export default memo(ImageContextMenu);
MenuContent.displayName = 'MenuContent';

View File

@@ -1,11 +1,7 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { canvasReset } from 'features/controlLayers/store/actions';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { useAppDispatch } from 'app/store/storeHooks';
import { useNewCanvasFromImage } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { toast } from 'features/toast/toast';
@@ -14,23 +10,16 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFileBold } from 'react-icons/pi';
const selectBboxRect = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect);
export const ImageMenuItemNewCanvasFromImage = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const imageDTO = useImageDTOContext();
const bboxRect = useAppSelector(selectBboxRect);
const imageViewer = useImageViewer();
const newCanvasFromImage = useNewCanvasFromImage();
const isBusy = useCanvasIsBusy();
const handleSendToCanvas = useCallback(() => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(canvasReset());
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
const onClick = useCallback(() => {
newCanvasFromImage(imageDTO);
dispatch(setActiveTab('canvas'));
imageViewer.close();
toast({
@@ -38,10 +27,10 @@ export const ImageMenuItemNewCanvasFromImage = memo(() => {
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [bboxRect.x, bboxRect.y, dispatch, imageDTO, imageViewer, t]);
}, [dispatch, imageDTO, imageViewer, newCanvasFromImage, t]);
return (
<MenuItem icon={<PiFileBold />} onClickCapture={handleSendToCanvas}>
<MenuItem icon={<PiFileBold />} onClickCapture={onClick} isDisabled={isBusy}>
{t('controlLayers.newCanvasFromImage')}
</MenuItem>
);

View File

@@ -1,11 +1,8 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch } from 'app/store/storeHooks';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { useNewRasterLayerFromImage } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { sentImageToCanvas } from 'features/gallery/store/actions';
@@ -14,23 +11,17 @@ import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selectBboxRect = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect);
export const ImageMenuItemNewLayerFromImage = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const imageDTO = useImageDTOContext();
const bboxRect = useAppSelector(selectBboxRect);
const imageViewer = useImageViewer();
const newRasterLayerFromImage = useNewRasterLayerFromImage();
const isBusy = useCanvasIsBusy();
const handleSendToCanvas = useCallback(() => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
const onClick = useCallback(() => {
dispatch(sentImageToCanvas());
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
newRasterLayerFromImage(imageDTO);
dispatch(setActiveTab('canvas'));
imageViewer.close();
toast({
@@ -38,10 +29,10 @@ export const ImageMenuItemNewLayerFromImage = memo(() => {
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [bboxRect.x, bboxRect.y, dispatch, imageDTO, imageViewer, t]);
}, [dispatch, imageDTO, imageViewer, newRasterLayerFromImage, t]);
return (
<MenuItem icon={<NewLayerIcon />} onClickCapture={handleSendToCanvas}>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClick} isDisabled={isBusy}>
{t('controlLayers.newLayerFromImage')}
</MenuItem>
);

View File

@@ -1,5 +1,6 @@
import { MenuDivider } from '@invoke-ai/ui-library';
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { ImageMenuItemChangeBoard } from 'features/gallery/components/ImageContextMenu/ImageMenuItemChangeBoard';
import { ImageMenuItemCopy } from 'features/gallery/components/ImageContextMenu/ImageMenuItemCopy';
import { ImageMenuItemDelete } from 'features/gallery/components/ImageContextMenu/ImageMenuItemDelete';
@@ -37,8 +38,10 @@ const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) =
<ImageMenuItemMetadataRecallActions />
<MenuDivider />
<ImageMenuItemSendToUpscale />
<ImageMenuItemNewLayerFromImage />
<ImageMenuItemNewCanvasFromImage />
<CanvasManagerProviderGate>
<ImageMenuItemNewLayerFromImage />
<ImageMenuItemNewCanvasFromImage />
</CanvasManagerProviderGate>
<MenuDivider />
<ImageMenuItemChangeBoard />
<ImageMenuItemStarUnstar />

View File

@@ -63,12 +63,9 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare?.image_name === imageDTO.image_name),
[imageDTO.image_name]
);
const alwaysShowImageSizeBadge = useAppSelector(selectAlwaysShouldImageSizeBadge);
const isSelectedForCompare = useAppSelector(selectIsSelectedForCompare);
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
const customStarUi = useStore($customStarUI);
const imageContainerRef = useScrollIntoView(isSelected, index, areMultiplesSelected);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
@@ -91,20 +88,6 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
}
}, [imageDTO, selectedBoardId, areMultiplesSelected]);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
const toggleStarredState = useCallback(() => {
if (imageDTO) {
if (imageDTO.starred) {
unstarImages({ imageDTOs: [imageDTO] });
}
if (!imageDTO.starred) {
starImages({ imageDTOs: [imageDTO] });
}
}
}, [starImages, unstarImages, imageDTO]);
const [isHovered, setIsHovered] = useState(false);
const handleMouseOver = useCallback(() => {
@@ -121,25 +104,6 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
setIsHovered(false);
}, []);
const starIcon = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.on.icon : <PiStarFill />;
}
if (!imageDTO.starred && isHovered) {
return customStarUi ? customStarUi.off.icon : <PiStarBold />;
}
}, [imageDTO.starred, isHovered, customStarUi]);
const starTooltip = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.off.text : 'Unstar';
}
if (!imageDTO.starred) {
return customStarUi ? customStarUi.on.text : 'Star';
}
return '';
}, [imageDTO.starred, customStarUi]);
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO.image_name), [imageDTO.image_name]);
if (!imageDTO) {
@@ -155,6 +119,8 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
justifyContent="center"
alignItems="center"
aspectRatio="1/1"
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<IAIDndImage
onClick={handleClick}
@@ -169,38 +135,8 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
isUploadDisabled={true}
thumbnail={true}
withHoverOverlay
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<>
{(isHovered || alwaysShowImageSizeBadge) && (
<Text
position="absolute"
background="base.900"
color="base.50"
fontSize="sm"
fontWeight="semibold"
bottom={1}
left={1}
opacity={0.7}
px={2}
lineHeight={1.25}
borderTopEndRadius="base"
sx={badgeSx}
pointerEvents="none"
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
)}
<IAIDndImageIcon
onClick={toggleStarredState}
icon={starIcon}
tooltip={starTooltip}
position="absolute"
top={2}
insetInlineEnd={2}
/>
{isHovered && <DeleteIcon imageDTO={imageDTO} />}
{isHovered && <OpenInViewerIconButton imageDTO={imageDTO} />}
</>
<HoverIcons imageDTO={imageDTO} isHovered={isHovered} />
</IAIDndImage>
</Flex>
</Box>
@@ -209,7 +145,21 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
GalleryImageContent.displayName = 'GalleryImageContent';
const DeleteIcon = ({ imageDTO }: { imageDTO: ImageDTO }) => {
const HoverIcons = memo(({ imageDTO, isHovered }: { imageDTO: ImageDTO; isHovered: boolean }) => {
const alwaysShowImageSizeBadge = useAppSelector(selectAlwaysShouldImageSizeBadge);
return (
<>
{(isHovered || alwaysShowImageSizeBadge) && <SizeBadge imageDTO={imageDTO} />}
{(isHovered || imageDTO.starred) && <StarIcon imageDTO={imageDTO} />}
{isHovered && <DeleteIcon imageDTO={imageDTO} />}
{isHovered && <OpenInViewerIconButton imageDTO={imageDTO} />}
</>
);
});
HoverIcons.displayName = 'HoverIcons';
const DeleteIcon = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const shift = useShiftModifier();
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -238,9 +188,11 @@ const DeleteIcon = ({ imageDTO }: { imageDTO: ImageDTO }) => {
insetInlineEnd={2}
/>
);
};
});
const OpenInViewerIconButton = ({ imageDTO }: { imageDTO: ImageDTO }) => {
DeleteIcon.displayName = 'DeleteIcon';
const OpenInViewerIconButton = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const imageViewer = useImageViewer();
const { t } = useTranslation();
@@ -258,4 +210,77 @@ const OpenInViewerIconButton = ({ imageDTO }: { imageDTO: ImageDTO }) => {
insetInlineStart={2}
/>
);
};
});
OpenInViewerIconButton.displayName = 'OpenInViewerIconButton';
const StarIcon = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const customStarUi = useStore($customStarUI);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
const toggleStarredState = useCallback(() => {
if (imageDTO) {
if (imageDTO.starred) {
unstarImages({ imageDTOs: [imageDTO] });
}
if (!imageDTO.starred) {
starImages({ imageDTOs: [imageDTO] });
}
}
}, [starImages, unstarImages, imageDTO]);
const starIcon = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.on.icon : <PiStarFill />;
}
if (!imageDTO.starred) {
return customStarUi ? customStarUi.off.icon : <PiStarBold />;
}
}, [imageDTO.starred, customStarUi]);
const starTooltip = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.off.text : 'Unstar';
}
if (!imageDTO.starred) {
return customStarUi ? customStarUi.on.text : 'Star';
}
return '';
}, [imageDTO.starred, customStarUi]);
return (
<IAIDndImageIcon
onClick={toggleStarredState}
icon={starIcon}
tooltip={starTooltip}
position="absolute"
top={2}
insetInlineEnd={2}
/>
);
});
StarIcon.displayName = 'StarIcon';
const SizeBadge = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
return (
<Text
position="absolute"
background="base.900"
color="base.50"
fontSize="sm"
fontWeight="semibold"
bottom={1}
left={1}
opacity={0.7}
px={2}
lineHeight={1.25}
borderTopEndRadius="base"
sx={badgeSx}
pointerEvents="none"
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
);
});
SizeBadge.displayName = 'SizeBadge';

View File

@@ -1,11 +1,11 @@
import { Button, Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
import { ELLIPSIS, useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
import { JumpTo } from './JumpTo';
export const GalleryPagination = () => {
export const GalleryPagination = memo(() => {
const { goPrev, goNext, isPrevEnabled, isNextEnabled, pageButtons, goToPage, currentPage, total } =
useGalleryPagination();
@@ -47,7 +47,9 @@ export const GalleryPagination = () => {
<JumpTo />
</Flex>
);
};
});
GalleryPagination.displayName = 'GalleryPagination';
type PageButtonProps = {
page: number | typeof ELLIPSIS;
@@ -55,7 +57,7 @@ type PageButtonProps = {
goToPage: (page: number) => void;
};
const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
const PageButton = memo(({ page, currentPage, goToPage }: PageButtonProps) => {
if (page === ELLIPSIS) {
return (
<Button size="sm" variant="link" isDisabled>
@@ -68,4 +70,6 @@ const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
{page}
</Button>
);
};
});
PageButton.displayName = 'PageButton';

View File

@@ -11,11 +11,11 @@ import {
useDisclosure,
} from '@invoke-ai/ui-library';
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { useCallback, useEffect, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
export const JumpTo = () => {
export const JumpTo = memo(() => {
const { t } = useTranslation();
const { goToPage, currentPage, pages } = useGalleryPagination();
const [newPage, setNewPage] = useState(currentPage);
@@ -64,7 +64,7 @@ export const JumpTo = () => {
}, [currentPage]);
return (
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen}>
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen} isLazy lazyBehavior="unmount">
<PopoverTrigger>
<Button aria-label={t('gallery.jump')} size="sm" onClick={onToggle} variant="outline">
{t('gallery.jump')}
@@ -94,4 +94,6 @@ export const JumpTo = () => {
</PopoverContent>
</Popover>
);
};
});
JumpTo.displayName = 'JumpTo';

View File

@@ -32,6 +32,8 @@ export const selectListImagesQueryArgs = createMemoizedSelector(
export const selectListBoardsQueryArgs = createMemoizedSelector(
selectGallerySlice,
(gallery): ListBoardsArgs => ({
order_by: gallery.boardsListOrderBy,
direction: gallery.boardsListOrderDir,
include_archived: gallery.shouldShowArchivedBoards ? true : undefined,
})
);
@@ -44,6 +46,9 @@ export const selectAutoAssignBoardOnClick = createSelector(
);
export const selectBoardSearchText = createSelector(selectGallerySlice, (gallery) => gallery.boardSearchText);
export const selectSearchTerm = createSelector(selectGallerySlice, (gallery) => gallery.searchTerm);
export const selectBoardsListOrderBy = createSelector(selectGallerySlice, (gallery) => gallery.boardsListOrderBy);
export const selectBoardsListOrderDir = createSelector(selectGallerySlice, (gallery) => gallery.boardsListOrderDir);
export const selectSelectionCount = createSelector(selectGallerySlice, (gallery) => gallery.selection.length);
export const selectHasMultipleImagesSelected = createSelector(selectSelectionCount, (count) => count > 1);
export const selectGalleryImageMinimumWidth = createSelector(

View File

@@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { isEqual, uniqBy } from 'lodash-es';
import type { ImageDTO } from 'services/api/types';
import type { BoardRecordOrderBy, ImageDTO } from 'services/api/types';
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types';
@@ -25,6 +25,8 @@ const initialGalleryState: GalleryState = {
comparisonMode: 'slider',
comparisonFit: 'fill',
shouldShowArchivedBoards: false,
boardsListOrderBy: 'created_at',
boardsListOrderDir: 'DESC',
};
export const gallerySlice = createSlice({
@@ -161,6 +163,12 @@ export const gallerySlice = createSlice({
state.searchTerm = action.payload;
state.offset = 0;
},
boardsListOrderByChanged: (state, action: PayloadAction<BoardRecordOrderBy>) => {
state.boardsListOrderBy = action.payload;
},
boardsListOrderDirChanged: (state, action: PayloadAction<OrderDir>) => {
state.boardsListOrderDir = action.payload;
},
},
});
@@ -186,6 +194,8 @@ export const {
starredFirstChanged,
shouldShowArchivedBoardsChanged,
searchTermChanged,
boardsListOrderByChanged,
boardsListOrderDirChanged,
} = gallerySlice.actions;
export const selectGallerySlice = (state: RootState) => state.gallery;

View File

@@ -1,4 +1,4 @@
import type { ImageCategory, ImageDTO } from 'services/api/types';
import type { BoardRecordOrderBy, ImageCategory, ImageDTO } from 'services/api/types';
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
@@ -28,4 +28,6 @@ export type GalleryState = {
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
shouldShowArchivedBoards: boolean;
boardsListOrderBy: BoardRecordOrderBy;
boardsListOrderDir: OrderDir;
};

View File

@@ -1,7 +1,6 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
import type { BoardFieldInputInstance, BoardFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
@@ -14,26 +13,28 @@ const BoardFieldInputComponent = (props: FieldComponentProps<BoardFieldInputInst
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { options, hasBoards } = useListAllBoardsQuery(queryArgs, {
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: 'None',
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
});
const { options, hasBoards } = useListAllBoardsQuery(
{ include_archived: true },
{
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: 'None',
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
}
);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -43,7 +43,7 @@ export const ShareWorkflowModal = () => {
if (!workflowToShare || !projectUrl) {
return null;
}
return `${projectUrl}/studio?selectedWorkflowId=${workflowToShare.workflow_id}`;
return `${window.location.origin}/${projectUrl}/studio?selectedWorkflowId=${workflowToShare.workflow_id}`;
}, [projectUrl, workflowToShare]);
const handleCopy = useCallback(() => {

View File

@@ -36,8 +36,6 @@ export const addControlNets = async (
};
for (const layer of validControlLayers) {
result.addedControlNets++;
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
@@ -50,6 +48,7 @@ export const addControlNets = async (
const imageDTO = getImageDTOResult.value;
addControlNetToGraph(g, layer, imageDTO, collector);
result.addedControlNets++;
}
return result;
@@ -77,8 +76,6 @@ export const addT2IAdapters = async (
};
for (const layer of validControlLayers) {
result.addedT2IAdapters++;
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
@@ -91,6 +88,7 @@ export const addT2IAdapters = async (
const imageDTO = getImageDTOResult.value;
addT2IAdapterToGraph(g, layer, imageDTO, collector);
result.addedT2IAdapters++;
}
return result;
@@ -110,10 +108,10 @@ const addControlNetToGraph = (
const controlNet = g.addNode({
id: `control_net_${id}`,
type: 'controlnet',
type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet',
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
control_mode: controlMode,
control_mode: model.base === 'flux' ? undefined : controlMode,
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,

View File

@@ -19,6 +19,8 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addControlNets } from './addControlAdapters';
const log = logger('system');
export const buildFLUXGraph = async (
@@ -93,6 +95,7 @@ export const buildFLUXGraph = async (
> = l2i;
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', noise, 'controlnet_vae');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
@@ -177,6 +180,24 @@ export const buildFLUXGraph = async (
);
}
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
manager,
canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', noise, 'control');
} else {
g.deleteNode(controlNetCollector.id);
}
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}

View File

@@ -1,5 +1,4 @@
import { getStore } from 'app/store/nanostores/store';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
@@ -44,10 +43,9 @@ export const checkImageAccess = async (name: string): Promise<boolean> => {
* @returns A promise that resolves to true if the client has access, else false.
*/
export const checkBoardAccess = async (id: string): Promise<boolean> => {
const { dispatch, getState } = getStore();
const { dispatch } = getStore();
try {
const queryArgs = selectListBoardsQueryArgs(getState());
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate(queryArgs));
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate({ include_archived: true }));
req.unsubscribe();
const result = await req.unwrap();
return result.some((b) => b.board_id === id);

View File

@@ -1,19 +1,19 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import type { BoardId } from 'features/gallery/store/types';
import { t } from 'i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
export const useBoardName = (board_id: BoardId) => {
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { boardName } = useListAllBoardsQuery(queryArgs, {
selectFromResult: ({ data }) => {
const selectedBoard = data?.find((b) => b.board_id === board_id);
const boardName = selectedBoard?.board_name || t('boards.uncategorized');
const { boardName } = useListAllBoardsQuery(
{ include_archived: true },
{
selectFromResult: ({ data }) => {
const selectedBoard = data?.find((b) => b.board_id === board_id);
const boardName = selectedBoard?.board_name || t('boards.uncategorized');
return { boardName };
},
});
return { boardName };
},
}
);
return boardName;
};

File diff suppressed because one or more lines are too long

View File

@@ -241,3 +241,5 @@ export type PostUploadAction =
| RGIPAdapterImagePostUploadAction
| UpscaleInitialImageAction
| ReplaceLayerWithImagePostUploadAction;
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];

View File

@@ -24,6 +24,15 @@ export default defineConfig(({ mode }) => {
cssInjectedByJsPlugin(),
],
build: {
/**
* zone.js (via faro) requires max ES2015 to prevent spamming unhandled promise rejections.
*
* See:
* - https://github.com/grafana/faro-web-sdk/issues/566
* - https://github.com/angular/angular/issues/51328
* - https://github.com/open-telemetry/opentelemetry-js/issues/3030
*/
target: 'ES2015',
cssCodeSplit: true,
lib: {
entry: path.resolve(__dirname, './src/index.ts'),

View File

@@ -1 +1 @@
__version__ = "5.1.1"
__version__ = "5.2.0rc1"

View File

@@ -43,8 +43,8 @@ dependencies = [
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe>=0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
"onnx>=1.15.0",
"onnxruntime>=1.16.3",
"onnx==1.16.1",
"onnxruntime==1.19.2",
"opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3",
"safetensors==0.4.3",

View File

@@ -0,0 +1,30 @@
import argparse
import json
from safetensors.torch import load_file
def extract_sd_keys_and_shapes(safetensors_file: str):
sd = load_file(safetensors_file)
keys_to_shapes = {k: v.shape for k, v in sd.items()}
out_file = "keys_and_shapes.json"
with open(out_file, "w") as f:
json.dump(keys_to_shapes, f, indent=4)
print(f"Keys and shapes written to '{out_file}'.")
def main():
parser = argparse.ArgumentParser(
description="Extracts the keys and shapes from the state dict in a safetensors file. Intended for creating "
+ "dummy state dicts for use in unit tests."
)
parser.add_argument("safetensors_file", type=str, help="Path to the safetensors file.")
args = parser.parse_args()
extract_sd_keys_and_shapes(args.safetensors_file)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,374 @@
# State dict keys and shapes for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors
instantx_sd_shapes = {
"context_embedder.bias": [3072],
"context_embedder.weight": [3072, 4096],
"controlnet_blocks.0.bias": [3072],
"controlnet_blocks.0.weight": [3072, 3072],
"controlnet_blocks.1.bias": [3072],
"controlnet_blocks.1.weight": [3072, 3072],
"controlnet_blocks.2.bias": [3072],
"controlnet_blocks.2.weight": [3072, 3072],
"controlnet_blocks.3.bias": [3072],
"controlnet_blocks.3.weight": [3072, 3072],
"controlnet_blocks.4.bias": [3072],
"controlnet_blocks.4.weight": [3072, 3072],
"controlnet_mode_embedder.weight": [10, 3072],
"controlnet_single_blocks.0.bias": [3072],
"controlnet_single_blocks.0.weight": [3072, 3072],
"controlnet_single_blocks.1.bias": [3072],
"controlnet_single_blocks.1.weight": [3072, 3072],
"controlnet_single_blocks.2.bias": [3072],
"controlnet_single_blocks.2.weight": [3072, 3072],
"controlnet_single_blocks.3.bias": [3072],
"controlnet_single_blocks.3.weight": [3072, 3072],
"controlnet_single_blocks.4.bias": [3072],
"controlnet_single_blocks.4.weight": [3072, 3072],
"controlnet_single_blocks.5.bias": [3072],
"controlnet_single_blocks.5.weight": [3072, 3072],
"controlnet_single_blocks.6.bias": [3072],
"controlnet_single_blocks.6.weight": [3072, 3072],
"controlnet_single_blocks.7.bias": [3072],
"controlnet_single_blocks.7.weight": [3072, 3072],
"controlnet_single_blocks.8.bias": [3072],
"controlnet_single_blocks.8.weight": [3072, 3072],
"controlnet_single_blocks.9.bias": [3072],
"controlnet_single_blocks.9.weight": [3072, 3072],
"controlnet_x_embedder.bias": [3072],
"controlnet_x_embedder.weight": [3072, 64],
"single_transformer_blocks.0.attn.norm_k.weight": [128],
"single_transformer_blocks.0.attn.norm_q.weight": [128],
"single_transformer_blocks.0.attn.to_k.bias": [3072],
"single_transformer_blocks.0.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.0.attn.to_q.bias": [3072],
"single_transformer_blocks.0.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.0.attn.to_v.bias": [3072],
"single_transformer_blocks.0.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.0.norm.linear.bias": [9216],
"single_transformer_blocks.0.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.0.proj_mlp.bias": [12288],
"single_transformer_blocks.0.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.0.proj_out.bias": [3072],
"single_transformer_blocks.0.proj_out.weight": [3072, 15360],
"single_transformer_blocks.1.attn.norm_k.weight": [128],
"single_transformer_blocks.1.attn.norm_q.weight": [128],
"single_transformer_blocks.1.attn.to_k.bias": [3072],
"single_transformer_blocks.1.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.1.attn.to_q.bias": [3072],
"single_transformer_blocks.1.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.1.attn.to_v.bias": [3072],
"single_transformer_blocks.1.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.1.norm.linear.bias": [9216],
"single_transformer_blocks.1.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.1.proj_mlp.bias": [12288],
"single_transformer_blocks.1.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.1.proj_out.bias": [3072],
"single_transformer_blocks.1.proj_out.weight": [3072, 15360],
"single_transformer_blocks.2.attn.norm_k.weight": [128],
"single_transformer_blocks.2.attn.norm_q.weight": [128],
"single_transformer_blocks.2.attn.to_k.bias": [3072],
"single_transformer_blocks.2.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.2.attn.to_q.bias": [3072],
"single_transformer_blocks.2.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.2.attn.to_v.bias": [3072],
"single_transformer_blocks.2.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.2.norm.linear.bias": [9216],
"single_transformer_blocks.2.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.2.proj_mlp.bias": [12288],
"single_transformer_blocks.2.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.2.proj_out.bias": [3072],
"single_transformer_blocks.2.proj_out.weight": [3072, 15360],
"single_transformer_blocks.3.attn.norm_k.weight": [128],
"single_transformer_blocks.3.attn.norm_q.weight": [128],
"single_transformer_blocks.3.attn.to_k.bias": [3072],
"single_transformer_blocks.3.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.3.attn.to_q.bias": [3072],
"single_transformer_blocks.3.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.3.attn.to_v.bias": [3072],
"single_transformer_blocks.3.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.3.norm.linear.bias": [9216],
"single_transformer_blocks.3.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.3.proj_mlp.bias": [12288],
"single_transformer_blocks.3.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.3.proj_out.bias": [3072],
"single_transformer_blocks.3.proj_out.weight": [3072, 15360],
"single_transformer_blocks.4.attn.norm_k.weight": [128],
"single_transformer_blocks.4.attn.norm_q.weight": [128],
"single_transformer_blocks.4.attn.to_k.bias": [3072],
"single_transformer_blocks.4.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.4.attn.to_q.bias": [3072],
"single_transformer_blocks.4.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.4.attn.to_v.bias": [3072],
"single_transformer_blocks.4.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.4.norm.linear.bias": [9216],
"single_transformer_blocks.4.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.4.proj_mlp.bias": [12288],
"single_transformer_blocks.4.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.4.proj_out.bias": [3072],
"single_transformer_blocks.4.proj_out.weight": [3072, 15360],
"single_transformer_blocks.5.attn.norm_k.weight": [128],
"single_transformer_blocks.5.attn.norm_q.weight": [128],
"single_transformer_blocks.5.attn.to_k.bias": [3072],
"single_transformer_blocks.5.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.5.attn.to_q.bias": [3072],
"single_transformer_blocks.5.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.5.attn.to_v.bias": [3072],
"single_transformer_blocks.5.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.5.norm.linear.bias": [9216],
"single_transformer_blocks.5.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.5.proj_mlp.bias": [12288],
"single_transformer_blocks.5.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.5.proj_out.bias": [3072],
"single_transformer_blocks.5.proj_out.weight": [3072, 15360],
"single_transformer_blocks.6.attn.norm_k.weight": [128],
"single_transformer_blocks.6.attn.norm_q.weight": [128],
"single_transformer_blocks.6.attn.to_k.bias": [3072],
"single_transformer_blocks.6.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.6.attn.to_q.bias": [3072],
"single_transformer_blocks.6.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.6.attn.to_v.bias": [3072],
"single_transformer_blocks.6.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.6.norm.linear.bias": [9216],
"single_transformer_blocks.6.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.6.proj_mlp.bias": [12288],
"single_transformer_blocks.6.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.6.proj_out.bias": [3072],
"single_transformer_blocks.6.proj_out.weight": [3072, 15360],
"single_transformer_blocks.7.attn.norm_k.weight": [128],
"single_transformer_blocks.7.attn.norm_q.weight": [128],
"single_transformer_blocks.7.attn.to_k.bias": [3072],
"single_transformer_blocks.7.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.7.attn.to_q.bias": [3072],
"single_transformer_blocks.7.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.7.attn.to_v.bias": [3072],
"single_transformer_blocks.7.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.7.norm.linear.bias": [9216],
"single_transformer_blocks.7.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.7.proj_mlp.bias": [12288],
"single_transformer_blocks.7.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.7.proj_out.bias": [3072],
"single_transformer_blocks.7.proj_out.weight": [3072, 15360],
"single_transformer_blocks.8.attn.norm_k.weight": [128],
"single_transformer_blocks.8.attn.norm_q.weight": [128],
"single_transformer_blocks.8.attn.to_k.bias": [3072],
"single_transformer_blocks.8.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.8.attn.to_q.bias": [3072],
"single_transformer_blocks.8.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.8.attn.to_v.bias": [3072],
"single_transformer_blocks.8.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.8.norm.linear.bias": [9216],
"single_transformer_blocks.8.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.8.proj_mlp.bias": [12288],
"single_transformer_blocks.8.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.8.proj_out.bias": [3072],
"single_transformer_blocks.8.proj_out.weight": [3072, 15360],
"single_transformer_blocks.9.attn.norm_k.weight": [128],
"single_transformer_blocks.9.attn.norm_q.weight": [128],
"single_transformer_blocks.9.attn.to_k.bias": [3072],
"single_transformer_blocks.9.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.9.attn.to_q.bias": [3072],
"single_transformer_blocks.9.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.9.attn.to_v.bias": [3072],
"single_transformer_blocks.9.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.9.norm.linear.bias": [9216],
"single_transformer_blocks.9.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.9.proj_mlp.bias": [12288],
"single_transformer_blocks.9.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.9.proj_out.bias": [3072],
"single_transformer_blocks.9.proj_out.weight": [3072, 15360],
"time_text_embed.guidance_embedder.linear_1.bias": [3072],
"time_text_embed.guidance_embedder.linear_1.weight": [3072, 256],
"time_text_embed.guidance_embedder.linear_2.bias": [3072],
"time_text_embed.guidance_embedder.linear_2.weight": [3072, 3072],
"time_text_embed.text_embedder.linear_1.bias": [3072],
"time_text_embed.text_embedder.linear_1.weight": [3072, 768],
"time_text_embed.text_embedder.linear_2.bias": [3072],
"time_text_embed.text_embedder.linear_2.weight": [3072, 3072],
"time_text_embed.timestep_embedder.linear_1.bias": [3072],
"time_text_embed.timestep_embedder.linear_1.weight": [3072, 256],
"time_text_embed.timestep_embedder.linear_2.bias": [3072],
"time_text_embed.timestep_embedder.linear_2.weight": [3072, 3072],
"transformer_blocks.0.attn.add_k_proj.bias": [3072],
"transformer_blocks.0.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.add_q_proj.bias": [3072],
"transformer_blocks.0.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.add_v_proj.bias": [3072],
"transformer_blocks.0.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.norm_added_k.weight": [128],
"transformer_blocks.0.attn.norm_added_q.weight": [128],
"transformer_blocks.0.attn.norm_k.weight": [128],
"transformer_blocks.0.attn.norm_q.weight": [128],
"transformer_blocks.0.attn.to_add_out.bias": [3072],
"transformer_blocks.0.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.0.attn.to_k.bias": [3072],
"transformer_blocks.0.attn.to_k.weight": [3072, 3072],
"transformer_blocks.0.attn.to_out.0.bias": [3072],
"transformer_blocks.0.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.0.attn.to_q.bias": [3072],
"transformer_blocks.0.attn.to_q.weight": [3072, 3072],
"transformer_blocks.0.attn.to_v.bias": [3072],
"transformer_blocks.0.attn.to_v.weight": [3072, 3072],
"transformer_blocks.0.ff.net.0.proj.bias": [12288],
"transformer_blocks.0.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.0.ff.net.2.bias": [3072],
"transformer_blocks.0.ff.net.2.weight": [3072, 12288],
"transformer_blocks.0.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.0.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.0.ff_context.net.2.bias": [3072],
"transformer_blocks.0.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.0.norm1.linear.bias": [18432],
"transformer_blocks.0.norm1.linear.weight": [18432, 3072],
"transformer_blocks.0.norm1_context.linear.bias": [18432],
"transformer_blocks.0.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.1.attn.add_k_proj.bias": [3072],
"transformer_blocks.1.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.add_q_proj.bias": [3072],
"transformer_blocks.1.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.add_v_proj.bias": [3072],
"transformer_blocks.1.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.norm_added_k.weight": [128],
"transformer_blocks.1.attn.norm_added_q.weight": [128],
"transformer_blocks.1.attn.norm_k.weight": [128],
"transformer_blocks.1.attn.norm_q.weight": [128],
"transformer_blocks.1.attn.to_add_out.bias": [3072],
"transformer_blocks.1.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.1.attn.to_k.bias": [3072],
"transformer_blocks.1.attn.to_k.weight": [3072, 3072],
"transformer_blocks.1.attn.to_out.0.bias": [3072],
"transformer_blocks.1.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.1.attn.to_q.bias": [3072],
"transformer_blocks.1.attn.to_q.weight": [3072, 3072],
"transformer_blocks.1.attn.to_v.bias": [3072],
"transformer_blocks.1.attn.to_v.weight": [3072, 3072],
"transformer_blocks.1.ff.net.0.proj.bias": [12288],
"transformer_blocks.1.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.1.ff.net.2.bias": [3072],
"transformer_blocks.1.ff.net.2.weight": [3072, 12288],
"transformer_blocks.1.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.1.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.1.ff_context.net.2.bias": [3072],
"transformer_blocks.1.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.1.norm1.linear.bias": [18432],
"transformer_blocks.1.norm1.linear.weight": [18432, 3072],
"transformer_blocks.1.norm1_context.linear.bias": [18432],
"transformer_blocks.1.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.2.attn.add_k_proj.bias": [3072],
"transformer_blocks.2.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.add_q_proj.bias": [3072],
"transformer_blocks.2.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.add_v_proj.bias": [3072],
"transformer_blocks.2.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.norm_added_k.weight": [128],
"transformer_blocks.2.attn.norm_added_q.weight": [128],
"transformer_blocks.2.attn.norm_k.weight": [128],
"transformer_blocks.2.attn.norm_q.weight": [128],
"transformer_blocks.2.attn.to_add_out.bias": [3072],
"transformer_blocks.2.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.2.attn.to_k.bias": [3072],
"transformer_blocks.2.attn.to_k.weight": [3072, 3072],
"transformer_blocks.2.attn.to_out.0.bias": [3072],
"transformer_blocks.2.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.2.attn.to_q.bias": [3072],
"transformer_blocks.2.attn.to_q.weight": [3072, 3072],
"transformer_blocks.2.attn.to_v.bias": [3072],
"transformer_blocks.2.attn.to_v.weight": [3072, 3072],
"transformer_blocks.2.ff.net.0.proj.bias": [12288],
"transformer_blocks.2.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.2.ff.net.2.bias": [3072],
"transformer_blocks.2.ff.net.2.weight": [3072, 12288],
"transformer_blocks.2.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.2.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.2.ff_context.net.2.bias": [3072],
"transformer_blocks.2.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.2.norm1.linear.bias": [18432],
"transformer_blocks.2.norm1.linear.weight": [18432, 3072],
"transformer_blocks.2.norm1_context.linear.bias": [18432],
"transformer_blocks.2.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.3.attn.add_k_proj.bias": [3072],
"transformer_blocks.3.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.add_q_proj.bias": [3072],
"transformer_blocks.3.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.add_v_proj.bias": [3072],
"transformer_blocks.3.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.norm_added_k.weight": [128],
"transformer_blocks.3.attn.norm_added_q.weight": [128],
"transformer_blocks.3.attn.norm_k.weight": [128],
"transformer_blocks.3.attn.norm_q.weight": [128],
"transformer_blocks.3.attn.to_add_out.bias": [3072],
"transformer_blocks.3.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.3.attn.to_k.bias": [3072],
"transformer_blocks.3.attn.to_k.weight": [3072, 3072],
"transformer_blocks.3.attn.to_out.0.bias": [3072],
"transformer_blocks.3.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.3.attn.to_q.bias": [3072],
"transformer_blocks.3.attn.to_q.weight": [3072, 3072],
"transformer_blocks.3.attn.to_v.bias": [3072],
"transformer_blocks.3.attn.to_v.weight": [3072, 3072],
"transformer_blocks.3.ff.net.0.proj.bias": [12288],
"transformer_blocks.3.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.3.ff.net.2.bias": [3072],
"transformer_blocks.3.ff.net.2.weight": [3072, 12288],
"transformer_blocks.3.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.3.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.3.ff_context.net.2.bias": [3072],
"transformer_blocks.3.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.3.norm1.linear.bias": [18432],
"transformer_blocks.3.norm1.linear.weight": [18432, 3072],
"transformer_blocks.3.norm1_context.linear.bias": [18432],
"transformer_blocks.3.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.4.attn.add_k_proj.bias": [3072],
"transformer_blocks.4.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.add_q_proj.bias": [3072],
"transformer_blocks.4.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.add_v_proj.bias": [3072],
"transformer_blocks.4.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.norm_added_k.weight": [128],
"transformer_blocks.4.attn.norm_added_q.weight": [128],
"transformer_blocks.4.attn.norm_k.weight": [128],
"transformer_blocks.4.attn.norm_q.weight": [128],
"transformer_blocks.4.attn.to_add_out.bias": [3072],
"transformer_blocks.4.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.4.attn.to_k.bias": [3072],
"transformer_blocks.4.attn.to_k.weight": [3072, 3072],
"transformer_blocks.4.attn.to_out.0.bias": [3072],
"transformer_blocks.4.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.4.attn.to_q.bias": [3072],
"transformer_blocks.4.attn.to_q.weight": [3072, 3072],
"transformer_blocks.4.attn.to_v.bias": [3072],
"transformer_blocks.4.attn.to_v.weight": [3072, 3072],
"transformer_blocks.4.ff.net.0.proj.bias": [12288],
"transformer_blocks.4.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.4.ff.net.2.bias": [3072],
"transformer_blocks.4.ff.net.2.weight": [3072, 12288],
"transformer_blocks.4.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.4.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.4.ff_context.net.2.bias": [3072],
"transformer_blocks.4.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.4.norm1.linear.bias": [18432],
"transformer_blocks.4.norm1.linear.weight": [18432, 3072],
"transformer_blocks.4.norm1_context.linear.bias": [18432],
"transformer_blocks.4.norm1_context.linear.weight": [18432, 3072],
"x_embedder.bias": [3072],
"x_embedder.weight": [3072, 64],
}
# InstantX FLUX ControlNet config for unit tests.
# Copied from https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/main/config.json
instantx_config = {
"_class_name": "FluxControlNetModel",
"_diffusers_version": "0.30.0.dev0",
"_name_or_path": "/mnt/wangqixun/",
"attention_head_dim": 128,
"axes_dims_rope": [16, 56, 56],
"guidance_embeds": True,
"in_channels": 64,
"joint_attention_dim": 4096,
"num_attention_heads": 24,
"num_layers": 5,
"num_mode": 10,
"num_single_layers": 10,
"patch_size": 1,
"pooled_projection_dim": 768,
}

View File

@@ -0,0 +1,108 @@
import sys
import pytest
import torch
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
infer_instantx_num_control_modes_from_state_dict,
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_config, instantx_sd_shapes
from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_sd_shapes
@pytest.mark.parametrize(
["sd_shapes", "expected"],
[
(xlabs_sd_shapes, True),
(instantx_sd_shapes, False),
(["foo"], False),
],
)
def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool):
sd = {k: None for k in sd_shapes}
assert is_state_dict_xlabs_controlnet(sd) == expected
@pytest.mark.parametrize(
["sd_keys", "expected"],
[
(instantx_sd_shapes, True),
(xlabs_sd_shapes, False),
(["foo"], False),
],
)
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
sd = {k: None for k in sd_keys}
assert is_state_dict_instantx_controlnet(sd) == expected
def test_convert_diffusers_instantx_state_dict_to_bfl_format():
"""Smoke test convert_diffusers_instantx_state_dict_to_bfl_format() to ensure that it handles all of the keys."""
sd = {k: torch.zeros(1) for k in instantx_sd_shapes}
bfl_sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
assert bfl_sd is not None
# TODO(ryand): Figure out why some tests in this file are failing on the MacOS CI runners. It seems to be related to
# using the meta device. I can't reproduce the issue on my local MacOS system.
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_infer_flux_params_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
assert flux_params.in_channels == instantx_config["in_channels"]
assert flux_params.vec_in_dim == instantx_config["pooled_projection_dim"]
assert flux_params.context_in_dim == instantx_config["joint_attention_dim"]
assert flux_params.hidden_size // flux_params.num_heads == instantx_config["attention_head_dim"]
assert flux_params.num_heads == instantx_config["num_attention_heads"]
assert flux_params.mlp_ratio == 4
assert flux_params.depth == instantx_config["num_layers"]
assert flux_params.depth_single_blocks == instantx_config["num_single_layers"]
assert flux_params.axes_dim == instantx_config["axes_dims_rope"]
assert flux_params.theta == 10000
assert flux_params.qkv_bias
assert flux_params.guidance_embed == instantx_config["guidance_embeds"]
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_infer_instantx_num_control_modes_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
assert num_control_modes == instantx_config["num_mode"]
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_load_instantx_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with torch.device("meta"):
model = InstantXControlNetFlux(flux_params, num_control_modes)
model_sd = model.state_dict()
assert set(model_sd.keys()) == set(sd.keys())
for key, tensor in model_sd.items():
assert isinstance(tensor, torch.Tensor)
assert tensor.shape == sd[key].shape

View File

@@ -0,0 +1,91 @@
# State dict keys and shapes for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
xlabs_sd_shapes = {
"controlnet_blocks.0.bias": [3072],
"controlnet_blocks.0.weight": [3072, 3072],
"controlnet_blocks.1.bias": [3072],
"controlnet_blocks.1.weight": [3072, 3072],
"double_blocks.0.img_attn.norm.key_norm.scale": [128],
"double_blocks.0.img_attn.norm.query_norm.scale": [128],
"double_blocks.0.img_attn.proj.bias": [3072],
"double_blocks.0.img_attn.proj.weight": [3072, 3072],
"double_blocks.0.img_attn.qkv.bias": [9216],
"double_blocks.0.img_attn.qkv.weight": [9216, 3072],
"double_blocks.0.img_mlp.0.bias": [12288],
"double_blocks.0.img_mlp.0.weight": [12288, 3072],
"double_blocks.0.img_mlp.2.bias": [3072],
"double_blocks.0.img_mlp.2.weight": [3072, 12288],
"double_blocks.0.img_mod.lin.bias": [18432],
"double_blocks.0.img_mod.lin.weight": [18432, 3072],
"double_blocks.0.txt_attn.norm.key_norm.scale": [128],
"double_blocks.0.txt_attn.norm.query_norm.scale": [128],
"double_blocks.0.txt_attn.proj.bias": [3072],
"double_blocks.0.txt_attn.proj.weight": [3072, 3072],
"double_blocks.0.txt_attn.qkv.bias": [9216],
"double_blocks.0.txt_attn.qkv.weight": [9216, 3072],
"double_blocks.0.txt_mlp.0.bias": [12288],
"double_blocks.0.txt_mlp.0.weight": [12288, 3072],
"double_blocks.0.txt_mlp.2.bias": [3072],
"double_blocks.0.txt_mlp.2.weight": [3072, 12288],
"double_blocks.0.txt_mod.lin.bias": [18432],
"double_blocks.0.txt_mod.lin.weight": [18432, 3072],
"double_blocks.1.img_attn.norm.key_norm.scale": [128],
"double_blocks.1.img_attn.norm.query_norm.scale": [128],
"double_blocks.1.img_attn.proj.bias": [3072],
"double_blocks.1.img_attn.proj.weight": [3072, 3072],
"double_blocks.1.img_attn.qkv.bias": [9216],
"double_blocks.1.img_attn.qkv.weight": [9216, 3072],
"double_blocks.1.img_mlp.0.bias": [12288],
"double_blocks.1.img_mlp.0.weight": [12288, 3072],
"double_blocks.1.img_mlp.2.bias": [3072],
"double_blocks.1.img_mlp.2.weight": [3072, 12288],
"double_blocks.1.img_mod.lin.bias": [18432],
"double_blocks.1.img_mod.lin.weight": [18432, 3072],
"double_blocks.1.txt_attn.norm.key_norm.scale": [128],
"double_blocks.1.txt_attn.norm.query_norm.scale": [128],
"double_blocks.1.txt_attn.proj.bias": [3072],
"double_blocks.1.txt_attn.proj.weight": [3072, 3072],
"double_blocks.1.txt_attn.qkv.bias": [9216],
"double_blocks.1.txt_attn.qkv.weight": [9216, 3072],
"double_blocks.1.txt_mlp.0.bias": [12288],
"double_blocks.1.txt_mlp.0.weight": [12288, 3072],
"double_blocks.1.txt_mlp.2.bias": [3072],
"double_blocks.1.txt_mlp.2.weight": [3072, 12288],
"double_blocks.1.txt_mod.lin.bias": [18432],
"double_blocks.1.txt_mod.lin.weight": [18432, 3072],
"guidance_in.in_layer.bias": [3072],
"guidance_in.in_layer.weight": [3072, 256],
"guidance_in.out_layer.bias": [3072],
"guidance_in.out_layer.weight": [3072, 3072],
"img_in.bias": [3072],
"img_in.weight": [3072, 64],
"input_hint_block.0.bias": [16],
"input_hint_block.0.weight": [16, 3, 3, 3],
"input_hint_block.10.bias": [16],
"input_hint_block.10.weight": [16, 16, 3, 3],
"input_hint_block.12.bias": [16],
"input_hint_block.12.weight": [16, 16, 3, 3],
"input_hint_block.14.bias": [16],
"input_hint_block.14.weight": [16, 16, 3, 3],
"input_hint_block.2.bias": [16],
"input_hint_block.2.weight": [16, 16, 3, 3],
"input_hint_block.4.bias": [16],
"input_hint_block.4.weight": [16, 16, 3, 3],
"input_hint_block.6.bias": [16],
"input_hint_block.6.weight": [16, 16, 3, 3],
"input_hint_block.8.bias": [16],
"input_hint_block.8.weight": [16, 16, 3, 3],
"pos_embed_input.bias": [3072],
"pos_embed_input.weight": [3072, 64],
"time_in.in_layer.bias": [3072],
"time_in.in_layer.weight": [3072, 256],
"time_in.out_layer.bias": [3072],
"time_in.out_layer.weight": [3072, 3072],
"txt_in.bias": [3072],
"txt_in.weight": [3072, 4096],
"vector_in.in_layer.bias": [3072],
"vector_in.in_layer.weight": [3072, 768],
"vector_in.out_layer.bias": [3072],
"vector_in.out_layer.weight": [3072, 3072],
}