Compare commits

..

135 Commits

Author SHA1 Message Date
Brandon Rising
596dc6b4e3 Setup Model and T5 Encoder selection fields for sd3 nodes 2024-10-25 00:20:28 -04:00
Brandon Rising
b4f93a0ff5 Initial wave of frontend updates for sd-3 node inputs 2024-10-24 22:22:32 -04:00
Brandon Rising
f4be52abb4 define submodels on sd3 models during probe 2024-10-24 15:18:42 -04:00
Ryan Dick
4e029331ba Add tqdm progress bar for SD3. 2024-10-24 16:04:37 +00:00
Ryan Dick
40b4de5f77 Bug fixes to get SD3 text-to-image workflow running. 2024-10-24 15:55:17 +00:00
Ryan Dick
c930807881 Temporary hack for testing SD3 model loader. 2024-10-24 15:34:12 +00:00
Ryan Dick
bec8b27429 Fix Sd3TextEncoderInvocation output type. 2024-10-24 15:14:46 +00:00
Ryan Dick
ef4f466ccf Initial draft of SD3DenoiseInvocation. 2024-10-24 14:43:48 +00:00
Ryan Dick
3c869ee5ab Add first draft of Sd3TextEncoderInvocation. 2024-10-24 01:19:40 +00:00
Ryan Dick
16dc30fd5b Add Sd3ModelLoaderInvocation. 2024-10-24 00:17:19 +00:00
Ryan Dick
0c14192819 Move FluxModelLoaderInvocation to its own file. model.py was getting bloated. 2024-10-24 00:03:35 +00:00
Ryan Dick
36dadba45b Get diffusers SD3 model probing working. 2024-10-23 19:55:26 +00:00
Ryan Dick
f2a9c01d0e (minor) Remove unused dict. 2024-10-23 19:03:33 +00:00
Ryan Dick
1ca57ade4d Fix huggingface_hub.errors imports after version bump. 2024-10-23 18:29:24 +00:00
Ryan Dick
85c0e0db1e Fix changed import for FromOriginalControlNetMixin after diffusers bump. 2024-10-23 18:25:12 +00:00
Ryan Dick
59a2388585 Bump diffusers, accelerate, and huggingface-hub. 2024-10-23 18:09:35 +00:00
psychedelicious
3583d03b70 feat(ui): improve subs and cleanup in filterer module
- Subscribe when starting the filterer
- Remember to abort the abortcontroller when destroying
- Unsubscribe when destroying
2024-10-23 08:21:12 -04:00
psychedelicious
bc954b9996 feat(ui): abort controller in SAM module when destroying 2024-10-23 08:21:12 -04:00
psychedelicious
c08075946a feat(ui): only subscribe listeners when segmenting
Realized we are doing a lot of event listening even when segmenting is not occuring. I don't think this will have a meaningful performance impact, but it makes sense to remove these listeners when not in use.
2024-10-23 08:21:12 -04:00
psychedelicious
df8df914e8 docs(ui): add comments to CanvasSegmentAnythingModule 2024-10-23 08:21:12 -04:00
psychedelicious
33924e8491 feat(ui): ensure abort controllers are cleaned up 2024-10-23 08:21:12 -04:00
psychedelicious
7e5ce1d69d fix(ui): when last SAM point is deleted, reset ephemeral state 2024-10-23 08:21:12 -04:00
Riku
6a24594140 feat(ui): move model manager in-place install state to redux
- persists across sessions/refreshes
- shared state for all installers (local path, scan folder)
2024-10-23 21:17:31 +11:00
psychedelicious
61d26cffe6 chore: bump version to v5.3.0rc1 2024-10-23 16:11:20 +11:00
psychedelicious
fdbc244dbe tidy(ui): autoProcessFilter -> autoProcess
It's used for more than filters now.
2024-10-23 16:01:15 +11:00
psychedelicious
0eea84c90d chore(ui): lint 2024-10-23 16:01:15 +11:00
psychedelicious
e079a91800 feat(ui): reorder point type radios 2024-10-23 16:01:15 +11:00
psychedelicious
eb20173487 fix(ui): set hasProcessed on segment module when deleting a point 2024-10-23 16:01:15 +11:00
psychedelicious
20dd0779b5 feat(ui): use radio instead of drop-down for point label 2024-10-23 16:01:15 +11:00
psychedelicious
b384a92f5c fix(ui): let segment module handle cursor if segmenting 2024-10-23 16:01:15 +11:00
psychedelicious
116d32fbbe feat(ui): auto-process for segment anything 2024-10-23 16:01:15 +11:00
psychedelicious
b044f31a61 fix(ui): translation for isolated layer preview 2024-10-23 16:01:15 +11:00
psychedelicious
6c3c24403b feat(ui): rename "Segment" -> "Auto Mask" 2024-10-23 16:01:15 +11:00
psychedelicious
591f48bb95 chore(ui): lint 2024-10-23 16:01:15 +11:00
psychedelicious
dc6e45485c feat(ui): update CanvasSegmentAnythingModule for new nodes 2024-10-23 16:01:15 +11:00
psychedelicious
829820479d chore(ui): typegen 2024-10-23 16:01:15 +11:00
psychedelicious
48a471bfb8 fix(nodes): apply_tensor_mask_to_image transparent image handling
Fix an issue where if the input image is transparent in a region to be masked, that transparent region ends up opaque black. Need to respect the input image transparency by applying the mask to the alpha channel only.
2024-10-23 16:01:15 +11:00
psychedelicious
ff72315db2 feat(nodes): update SAM backend and nodes to work with SAM points 2024-10-23 16:01:15 +11:00
psychedelicious
790846297a feat(ui): add more data to canvas module reprs 2024-10-23 16:01:15 +11:00
psychedelicious
230b455a13 tidy(ui): $pointTypeEnglish -> $pointTypeString 2024-10-23 16:01:15 +11:00
psychedelicious
71f0fff55b fix(ui): right click on stage draws 2024-10-23 16:01:15 +11:00
psychedelicious
7f2c83b9e6 feat(ui): consolidate isolated preview settings
`isolatedFilteringPreview` and `isolatedTransformingPreview` are merged into `isolatedLayerPreview`. This is also used for segment anything.
2024-10-23 16:01:15 +11:00
psychedelicious
bc85bd4bd4 tidy(ui): clean up and document CanvasSegmentAnythingModule 2024-10-23 16:01:15 +11:00
psychedelicious
38b09d73e4 feat(ui): masking UX (wip - interaction state issue) 2024-10-23 16:01:15 +11:00
psychedelicious
606c4ae88c feat(ui): masking UX (wip - issue w/ positioning) 2024-10-23 16:01:15 +11:00
psychedelicious
f666bac77f tidy(ui): CanvasToolView -> CanvasViewToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
c9bf7da23a tidy(ui): CanvasToolRect -> CanvasRectToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
dfc65b93e9 tidy(ui): CanvasToolMove -> CanvasMoveToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
9ca40b4cf5 tidy(ui): CanvasToolErase -> CanvasEraserToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
d571e71d5e tidy(ui): CanvasToolColorPicker -> CanvasColorPickerToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
ad1e6c3fe6 tidy(ui): CanvasToolBrush -> CanvasBrushToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
21d02911dd tidy(ui): CanvasBboxModule -> CanvasBboxToolModule, move file 2024-10-23 16:01:15 +11:00
psychedelicious
43afe0bd9a feat(ui): move cursor handling to tool modules
Also add cursors for move tool and bbox tool - when pointer is over the layer or bbox, use the move cursor.
2024-10-23 16:01:15 +11:00
psychedelicious
e7a68c446d feat(ui): add CanvasToolView
It's nearly a noop but I think it makes sense to have a module for each tool...
2024-10-23 16:01:15 +11:00
psychedelicious
b9c68a2e7e feat(ui): add CanvasToolMove
It's essentially a noop but I think it makes sense to have a module for each tool...
2024-10-23 16:01:15 +11:00
psychedelicious
371a1b1af3 feat(ui): make CanvasBboxModule child of CanvasToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
dae4591de6 feat(ui): let tool modules set own visibility 2024-10-23 16:01:15 +11:00
psychedelicious
8ccb2e30ce feat(ui): bail on stage events when not targeting the stage 2024-10-23 16:01:15 +11:00
psychedelicious
b8106a4613 fix(ui): bail on drawing when mouse not down 2024-10-23 16:01:15 +11:00
psychedelicious
ce51e9582a feat(ui): add CanvasRectTool 2024-10-23 16:01:15 +11:00
psychedelicious
00848eb631 feat(ui): let color picker tool handle its events 2024-10-23 16:01:15 +11:00
psychedelicious
b48430a892 feat(ui): let eraser tool handle its events 2024-10-23 16:01:15 +11:00
psychedelicious
f94a218561 tidy(ui): remove extraneous checks from CanvasToolBrush 2024-10-23 16:01:15 +11:00
psychedelicious
9b6ed40875 fix(ui): edge case where pressure could be added erroneously to points 2024-10-23 16:01:15 +11:00
psychedelicious
26553dbb0e tidy(ui): CanvasToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
9eb695d0b4 docs(ui): update CanvasToolModule 2024-10-23 16:01:15 +11:00
psychedelicious
babab17e1d feat(ui): let brush tool handle its events
Move brush tool event logic to its class.
2024-10-23 16:01:15 +11:00
psychedelicious
d0a80f3347 feat(ui): create zCoordinateWithPressure & export type from canvas types 2024-10-23 16:01:15 +11:00
psychedelicious
9b30363177 tidy(ui): CanvasToolModule structure 2024-10-23 16:01:15 +11:00
psychedelicious
89bde36b0c feat(ui): support draggable SAM points 2024-10-23 16:01:15 +11:00
psychedelicious
86a8476d97 feat(ui): working segment anything flow 2024-10-23 16:01:15 +11:00
psychedelicious
afa0661e55 chore(ui): typegen 2024-10-23 16:01:15 +11:00
psychedelicious
ba09c1277f feat(nodes): hacked together nodes for segment anything w/ points 2024-10-23 16:01:15 +11:00
psychedelicious
80bf9ddb71 feat(ui): rough out points UI for segment anything module 2024-10-23 16:01:15 +11:00
psychedelicious
1dbc98d747 feat(ui): add CanvasSegmentAnythingModule (wip) 2024-10-23 16:01:15 +11:00
psychedelicious
0698188ea2 feat(ui): support readonly arrays in SerializableObject type 2024-10-23 16:01:15 +11:00
psychedelicious
59d0ad4505 chore(ui): migrate from ts-toolbelt to type-fest
`ts-toolbelt` is unmaintained while `type-fest` is very actively maintained. Both provide similar TS utilities.
2024-10-23 16:01:15 +11:00
Thomas Bolteau
074a5692dd translationBot(ui): update translation (French)
Currently translated at 100.0% (1509 of 1509 strings)

translationBot(ui): update translation (French)

Currently translated at 100.0% (1509 of 1509 strings)

Co-authored-by: Thomas Bolteau <thomas.bolteau50@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/fr/
Translation: InvokeAI/Web UI
2024-10-23 10:23:37 +11:00
Васянатор
bb0741146a translationBot(ui): update translation (Russian)
Currently translated at 99.6% (1504 of 1509 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-10-23 10:23:37 +11:00
Riccardo Giovanetti
1845d9a87a translationBot(ui): update translation (Italian)
Currently translated at 98.8% (1492 of 1509 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-10-23 10:23:37 +11:00
Riku
748c393e71 translationBot(ui): update translation (German)
Currently translated at 71.0% (1072 of 1509 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-10-23 10:23:37 +11:00
David Burnett
9bd17ea02f Get flux working with MPS on 2.4.1, with GGUF support 2024-10-23 10:20:42 +11:00
David Burnett
24f9b46fbc ruff fix 2024-10-23 10:09:24 +11:00
David Burnett
54b3aa1d01 load t5 model in the same format as it is saved, seems to load as float32 on Macs 2024-10-23 10:09:24 +11:00
Maximilian Maag
d85733f22b fix(installer): pytorch and ROCm versions are incompatible
Each version of torch is only available for specific versions of CUDA and ROCm.
The Invoke installer and dockerfile try to install torch 2.4.1 with ROCm 5.6
support, which does not exist. As a result, the installation falls back to the
default CUDA version so AMD GPUs aren't detected. This commits fixes that by
bumping the ROCm version to 6.1, as suggested by the PyTorch documentation. [1]

The specified CUDA version of 12.4 is still correct according to [1] so it does
need to be changed.

Closes #7006
Closes #7146

[1]: https://pytorch.org/get-started/previous-versions/#v241
2024-10-23 09:59:00 +11:00
psychedelicious
aff6ad0316 FLUX XLabs IP-Adapter Support (#7157)
## Summary

This PR adds support for the XLabs IP-Adapter
(https://huggingface.co/XLabs-AI/flux-ip-adapter) in workflows. Linear
UI integration is coming in a follow-up PR. The XLabs IP-Adapter can be
installed in the Starter Models tab.

Usage tips:

- Use a `cfg_scale` value of 2.0 to 4.0
- Start with an IP-Adatper weight of ~0.6 and adjust from there.
- Set `cfg_scale_start_step = 1`
- Set `cfg_scale_end_step` to roughly the halfway point (it's
unnecessary to apply CFG to all steps, and this will improve processing
time).

Sample workflow:
<img width="976" alt="image"
src="https://github.com/user-attachments/assets/4627b459-7e5a-4703-80e7-f7575c5fce19">

Result:

![image](https://github.com/user-attachments/assets/220b6a4c-69c6-447f-8df6-8aa6a56f3b3f)

## Related Issues / Discussions

Prerequisite: https://github.com/invoke-ai/InvokeAI/pull/7152

## Remaining TODO:

- [ ] Update default workflows.

## QA Instructions

- [x] Test basic happy path
- [x] Test with multiple IP-Adapters (it runs, but results aren't great)
- [ ] ~Test with multiple images to a single IP-Adapter~ (this is not
supported for now)
- [ ] Test automatic runtime installation of CLIP-L, CLIP-H, and CLIP-G
image encoder models if they are not already installed.
- [ ] Test starter model installation of the XLabs FLUX IP-Adapter
- [ ] Test SD and SDXL IP-Adapters for regression.
- [ ] Check peak memory utilization.

## Merge Plan

- [ ] Merge #7152 
- [ ] Change target branch to main

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-10-23 09:57:39 +11:00
psychedelicious
61496fdcbc fix(nodes): load IP Adapter images as RGB
FLUX IP Adapter only works with RGB. Did the same for non-FLUX to be safe & consistent, though I don't think it's strictly necessary.
2024-10-23 08:34:15 +10:00
psychedelicious
ee8975401a fix(ui): remove special handling for flux in IPAdapterModel
This masked an issue w/ the CLIP Vision model. Issue is now handled in reducer/graph builder.
2024-10-23 08:31:10 +10:00
psychedelicious
bf3260446d fix(ui): use flux_ip_adapter for flux 2024-10-23 08:30:11 +10:00
psychedelicious
f53823b45e fix(ui): update CLIP Vision when ipa model changes 2024-10-23 08:29:14 +10:00
Ryan Dick
5cbe89afdd Merge branch 'main' into ryan/flux-ip-adapter-cfg-2 2024-10-22 21:17:36 +00:00
Ryan Dick
c466d50c3d FLUX CFG support (#7152)
## Summary

Add support for Classifier-Free Guidance with FLUX.

- Using CFG doubles the time for the denoising process. Running both the
positive and negative conditioning in a single batch is left for future
work, because most users are already VRAM-constrained (this would
probably be faster at the cost of higher peak VRAM).
- Negative text conditioning is optional and only required if `cfg_scale
!= 1.0`
- CFG is skipped if `cfg_scale == 1.0` (i.e. no compute overhead in this
case)
- `cfg_scale_start_step` and `cfg_scale_end_step` can be used to easily
control the range of steps that CFG is applied for.
- CFG is a prerequisite for IP-Adapter support.

## Example

Positive Caption: `Professional photography of a luxury hotel in the
Nevada desert`
CFG: 1.0

![image](https://github.com/user-attachments/assets/f25ff832-d69b-4c5f-88f4-9429ce96d598)

Positive Caption: `Professional photography of a luxury hotel in the
Nevada desert`
Negative Caption: `Swimming pool`
CFG: 2.0
Same seed

![image](https://github.com/user-attachments/assets/27e3b952-2795-469f-bb24-b7fddb726ba1)


## QA Instructions

- [ ] Test interactions with ControlNet
- [ ] Verify that peak RAM/VRAM utilization has not increased
significantly
- [ ] Test that CFG is skipped when cfg_scale == 1.0
- [ ] Test that negative text conditioning can be omitted when cfg_scale
== 1.0
- [ ] Test that a clear error message is returned when negative text
conditioning is omitted when cfg_scale != 1.0
- [ ] Test that the negative text prompt gets applied when cfg_scale
>1.0
- [ ] Test that a collection of cfg_scale values can be provided for
per-step control.
- [ ] Test that `cfg_scale_start_step` and `cfg_scale_end_step` control
the range of steps that CFG is applied

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-10-22 17:09:40 -04:00
Ryan Dick
d20b894a61 Add cfg_scale_start_step and cfg_scale_end_step to FLUX Denoise node. 2024-10-23 07:59:48 +11:00
Ryan Dick
20362448b9 Make negative_text_conditioning nullable on FLUX Denoise invocation. 2024-10-23 07:59:48 +11:00
Ryan Dick
5df10cc494 Add support for cfg_scale list on FLUX Denoise node. 2024-10-23 07:59:48 +11:00
Ryan Dick
da171114ea Naive implementation of CFG for FLUX. 2024-10-23 07:59:48 +11:00
Eugene Brodsky
62919a443c fix(installer): remove xformers before installation 2024-10-23 07:57:52 +11:00
Mary Hipp
ffcec91d87 Merge branch 'ryan/flux-ip-adapter-cfg-2' of https://github.com/invoke-ai/InvokeAI into ryan/flux-ip-adapter-cfg-2 2024-10-22 15:23:35 -04:00
Mary Hipp
0a96466b60 feat(ui): add IP adapters to FLUX in linear UI 2024-10-22 15:22:56 -04:00
Ryan Dick
e48cab0276 Only allow a single image prompt for FLUX IP-Adapters (haven't really looked into this much, but punting on it for now). 2024-10-22 16:32:01 +00:00
Ryan Dick
740f6eb19f Skip tests that use the meta device - they fail on the MacOS CI runners. 2024-10-22 15:56:49 +00:00
psychedelicious
d1bb4c2c70 fix(nodes): FluxDenoiseInvocation.controlnet_vae missing default=None 2024-10-22 10:54:15 +11:00
Ryan Dick
e545f18a45 (minor) Fix ruff. 2024-10-21 22:38:06 +00:00
Ryan Dick
e8cd1bb3d8 Add FLUX IP-Adapter starter models. 2024-10-21 22:17:42 +00:00
Ryan Dick
90a906e203 Simplify handling of CLIP ViT selection for FLUX IP-Adapter invocation. 2024-10-21 19:54:59 +00:00
Ryan Dick
5546110127 Add FluxIPAdapterInvocation. 2024-10-21 18:27:40 +00:00
Ryan Dick
73bbb12f7a Use a black image as the negative IP prompt for parity with X-Labs implementation. 2024-10-21 15:47:22 +00:00
Ryan Dick
dde54740c5 Test out IP-Adapter with CFG. 2024-10-21 15:47:17 +00:00
Ryan Dick
f70a8e2c1a A bunch of HACKS to get ViT-L CLIP vision encoder working for FLUX IP-Adapter. Need to revisit how to clean this all up long term. 2024-10-21 15:43:00 +00:00
Ryan Dick
fdccdd52d5 Fixes to get XLabsIpAdapterExtension running. 2024-10-21 15:43:00 +00:00
Ryan Dick
31ffd73423 Initial draft of integrating FLUX IP-Adapter inference support. 2024-10-21 15:42:56 +00:00
Ryan Dick
3fa1012879 Add IPAdapterDoubleBlocks wrapper to tidy FLUX ip-adapter handling. 2024-10-21 15:38:50 +00:00
Ryan Dick
c2a8fbd8d6 (minor) Move infer_xlabs_ip_adapter_params_from_state_dict(...) to state_dict_utils.py. 2024-10-21 15:38:50 +00:00
Ryan Dick
d6643d7263 Add model loading code for xlabs FLUX IP-Adapter (not tested). 2024-10-21 15:38:50 +00:00
Ryan Dick
412e79d8e6 Add model probing for XLabs FLUX IP-Adapter. 2024-10-21 15:38:50 +00:00
Ryan Dick
f939dbdc33 Add is_state_dict_xlabs_ip_adapter() utility function. 2024-10-21 15:38:50 +00:00
Ryan Dick
24a0ca86f5 Add logic for loading an Xlabs IP-Adapter from a state dict. 2024-10-21 15:38:50 +00:00
Ryan Dick
95c30f6a8b Add initial logic for inferring FLUX IP-Adapter params from a state_dict. 2024-10-21 15:38:50 +00:00
Ryan Dick
ac7441e606 Fixup typing/imports for IPDoubleStreamBlockProcessor. 2024-10-21 15:38:50 +00:00
Ryan Dick
9c9af312fe Copy IPDoubleStreamBlockProcessor from 47495425db/src/flux/modules/layers.py (L221). 2024-10-21 15:38:50 +00:00
Ryan Dick
7bf5927c43 Add XLabs IP-Adapter state dict for unit tests. 2024-10-21 15:38:50 +00:00
Ryan Dick
32c7cdd856 Add cfg_scale_start_step and cfg_scale_end_step to FLUX Denoise node. 2024-10-21 14:52:02 +00:00
Mary Hipp
bbd89d54b4 add it to list 2024-10-19 14:08:49 +11:00
Mary Hipp
ee61006a49 add starter model 2024-10-19 14:08:49 +11:00
psychedelicious
0b43f5fd64 docs(ui): improve docstrings for LoggingOverrides 2024-10-19 08:04:20 +11:00
psychedelicious
6c61266990 refactor(ui): logging config handling
Introduce two-stage logging configuration and overrides for enabled status, log level and log namespaces.

The first stage in `<InvokeAIUI />`, before we set up redux (and therefore before we have access to the user's configured logging setup). In this stage, we use the overrides or default values.

The second stage is in `<App />`, after we set up redux, via `useSyncLoggingConfig`. In this stage, we use the overrides or the user's configured logging setup. This hook also handles pushing changes made by the user into localstorage.

Other changes:
- Extract logging config to util function
- Remove the `useEffect` from `SettingsModal` that was changing the logging settings
- Remove extraneous log effects from `useLogger`
- Export new `LoggingOverrides` type
2024-10-19 08:04:20 +11:00
Maximilian Maag
2d5afe8094 fix(installer): Print maximize suggestion when Python is found, not when it's missing 2024-10-18 16:35:51 -04:00
Maximilian Maag
2430137d19 fix(installer): Avoid misleading error message when searching for python binary
which prints a message to stderr when it doesn't find anything. In this case,
not finding anything is expected so the error is misleading.
2024-10-18 16:35:51 -04:00
Ryan Dick
6df4ee5fc8 Make negative_text_conditioning nullable on FLUX Denoise invocation. 2024-10-18 20:31:27 +00:00
Ryan Dick
371742d8f9 Add support for cfg_scale list on FLUX Denoise node. 2024-10-18 20:14:47 +00:00
psychedelicious
5440c03767 fix(app): directory traversal when deleting images 2024-10-18 14:27:41 +11:00
psychedelicious
358dbdbf84 chore: bump version to v5.2.0 2024-10-17 22:24:51 +11:00
psychedelicious
5ec2d71be0 feat(ui): make debug logger middleware configurable
While troubleshooting an issue with this middleware, I found the inclusion of the nextState and diff to be very noisy. It's now a function that accepts some options to configure the output, and returns the middleware.
2024-10-17 08:04:51 +11:00
Mary Hipp
8f28903c81 remove extra slash in workflow share link 2024-10-17 08:02:27 +11:00
Mary Hipp
a071f2788a fix(ui): upload tooltip should only show plural if multiple upload is an option 2024-10-16 12:00:11 -04:00
131 changed files with 5191 additions and 1375 deletions

View File

@@ -38,7 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.6"; \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
fi &&\

View File

@@ -12,7 +12,7 @@ MINIMUM_PYTHON_VERSION=3.10.0
MAXIMUM_PYTHON_VERSION=3.11.100
PYTHON=""
for candidate in python3.11 python3.10 python3 python ; do
if ppath=`which $candidate`; then
if ppath=`which $candidate 2>/dev/null`; then
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
# we check that this found executable can actually run
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
@@ -30,10 +30,11 @@ done
if [ -z "$PYTHON" ]; then
echo "A suitable Python interpreter could not be found"
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
echo "For the best user experience we suggest enlarging or maximizing this window now."
read -p "Press any key to exit"
exit -1
fi
echo "For the best user experience we suggest enlarging or maximizing this window now."
exec $PYTHON ./lib/main.py ${@}
read -p "Press any key to exit"

View File

@@ -245,6 +245,9 @@ class InvokeAiInstance:
pip = local[self.pip]
# Uninstall xformers if it is present; the correct version of it will be reinstalled if needed
_ = pip["uninstall", "-yqq", "xformers"] & FG
pipeline = pip[
"install",
"--require-virtualenv",
@@ -407,7 +410,7 @@ def get_torch_source() -> Tuple[str | None, str | None]:
optional_modules: str | None = None
if OS == "Linux":
if device == GpuType.ROCM:
url = "https://download.pytorch.org/whl/rocm5.6"
url = "https://download.pytorch.org/whl/rocm6.1"
elif device == GpuType.CPU:
url = "https://download.pytorch.org/whl/cpu"
elif device == GpuType.CUDA:

View File

@@ -547,7 +547,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
single_ipa_images = [
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
]
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.

View File

@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@@ -133,6 +134,7 @@ class FieldDescriptions:
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
mmditx = "MMDiTX"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -140,6 +142,7 @@ class FieldDescriptions:
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -246,6 +249,12 @@ class FluxConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -89,12 +89,24 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
negative_text_conditioning: FluxConditioningField | None = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
cfg_scale_start_step: int = InputField(
default=0,
title="CFG Scale Start Step",
description="Index of the first step to apply cfg_scale. Negative indices count backwards from the "
+ "the last step (e.g. a value of -1 refers to the final step).",
)
cfg_scale_end_step: int = InputField(
default=-1,
title="CFG Scale End Step",
description="Index of the last step to apply cfg_scale. Negative indices count backwards from the "
+ "last step (e.g. a value of -1 refers to the final step).",
)
# TODO(ryand): Add support for cfg_scale to be a list of floats: one for each step.
# TODO(ryand): Add cfg_scale range validation.
cfg_scale: float = InputField(default=3.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
@@ -109,6 +121,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
default=None,
description=FieldDescriptions.vae,
input=Input.Connection,
)
@@ -148,9 +161,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings: torch.Tensor | None = None
neg_clip_embeddings: torch.Tensor | None = None
if self.negative_text_conditioning is not None:
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
@@ -215,10 +231,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_txt_ids: torch.Tensor | None = None
if neg_t5_embeddings is not None:
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
@@ -247,6 +265,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
ip_adapter_fields, context
)
cfg_scale = self.prep_cfg_scale(
cfg_scale=self.cfg_scale,
timesteps=timesteps,
cfg_scale_start_step=self.cfg_scale_start_step,
cfg_scale_end_step=self.cfg_scale_end_step,
)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
@@ -318,7 +343,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
cfg_scale=self.cfg_scale,
cfg_scale=cfg_scale,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
@@ -328,6 +353,55 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
x = unpack(x.float(), self.height, self.width)
return x
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
) -> list[float]:
"""Prepare the cfg_scale schedule.
- Clips the cfg_scale schedule based on cfg_scale_start_step and cfg_scale_end_step.
- If cfg_scale is a list, then it is assumed to be a schedule and is returned as-is.
- If cfg_scale is a scalar, then a linear schedule is created from cfg_scale_start_step to cfg_scale_end_step.
"""
# num_steps is the number of denoising steps, which is one less than the number of timesteps.
num_steps = len(timesteps) - 1
# Normalize cfg_scale to a list if it is a scalar.
cfg_scale_list: list[float]
if isinstance(cfg_scale, float):
cfg_scale_list = [cfg_scale] * num_steps
elif isinstance(cfg_scale, list):
cfg_scale_list = cfg_scale
else:
raise ValueError(f"Unsupported cfg_scale type: {type(cfg_scale)}")
assert len(cfg_scale_list) == num_steps
# Handle negative indices for cfg_scale_start_step and cfg_scale_end_step.
start_step_index = cfg_scale_start_step
if start_step_index < 0:
start_step_index = num_steps + start_step_index
end_step_index = cfg_scale_end_step
if end_step_index < 0:
end_step_index = num_steps + end_step_index
# Validate the start and end step indices.
if not (0 <= start_step_index < num_steps):
raise ValueError(f"Invalid cfg_scale_start_step. Out of range: {cfg_scale_start_step}.")
if not (0 <= end_step_index < num_steps):
raise ValueError(f"Invalid cfg_scale_end_step. Out of range: {cfg_scale_end_step}.")
if start_step_index > end_step_index:
raise ValueError(
f"cfg_scale_start_step ({cfg_scale_start_step}) must be before cfg_scale_end_step "
+ f"({cfg_scale_end_step})."
)
# Set values outside the start and end step indices to 1.0. This is equivalent to disabling cfg_scale for those
# steps.
clipped_cfg_scale = [1.0] * num_steps
clipped_cfg_scale[start_step_index : end_step_index + 1] = cfg_scale_list[start_step_index : end_step_index + 1]
return clipped_cfg_scale
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
@@ -497,7 +571,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
raise ValueError(f"Unsupported IP-Adapter image type: {type(ip_adapter_field.image)}")
ipa_images = [context.images.get_pil(image.image_name) for image in ipa_image_fields]
if len(ipa_image_fields) != 1:
raise ValueError(
f"FLUX IP-Adapter only supports a single image prompt (received {len(ipa_image_fields)})."
)
ipa_images = [context.images.get_pil(image.image_name, mode="RGB") for image in ipa_image_fields]
pos_images: list[npt.NDArray[np.uint8]] = []
neg_images: list[npt.NDArray[np.uint8]] = []

View File

@@ -0,0 +1,89 @@
from builtins import float
from typing import List, Literal, Union
from pydantic import field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import InputField, UIType
from invokeai.app.invocations.ip_adapter import (
CLIP_VISION_MODEL_MAP,
IPAdapterField,
IPAdapterInvocation,
IPAdapterOutput,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
)
@invocation(
"flux_ip_adapter",
title="FLUX IP-Adapter",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxIPAdapterInvocation(BaseInvocation):
"""Collects FLUX IP-Adapter info to pass to other nodes."""
# FLUXIPAdapterInvocation is based closely on IPAdapterInvocation, but with some unsupported features removed.
image: ImageField = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", ui_type=UIType.IPAdapterModel
)
# Currently, the only known ViT model used by FLUX IP-Adapters is ViT-L.
clip_vision_model: Literal["ViT-L"] = InputField(description="CLIP Vision model to use.", default="ViT-L")
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v: float) -> float:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self) -> Self:
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_model_id = image_encoder_starter_model.source
image_encoder_model_name = image_encoder_starter_model.name
image_encoder_model = IPAdapterInvocation.get_clip_image_encoder(
context, image_encoder_model_id, image_encoder_model_name
)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
weight=self.weight,
target_blocks=[], # target_blocks is currently unused for FLUX IP-Adapters.
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
mask=None, # mask is currently unused for FLUX IP-Adapters.
),
)

View File

@@ -0,0 +1,89 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
SubModelType,
)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -18,6 +18,12 @@ from invokeai.backend.model_manager.config import (
IPAdapterInvokeAIConfig,
ModelType,
)
from invokeai.backend.model_manager.starter_models import (
StarterModel,
clip_vit_l_image_encoder,
ip_adapter_sd_image_encoder,
ip_adapter_sdxl_image_encoder,
)
class IPAdapterField(BaseModel):
@@ -56,10 +62,10 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
CLIP_VISION_MODEL_MAP = {
"ViT-L": ("InvokeAI/clip-vit-large-patch14", "clip-vit-large-patch14-full"),
"ViT-H": ("InvokeAI/ip_adapter_sd_image_encoder", "ip_adapter_sd_image_encoder"),
"ViT-G": ("InvokeAI/ip_adapter_sdxl_image_encoder", "ip_adapter_sdxl_image_encoder"),
CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] = {
"ViT-L": clip_vit_l_image_encoder,
"ViT-H": ip_adapter_sd_image_encoder,
"ViT-G": ip_adapter_sdxl_image_encoder,
}
@@ -75,7 +81,7 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
clip_vision_model: Literal["ViT-L", "ViT-H", "ViT-G"] = InputField(
clip_vision_model: Literal["ViT-H", "ViT-G", "ViT-L"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="ViT-H",
ui_order=2,
@@ -116,9 +122,11 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
image_encoder_model_id, image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_model_id = image_encoder_starter_model.source
image_encoder_model_name = image_encoder_starter_model.name
image_encoder_model = self._get_image_encoder(context, image_encoder_model_id, image_encoder_model_name)
image_encoder_model = self.get_clip_image_encoder(context, image_encoder_model_id, image_encoder_model_name)
if self.method == "style":
if ip_adapter_info.base == "sd-1":
@@ -152,8 +160,9 @@ class IPAdapterInvocation(BaseInvocation):
),
)
def _get_image_encoder(
self, context: InvocationContext, image_encoder_model_id: str, image_encoder_model_name: str
@classmethod
def get_clip_image_encoder(
cls, context: InvocationContext, image_encoder_model_id: str, image_encoder_model_name: str
) -> AnyModelConfig:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision

View File

@@ -5,6 +5,7 @@ from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
from invokeai.backend.image_util.util import pil_to_np
@invocation(
@@ -148,3 +149,51 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
mask_pil = Image.fromarray(mask_np, mode="L")
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)
@invocation(
"apply_tensor_mask_to_image",
title="Apply Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
)
class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies a tensor mask to an image.
The image is converted to RGBA and the mask is applied to the alpha channel."""
mask: TensorField = InputField(description="The mask tensor to apply.")
image: ImageField = InputField(description="The image to apply the mask to.")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
mask = context.tensors.load(self.mask.tensor_name)
# Squeeze the channel dimension if it exists.
if mask.dim() == 3:
mask = mask.squeeze(0)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy().astype(np.uint8)
# Apply the mask only to the alpha channel where the original alpha is non-zero. This preserves the original
# image's transparency - else the transparent regions would end up as opaque black.
# Separate the image into R, G, B, and A channels
image_np = pil_to_np(image)
r, g, b, a = np.split(image_np, 4, axis=-1)
# Apply the mask to the alpha channel
new_alpha = np.where(a.squeeze() > 0, mask_np, a.squeeze())
# Stack the RGB channels with the modified alpha
masked_image_np = np.dstack([r.squeeze(), g.squeeze(), b.squeeze(), new_alpha])
# Convert back to an image (RGBA)
masked_image = Image.fromarray(masked_image_np.astype(np.uint8), "RGBA")
image_dto = context.images.save(image=masked_image)
return ImageOutput.build(image_dto)

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
SD3ConditioningField,
TensorField,
UIComponent,
)
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,241 @@
from typing import Callable, Tuple
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
SD3ConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_denoise",
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
sd3_conditioning = cond_data.conditionings[0]
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
t5_embeds = sd3_conditioning.t5_embeds
if t5_embeds is None:
# TODO(ryand): Construct a zero tensor of the correct shape to use as the T5 conditioning.
raise NotImplementedError("SD3 inference without T5 conditioning is not yet supported.")
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
pooled_prompt_embeds = torch.cat(
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
)
return prompt_embeds, pooled_prompt_embeds
def _get_noise(
self,
num_samples: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
"""Prepare the CFG scale list.
Args:
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
on the scheduler used (e.g. higher order schedulers).
Returns:
list[float]: _description_
"""
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype, device
)
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype, device
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.num_steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
transformer_info = context.models.load(self.transformer.transformer)
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)
noise = self._get_noise(
num_samples=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=latents,
),
)
with transformer_info.model_on_device() as (cached_weights, transformer):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=None,
return_dict=False,
)[0]
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None: ...
return step_callback

View File

@@ -0,0 +1,97 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("sd3_model_loader_output")
class Sd3ModelLoaderOutput(BaseInvocationOutput):
"""SD3 base model loader output."""
mmditx: TransformerField = OutputField(description=FieldDescriptions.mmditx, title="MMDiTX")
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"sd3_model_loader",
title="SD3 Main Model",
tags=["model", "sd3"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
)
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="T5 Encoder",
default=None,
)
# TODO(brandon): Setup UI updates to support selecting a clip l model.
# clip_l_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.clip_l_model,
# ui_type=UIType.CLIPEmbedModel,
# input=Input.Direct,
# title="CLIP L Encoder",
# )
# TODO(brandon): Setup UI updates to support selecting a clip g model.
# clip_g_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.clip_g_model,
# ui_type=UIType.CLIPGModel,
# input=Input.Direct,
# title="CLIP G Encoder",
# )
# TODO(brandon): Setup UI updates to support selecting an SD3 vae model.
# vae_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE", default=None
# )
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
mmditx = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
clip_encoder_g = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
mmditx=TransformerField(transformer=mmditx, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,196 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
@invocation(
"sd3_text_encoder",
title="SD3 Text Encoding",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""
clip_l: CLIPField = InputField(
title="CLIP L",
description=FieldDescriptions.clip,
input=Input.Connection,
)
clip_g: CLIPField = InputField(
title="CLIP G",
description=FieldDescriptions.clip,
input=Input.Connection,
)
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
t5_encoder: T5EncoderField | None = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
t5_max_seq_len = 256
t5_embeddings: torch.Tensor | None = None
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, t5_max_seq_len)
conditioning_data = ConditioningFieldData(
conditionings=[
SD3ConditioningInfo(
clip_l_embeds=clip_l_embeddings,
clip_l_pooled_embeds=clip_l_pooled_embeddings,
clip_g_embeds=clip_g_embeddings,
clip_g_pooled_embeds=clip_g_pooled_embeddings,
t5_embeds=t5_embeddings,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return SD3ConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
text_inputs = t5_tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
text_inputs = clip_tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -1,9 +1,11 @@
from enum import Enum
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field, model_validator
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
@@ -23,12 +25,31 @@ SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
}
class SAMPointLabel(Enum):
negative = -1
neutral = 0
positive = 1
class SAMPoint(BaseModel):
x: int = Field(..., description="The x-coordinate of the point")
y: int = Field(..., description="The y-coordinate of the point")
label: SAMPointLabel = Field(..., description="The label of the point")
class SAMPointsField(BaseModel):
points: list[SAMPoint] = Field(..., description="The points of the object")
def to_list(self) -> list[list[int]]:
return [[point.x, point.y, point.label.value] for point in self.points]
@invocation(
"segment_anything",
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.0.0",
version="1.1.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
@@ -40,7 +61,13 @@ class SegmentAnythingInvocation(BaseInvocation):
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
image: ImageField = InputField(description="The image to segment.")
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
bounding_boxes: list[BoundingBoxField] | None = InputField(
default=None, description="The bounding boxes to prompt the SAM model with."
)
point_lists: list[SAMPointsField] | None = InputField(
default=None,
description="The list of point lists to prompt the SAM model with. Each list of points represents a single object.",
)
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
default=True,
@@ -50,12 +77,22 @@ class SegmentAnythingInvocation(BaseInvocation):
default="all",
)
@model_validator(mode="after")
def check_point_lists_or_bounding_box(self):
if self.point_lists is None and self.bounding_boxes is None:
raise ValueError("Either point_lists or bounding_box must be provided.")
elif self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
return self
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if len(self.bounding_boxes) == 0:
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
not self.point_lists or len(self.point_lists) == 0
):
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else:
masks = self._segment(context=context, image=image_pil)
@@ -83,14 +120,13 @@ class SegmentAnythingInvocation(BaseInvocation):
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _segment(
self,
context: InvocationContext,
image: Image.Image,
) -> list[torch.Tensor]:
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
sam_bounding_boxes = (
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
)
sam_points = [p.to_list() for p in self.point_lists] if self.point_lists else None
with (
context.models.load_remote_model(
@@ -98,7 +134,7 @@ class SegmentAnythingInvocation(BaseInvocation):
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes, point_lists=sam_points)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:
@@ -141,9 +177,10 @@ class SegmentAnythingInvocation(BaseInvocation):
return masks
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
def _filter_masks(
self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField] | None
) -> list[torch.Tensor]:
"""Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes)
if self.mask_filter == "all":
return masks
@@ -151,6 +188,10 @@ class SegmentAnythingInvocation(BaseInvocation):
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
assert (
bounding_boxes is not None
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
assert len(masks) == len(bounding_boxes)
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a

View File

@@ -110,15 +110,26 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
base_folder = self.__thumbnails_folder if thumbnail else self.__output_folder
filename = get_thumbnail_name(image_name) if thumbnail else image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
# Strip any path information from the filename
basename = Path(filename).name
return path
if basename != filename:
raise ValueError("Invalid image name, potential directory traversal detected")
image_path = base_folder / basename
# Ensure the image path is within the base folder to prevent directory traversal
resolved_base = base_folder.resolve()
resolved_image_path = image_path.resolve()
if not resolved_image_path.is_relative_to(resolved_base):
raise ValueError("Image path outside outputs folder, potential directory traversal detected")
return resolved_image_path
def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image or thumbnail."""

View File

@@ -1,3 +1,4 @@
import math
from typing import Callable
import torch
@@ -22,14 +23,14 @@ def denoise(
txt_ids: torch.Tensor,
vec: torch.Tensor,
# negative text conditioning
neg_txt: torch.Tensor,
neg_txt_ids: torch.Tensor,
neg_vec: torch.Tensor,
neg_txt: torch.Tensor | None,
neg_txt_ids: torch.Tensor | None,
neg_vec: torch.Tensor | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
cfg_scale: float,
cfg_scale: list[float],
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
@@ -46,20 +47,17 @@ def denoise(
latents=img,
),
)
step = 1
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
timestep_index = step - 1
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=timestep_index,
timestep_index=step_index,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
@@ -72,7 +70,7 @@ def denoise(
)
# Merge the ControlNet residuals from multiple ControlNets.
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
@@ -85,17 +83,23 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=timestep_index,
timestep_index=step_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
)
# TODO(ryand): Add option to apply controlnet to negative conditioning as well.
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
# systems with sufficient VRAM.
if step > 1:
step_cfg_scale = cfg_scale[step_index]
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
if not math.isclose(step_cfg_scale, 1.0):
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_pred = model(
img=img,
img_ids=img_ids,
@@ -104,13 +108,13 @@ def denoise(
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=timestep_index,
timestep_index=step_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
)
pred = neg_pred + cfg_scale * (pred - neg_pred)
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
@@ -121,13 +125,12 @@ def denoise(
step_callback(
PipelineIntermediateState(
step=step,
step=step_index + 1,
order=1,
total_steps=total_steps,
timestep=int(t_curr),
latents=preview_img,
),
)
step += 1
return img

View File

@@ -168,8 +168,17 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
Returns:
torch.Tensor: Image position ids.
"""
if device.type == "mps":
orig_dtype = dtype
dtype = torch.float16
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids.to(orig_dtype)
return img_ids

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, TypeAlias
import torch
from PIL import Image
@@ -7,6 +7,14 @@ from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.raw_model import RawModel
# Type aliases for the inputs to the SAM model.
ListOfBoundingBoxes: TypeAlias = list[list[int]]
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
ListOfPoints: TypeAlias = list[list[int]]
"""A list of points. Each point is in the format [x, y]."""
ListOfPointLabels: TypeAlias = list[int]
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
@@ -27,20 +35,53 @@ class SegmentAnythingPipeline(RawModel):
return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
def segment(
self,
image: Image.Image,
bounding_boxes: list[list[int]] | None = None,
point_lists: list[list[list[int]]] | None = None,
) -> torch.Tensor:
"""Run the SAM model.
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
point_lists will be ignored.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Add batch dimension of 1 to the bounding boxes.
boxes = [bounding_boxes]
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
# Prep the inputs:
# - Create a list of bounding boxes or points and labels.
# - Add a batch dimension of 1 to the inputs.
if bounding_boxes:
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
input_points: list[ListOfPoints] | None = None
input_labels: list[ListOfPointLabels] | None = None
elif point_lists:
input_boxes: list[ListOfBoundingBoxes] | None = None
input_points: list[ListOfPoints] | None = []
input_labels: list[ListOfPointLabels] | None = []
for point_list in point_lists:
input_points.append([[p[0], p[1]] for p in point_list])
input_labels.append([p[2] for p in point_list])
else:
raise ValueError("Either bounding_boxes or points and labels must be provided.")
inputs = self._sam_processor(
images=image,
input_boxes=input_boxes,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt",
).to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,

View File

@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
@@ -147,6 +150,11 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
@@ -193,6 +201,9 @@ class ModelConfigBase(BaseModel):
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
class CheckpointConfigBase(ModelConfigBase):
@@ -394,6 +405,8 @@ class IPAdapterBaseConfig(ModelConfigBase):
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter diffusers format models."""
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]

View File

@@ -29,7 +29,7 @@ class ClipVisionLoader(ModelLoader):
raise ValueError("Only DiffusersConfigBase models are currently supported here.")
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
raise Exception("There are no submodels in CLIP Vision models.")
model_path = Path(config.path)

View File

@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
@@ -172,10 +172,10 @@ class T5EncoderCheckpointModel(ModelLoader):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"

View File

@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
config: AnyModelConfig,

View File

@@ -20,7 +20,7 @@ from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session

View File

@@ -19,7 +19,7 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@@ -33,7 +33,10 @@ from invokeai.backend.model_manager.config import (
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubmodelDefinition,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -112,6 +115,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
@@ -122,6 +126,8 @@ class ModelProbe(object):
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
}
@classmethod
@@ -178,7 +184,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["hash"] = "placeholder" # fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@@ -217,6 +223,10 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
get_submodels = getattr(probe, "get_submodels", None)
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@@ -746,18 +756,33 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
config_path = self.model_path / "unet" / "config.json"
if config_path.exists():
with open(config_path) as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a transformer (i.e. SD3).
config_path = self.model_path / "transformer" / "config.json"
if config_path.exists():
with open(config_path) as file:
transformer_conf = json.load(file)
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
@@ -769,6 +794,21 @@ class PipelineFolderProbe(FolderProbeBase):
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
submodels: Dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type,
)
return submodels
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the

View File

@@ -25,22 +25,6 @@ class StarterModelBundles(BaseModel):
models: list[StarterModel]
ip_adapter_sd_image_encoder = StarterModel(
name="IP Adapter SD1.5 Image Encoder",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_sd_image_encoder",
description="IP Adapter SD Image Encoder",
type=ModelType.CLIPVision,
)
ip_adapter_sdxl_image_encoder = StarterModel(
name="IP Adapter SDXL Image Encoder",
base=BaseModelType.StableDiffusionXL,
source="InvokeAI/ip_adapter_sdxl_image_encoder",
description="IP Adapter SDXL Image Encoder",
type=ModelType.CLIPVision,
)
cyberrealistic_negative = StarterModel(
name="CyberRealistic Negative v3",
base=BaseModelType.StableDiffusion1,
@@ -49,6 +33,32 @@ cyberrealistic_negative = StarterModel(
type=ModelType.TextualInversion,
)
# region CLIP Image Encoders
ip_adapter_sd_image_encoder = StarterModel(
name="IP Adapter SD1.5 Image Encoder",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_sd_image_encoder",
description="IP Adapter SD Image Encoder",
type=ModelType.CLIPVision,
)
ip_adapter_sdxl_image_encoder = StarterModel(
name="IP Adapter SDXL Image Encoder",
base=BaseModelType.StableDiffusionXL,
source="InvokeAI/ip_adapter_sdxl_image_encoder",
description="IP Adapter SDXL Image Encoder",
type=ModelType.CLIPVision,
)
# Note: This model is installed from the same source as the CLIPEmbed model below. The model contains both the image
# encoder and the text encoder, but we need separate model entries so that they get loaded correctly.
clip_vit_l_image_encoder = StarterModel(
name="clip-vit-large-patch14",
base=BaseModelType.Any,
source="InvokeAI/clip-vit-large-patch14",
description="CLIP ViT-L Image Encoder",
type=ModelType.CLIPVision,
)
# endregion
# region TextEncoders
t5_base_encoder = StarterModel(
name="t5_base_encoder",
@@ -186,6 +196,16 @@ dreamshaper_sdxl = StarterModel(
type=ModelType.Main,
dependencies=[sdxl_fp16_vae_fix],
)
archvis_sdxl = StarterModel(
name="Architecture (RealVisXL5)",
base=BaseModelType.StableDiffusionXL,
source="SG161222/RealVisXL_V5.0",
description="A photorealistic model, with architecture among its many use cases",
type=ModelType.Main,
dependencies=[sdxl_fp16_vae_fix],
)
sdxl_refiner = StarterModel(
name="SDXL Refiner",
base=BaseModelType.StableDiffusionXLRefiner,
@@ -254,6 +274,14 @@ ip_adapter_sdxl = StarterModel(
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sdxl_image_encoder],
)
ip_adapter_flux = StarterModel(
name="XLabs FLUX IP-Adapter",
base=BaseModelType.Flux,
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
description="FLUX IP-Adapter",
type=ModelType.IPAdapter,
dependencies=[clip_vit_l_image_encoder],
)
# endregion
# region ControlNet
qr_code_cnet_sd1 = StarterModel(
@@ -545,6 +573,7 @@ STARTER_MODELS: list[StarterModel] = [
deliberate_inpainting_sd1,
juggernaut_sdxl,
dreamshaper_sdxl,
archvis_sdxl,
sdxl_refiner,
sdxl_fp16_vae_fix,
flux_vae,
@@ -555,6 +584,7 @@ STARTER_MODELS: list[StarterModel] = [
ip_adapter_plus_sd1,
ip_adapter_plus_face_sd1,
ip_adapter_sdxl,
ip_adapter_flux,
qr_code_cnet_sd1,
qr_code_cnet_sdxl,
canny_sd1,
@@ -642,6 +672,7 @@ flux_bundle: list[StarterModel] = [
t5_8b_quantized_encoder,
clip_l_encoder,
union_cnet_flux,
ip_adapter_flux,
]
STARTER_BUNDLES: dict[str, list[StarterModel]] = {

View File

@@ -54,6 +54,11 @@ GGML_TENSOR_OP_TABLE = {
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
}
if torch.backends.mps.is_available():
GGML_TENSOR_OP_TABLE.update(
{torch.ops.aten.linear.default: dequantize_and_run} # pyright: ignore
)
class GGMLTensor(torch.Tensor):
"""A torch.Tensor sub-class holding a quantized GGML tensor.

View File

@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
return self
@dataclass
class SD3ConditioningInfo:
clip_l_pooled_embeds: torch.Tensor
clip_l_embeds: torch.Tensor
clip_g_pooled_embeds: torch.Tensor
clip_g_embeds: torch.Tensor
t5_embeds: torch.Tensor | None
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
if self.t5_embeds is not None:
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]
| List[FLUXConditioningInfo]
| List[SD3ConditioningInfo]
)
@dataclass

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
@@ -32,7 +32,9 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
# NOTE(ryand): I'm not the origina author of this code, but for future reference, it appears that this class was copied
# from diffusers in order to add support for the encoder_attention_mask argument.
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.

View File

@@ -114,8 +114,7 @@
},
"peerDependencies": {
"react": "^18.2.0",
"react-dom": "^18.2.0",
"ts-toolbelt": "^9.6.0"
"react-dom": "^18.2.0"
},
"devDependencies": {
"@invoke-ai/eslint-config-react": "^0.0.14",
@@ -149,8 +148,8 @@
"prettier": "^3.3.3",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.3.4",
"ts-toolbelt": "^9.6.0",
"tsafe": "^1.7.5",
"type-fest": "^4.26.1",
"typescript": "^5.6.2",
"vite": "^5.4.8",
"vite-plugin-css-injected-by-js": "^3.5.2",

View File

@@ -277,12 +277,12 @@ devDependencies:
storybook:
specifier: ^8.3.4
version: 8.3.4
ts-toolbelt:
specifier: ^9.6.0
version: 9.6.0
tsafe:
specifier: ^1.7.5
version: 1.7.5
type-fest:
specifier: ^4.26.1
version: 4.26.1
typescript:
specifier: ^5.6.2
version: 5.6.2
@@ -8830,10 +8830,6 @@ packages:
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
dev: false
/ts-toolbelt@9.6.0:
resolution: {integrity: sha512-nsZd8ZeNUzukXPlJmTBwUAuABDe/9qtVDelJeT/qW0ow3ZS3BsQJtNkan1802aM9Uf68/Y8ljw86Hu0h5IUW3w==}
dev: true
/tsafe@1.7.5:
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
dev: true

View File

@@ -93,7 +93,9 @@
"placeholderSelectAModel": "Modell auswählen",
"reset": "Zurücksetzen",
"none": "Keine",
"new": "Neu"
"new": "Neu",
"ok": "OK",
"close": "Schließen"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -156,7 +158,11 @@
"displayBoardSearch": "Board durchsuchen",
"displaySearch": "Bild suchen",
"go": "Los",
"jump": "Springen"
"jump": "Springen",
"assetsTab": "Dateien, die Sie zur Verwendung in Ihren Projekten hochgeladen haben.",
"imagesTab": "Bilder, die Sie in Invoke erstellt und gespeichert haben.",
"boardsSettings": "Ordnereinstellungen",
"imagesSettings": "Galeriebildereinstellungen"
},
"hotkeys": {
"noHotkeysFound": "Kein Hotkey gefunden",
@@ -267,6 +273,18 @@
"applyFilter": {
"title": "Filter anwenden",
"desc": "Wende den ausstehenden Filter auf die ausgewählte Ebene an."
},
"cancelFilter": {
"title": "Filter abbrechen",
"desc": "Den ausstehenden Filter abbrechen."
},
"applyTransform": {
"desc": "Die ausstehende Transformation auf die ausgewählte Ebene anwenden.",
"title": "Transformation anwenden"
},
"cancelTransform": {
"title": "Transformation abbrechen",
"desc": "Die ausstehende Transformation abbrechen."
}
},
"viewer": {
@@ -563,7 +581,18 @@
"scanResults": "Ergebnisse des Scans",
"urlOrLocalPathHelper": "URLs sollten auf eine einzelne Datei deuten. Lokale Pfade können zusätzlich auch auf einen Ordner für ein einzelnes Diffusers-Modell hinweisen.",
"inplaceInstallDesc": "Installieren Sie Modelle, ohne die Dateien zu kopieren. Wenn Sie das Modell verwenden, wird es direkt von seinem Speicherort geladen. Wenn deaktiviert, werden die Dateien während der Installation in das von Invoke verwaltete Modellverzeichnis kopiert.",
"scanFolderHelper": "Der Ordner wird rekursiv nach Modellen durchsucht. Dies kann bei sehr großen Ordnern etwas dauern."
"scanFolderHelper": "Der Ordner wird rekursiv nach Modellen durchsucht. Dies kann bei sehr großen Ordnern etwas dauern.",
"includesNModels": "Enthält {{n}} Modelle und deren Abhängigkeiten",
"starterBundles": "Starterpakete",
"installingXModels_one": "{{count}} Modell wird installiert",
"installingXModels_other": "{{count}} Modelle werden installiert",
"skippingXDuplicates_one": ", überspringe {{count}} Duplikat",
"skippingXDuplicates_other": ", überspringe {{count}} Duplikate",
"installingModel": "Modell wird installiert",
"loraTriggerPhrases": "LoRA-Auslösephrasen",
"installingBundle": "Bündel wird installiert",
"triggerPhrases": "Auslösephrasen",
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen"
},
"parameters": {
"images": "Bilder",
@@ -667,7 +696,8 @@
"about": "Über",
"submitSupportTicket": "Support-Ticket senden",
"toggleRightPanel": "Rechtes Bedienfeld umschalten (G)",
"toggleLeftPanel": "Linkes Bedienfeld umschalten (T)"
"toggleLeftPanel": "Linkes Bedienfeld umschalten (T)",
"uploadImages": "Bild(er) hochladen"
},
"boards": {
"autoAddBoard": "Board automatisch erstellen",
@@ -702,7 +732,7 @@
"shared": "Geteilte Ordner",
"archiveBoard": "Ordner archivieren",
"archived": "Archiviert",
"noBoards": "Kein {boardType}} Ordner",
"noBoards": "Kein {{boardType}} Ordner",
"hideBoards": "Ordner verstecken",
"viewBoards": "Ordner ansehen",
"deletedPrivateBoardsCannotbeRestored": "Gelöschte Boards können nicht wiederhergestellt werden. Wenn Sie „Nur Board löschen“ wählen, werden die Bilder in einen privaten, nicht kategorisierten Status für den Ersteller des Bildes versetzt.",
@@ -811,7 +841,8 @@
"parameterSet": "Parameter {{parameter}} setzen",
"recallParameter": "{{label}} Abrufen",
"parsingFailed": "Parsing Fehlgeschlagen",
"canvasV2Metadata": "Leinwand"
"canvasV2Metadata": "Leinwand",
"guidance": "Führung"
},
"popovers": {
"noiseUseCPU": {
@@ -1137,7 +1168,9 @@
"workflowNotes": "Notizen",
"workflowTags": "Tags",
"workflowVersion": "Version",
"saveToGallery": "In Galerie speichern"
"saveToGallery": "In Galerie speichern",
"noWorkflows": "Keine Arbeitsabläufe",
"noMatchingWorkflows": "Keine passenden Arbeitsabläufe"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -1842,6 +1842,17 @@
"apply": "Apply",
"cancel": "Cancel"
},
"segment": {
"autoMask": "Auto Mask",
"pointType": "Point Type",
"foreground": "Foreground",
"background": "Background",
"neutral": "Neutral",
"reset": "Reset",
"apply": "Apply",
"cancel": "Cancel",
"process": "Process"
},
"settings": {
"snapToGrid": {
"label": "Snap to Grid",
@@ -1852,10 +1863,10 @@
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
},
"isolatedPreview": "Isolated Preview",
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedFilteringPreview": "Isolated Filtering Preview",
"isolatedTransformingPreview": "Isolated Transforming Preview",
"isolatedPreview": "Isolated Preview",
"isolatedLayerPreview": "Isolated Layer Preview",
"isolatedLayerPreviewDesc": "Whether to show only this layer when performing operations like filtering or transforming.",
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
"pressureSensitivity": "Pressure Sensitivity"
},

View File

@@ -6,7 +6,7 @@
"settingsLabel": "Paramètres",
"img2img": "Image vers Image",
"nodes": "Processus",
"upload": "Télécharger",
"upload": "Importer",
"load": "Charger",
"back": "Retour",
"statusDisconnected": "Hors ligne",
@@ -51,7 +51,7 @@
"green": "Vert",
"delete": "Supprimer",
"simple": "Simple",
"template": "Modèle",
"template": "Template",
"advanced": "Avancé",
"copy": "Copier",
"saveAs": "Enregistrer sous",
@@ -117,8 +117,8 @@
"bulkDownloadRequestFailed": "Problème lors de la préparation du téléchargement",
"copy": "Copier",
"autoAssignBoardOnClick": "Assigner automatiquement une Planche lors du clic",
"dropToUpload": "$t(gallery.drop) pour Charger",
"dropOrUpload": "$t(gallery.drop) ou Séléctioner",
"dropToUpload": "$t(gallery.drop) pour Importer",
"dropOrUpload": "$t(gallery.drop) ou Importer",
"oldestFirst": "Plus Ancien en premier",
"deleteImagePermanent": "Les Images supprimées ne peuvent pas être restorées.",
"displaySearch": "Recherche d'Image",
@@ -161,7 +161,7 @@
"unstarImage": "Retirer le marquage de l'Image",
"viewerImage": "Visualisation de l'Image",
"imagesSettings": "Paramètres des images de la galerie",
"assetsTab": "Fichiers que vous avez chargé pour vos projets.",
"assetsTab": "Fichiers que vous avez importé pour vos projets.",
"imagesTab": "Images que vous avez créées et enregistrées dans Invoke.",
"boardsSettings": "Paramètres des planches"
},
@@ -243,7 +243,7 @@
"noModelsInstalled": "Aucun modèle installé",
"urlOrLocalPath": "URL ou chemin local",
"prune": "Vider",
"uploadImage": "Charger une image",
"uploadImage": "Importer une image",
"addModels": "Ajouter des modèles",
"install": "Installer",
"localOnly": "local uniquement",
@@ -273,7 +273,18 @@
"spandrelImageToImage": "Image vers Image (Spandrel)",
"starterModelsInModelManager": "Les modèles de démarrage peuvent être trouvés dans le gestionnaire de modèles",
"t5Encoder": "Encodeur T5",
"learnMoreAboutSupportedModels": "En savoir plus sur les modèles que nous prenons en charge"
"learnMoreAboutSupportedModels": "En savoir plus sur les modèles que nous prenons en charge",
"includesNModels": "Contient {{n}} modèles et leurs dépendances",
"starterBundles": "Packs de démarrages",
"starterBundleHelpText": "Installe facilement tous les modèles nécessaire pour démarrer avec un modèle de base, incluant un modèle principal, ControlNets, IP Adapters et plus encore. Choisir un pack igniorera tous les modèles déjà installés.",
"installingXModels_one": "En cours d'installation de {{count}} modèle",
"installingXModels_many": "En cours d'installation de {{count}} modèles",
"installingXModels_other": "En cours d'installation de {{count}} modèles",
"skippingXDuplicates_one": ", en ignorant {{count}} doublon",
"skippingXDuplicates_many": ", en ignorant {{count}} doublons",
"skippingXDuplicates_other": ", en ignorant {{count}} doublons",
"installingModel": "Modèle en cours d'installation",
"installingBundle": "Pack en cours d'installation"
},
"parameters": {
"images": "Images",
@@ -414,16 +425,16 @@
"confirmOnNewSession": "Confirmer lors d'une nouvelle session"
},
"toast": {
"uploadFailed": "Téléchargement échoué",
"uploadFailed": "Importation échouée",
"imageCopied": "Image copiée",
"parametersNotSet": "Paramètres non rappelés",
"serverError": "Erreur du serveur",
"uploadFailedInvalidUploadDesc": "Doit être une unique image PNG ou JPEG",
"uploadFailedInvalidUploadDesc": "Doit être des images au format PNG ou JPEG.",
"problemCopyingImage": "Impossible de copier l'image",
"parameterSet": "Paramètre Rappelé",
"parameterNotSet": "Paramètre non Rappelé",
"canceled": "Traitement annulé",
"addedToBoard": "Ajouté à la planche",
"addedToBoard": "Ajouté aux ressources de la planche {{name}}",
"workflowLoaded": "Processus chargé",
"connected": "Connecté au serveur",
"setNodeField": "Définir comme champ de nœud",
@@ -436,7 +447,7 @@
"baseModelChangedCleared_one": "Effacé ou désactivé {{count}} sous-modèle incompatible",
"baseModelChangedCleared_many": "Effacé ou désactivé {{count}} sous-modèles incompatibles",
"baseModelChangedCleared_other": "Effacé ou désactivé {{count}} sous-modèles incompatibles",
"invalidUpload": "Téléchargement invalide",
"invalidUpload": "Importation invalide",
"problemDownloadingImage": "Impossible de télécharger l'image",
"problemRetrievingWorkflow": "Problème de récupération du processus",
"problemDeletingWorkflow": "Problème de suppression du processus",
@@ -468,10 +479,15 @@
"baseModelChanged": "Modèle de base changé",
"problemSavingLayer": "Impossible d'enregistrer la couche",
"imageNotLoadedDesc": "Image introuvable",
"linkCopied": "Lien copié"
"linkCopied": "Lien copié",
"imagesWillBeAddedTo": "Les images Importées seront ajoutées au ressources de la Planche {{boardName}}.",
"uploadFailedInvalidUploadDesc_withCount_one": "Doit être au maximum une image PNG ou JPEG.",
"uploadFailedInvalidUploadDesc_withCount_many": "Doit être au maximum {{count}} images PNG ou JPEG.",
"uploadFailedInvalidUploadDesc_withCount_other": "Doit être au maximum {{count}} images PNG ou JPEG.",
"addedToUncategorized": "Ajouté aux ressources de la planche $t(boards.uncategorized)"
},
"accessibility": {
"uploadImage": "Charger une image",
"uploadImage": "Importer une image",
"reset": "Réinitialiser",
"nextImage": "Image suivante",
"previousImage": "Image précédente",
@@ -483,7 +499,8 @@
"submitSupportTicket": "Envoyer un ticket de support",
"resetUI": "$t(accessibility.reset) l'Interface Utilisateur",
"toggleRightPanel": "Afficher/Masquer le panneau de droite (G)",
"toggleLeftPanel": "Afficher/Masquer le panneau de gauche (T)"
"toggleLeftPanel": "Afficher/Masquer le panneau de gauche (T)",
"uploadImages": "Importer Image(s)"
},
"boards": {
"move": "Déplacer",
@@ -1400,13 +1417,14 @@
"parameterSet": "Paramètre {{parameter}} défini",
"parsingFailed": "L'analyse a échoué",
"recallParameter": "Rappeler {{label}}",
"canvasV2Metadata": "Toile"
"canvasV2Metadata": "Toile",
"guidance": "Guide"
},
"sdxl": {
"freePromptStyle": "Écriture de Prompt manuelle",
"concatPromptStyle": "Lier Prompt & Style",
"negStylePrompt": "Prompt Négatif",
"posStylePrompt": "Prompt Positif",
"negStylePrompt": "Style Prompt Négatif",
"posStylePrompt": "Style Prompt Positif",
"refinerStart": "Démarrer le Refiner",
"denoisingStrength": "Force de débruitage",
"steps": "Étapes",
@@ -1582,7 +1600,7 @@
"noDescription": "Aucune description",
"deleteWorkflow": "Supprimer le processus",
"openWorkflow": "Ouvrir le processus",
"uploadWorkflow": "Charger à partir du fichier",
"uploadWorkflow": "Charger à partir d'un fichier",
"workflowName": "Nom du processus",
"unnamedWorkflow": "Processus sans nom",
"saveWorkflowAs": "Enregistrer le processus sous",
@@ -1613,7 +1631,7 @@
"projectWorkflows": "Processus du projet",
"copyShareLink": "Copier le lien de partage",
"chooseWorkflowFromLibrary": "Choisir le Processus dans la Bibliothèque",
"uploadAndSaveWorkflow": "Charger dans la bibliothèque",
"uploadAndSaveWorkflow": "Importer dans la bibliothèque",
"edit": "Modifer",
"deleteWorkflow2": "Êtes-vous sûr de vouloir supprimer ce processus? Ceci ne peut pas être annulé.",
"download": "Télécharger",
@@ -1980,50 +1998,50 @@
"missingTileControlNetModel": "Aucun modèle ControlNet valide installé"
},
"stylePresets": {
"deleteTemplate": "Supprimer le modèle",
"editTemplate": "Modifier le modèle",
"deleteTemplate": "Supprimer le template",
"editTemplate": "Modifier le template",
"exportFailed": "Impossible de générer et de télécharger le CSV",
"name": "Nom",
"acceptedColumnsKeys": "Colonnes/clés acceptées :",
"promptTemplatesDesc1": "Les modèles de prompt ajoutent du texte aux prompts que vous écrivez dans la zone de saisie des prompts.",
"promptTemplatesDesc1": "Les templates de prompt ajoutent du texte aux prompts que vous écrivez dans la zone de saisie.",
"private": "Privé",
"searchByName": "Rechercher par nom",
"viewList": "Afficher la liste des modèles",
"noTemplates": "Aucun modèle",
"viewList": "Afficher la liste des templates",
"noTemplates": "Aucun templates",
"insertPlaceholder": "Insérer un placeholder",
"defaultTemplates": "Modèles par défaut",
"defaultTemplates": "Template pré-défini",
"deleteImage": "Supprimer l'image",
"createPromptTemplate": "Créer un modèle de prompt",
"createPromptTemplate": "Créer un template de prompt",
"negativePrompt": "Prompt négatif",
"promptTemplatesDesc3": "Si vous omettez le placeholder, le modèle sera ajouté à la fin de votre prompt.",
"promptTemplatesDesc3": "Si vous omettez le placeholder, le template sera ajouté à la fin de votre prompt.",
"positivePrompt": "Prompt positif",
"choosePromptTemplate": "Choisir un modèle de prompt",
"choosePromptTemplate": "Choisir un template de prompt",
"toggleViewMode": "Basculer le mode d'affichage",
"updatePromptTemplate": "Mettre à jour le modèle de prompt",
"flatten": "Intégrer le modèle sélectionné dans le prompt actuel",
"myTemplates": "Mes modèles",
"updatePromptTemplate": "Mettre à jour le template de prompt",
"flatten": "Intégrer le template sélectionné dans le prompt actuel",
"myTemplates": "Mes Templates",
"type": "Type",
"exportDownloaded": "Exportation téléchargée",
"clearTemplateSelection": "Supprimer la sélection de modèle",
"promptTemplateCleared": "Modèle de prompt effacé",
"templateDeleted": "Modèle de prompt supprimé",
"exportPromptTemplates": "Exporter mes modèles de prompt (CSV)",
"clearTemplateSelection": "Supprimer la sélection de template",
"promptTemplateCleared": "Template de prompt effacé",
"templateDeleted": "Template de prompt supprimé",
"exportPromptTemplates": "Exporter mes templates de prompt (CSV)",
"nameColumn": "'nom'",
"positivePromptColumn": "\"prompt\" ou \"prompt_positif\"",
"useForTemplate": "Utiliser pour le modèle de prompt",
"uploadImage": "Charger une image",
"importTemplates": "Importer des modèles de prompt (CSV/JSON)",
"useForTemplate": "Utiliser pour le template de prompt",
"uploadImage": "Importer une image",
"importTemplates": "Importer des templates de prompt (CSV/JSON)",
"negativePromptColumn": "'prompt_négatif'",
"deleteTemplate2": "Êtes-vous sûr de vouloir supprimer ce modèle? Cette action ne peut pas être annulée.",
"deleteTemplate2": "Êtes-vous sûr de vouloir supprimer ce template? Cette action ne peut pas être annulée.",
"preview": "Aperçu",
"shared": "Partagé",
"noMatchingTemplates": "Aucun modèle correspondant",
"sharedTemplates": "Modèles partagés",
"unableToDeleteTemplate": "Impossible de supprimer le modèle de prompt",
"noMatchingTemplates": "Aucun templates correspondant",
"sharedTemplates": "Template partagés",
"unableToDeleteTemplate": "Impossible de supprimer le template de prompt",
"active": "Actif",
"copyTemplate": "Copier le modèle",
"viewModeTooltip": "Voici à quoi ressemblera votre prompt avec le modèle actuellement sélectionné. Pour modifier votre prompt, cliquez n'importe où dans la zone de texte.",
"promptTemplatesDesc2": "Utilisez la chaîne de remplacement <Pre>{{placeholder}}</Pre> pour spécifier où votre prompt doit être inclus dans le modèle."
"copyTemplate": "Copier le template",
"viewModeTooltip": "Voici à quoi ressemblera votre prompt avec le template actuellement sélectionné. Pour modifier votre prompt, cliquez n'importe où dans la zone de texte.",
"promptTemplatesDesc2": "Utilisez la chaîne de remplacement <Pre>{{placeholder}}</Pre> pour spécifier où votre prompt doit être inclus dans le template."
},
"system": {
"logNamespaces": {
@@ -2051,8 +2069,12 @@
"enableLogging": "Activer la journalisation"
},
"newUserExperience": {
"toGetStarted": "Pour commencer, saisissez un prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un modèle de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement dans la <StrongComponent>Galerie</StrongComponent> ou de les modifier sur la <StrongComponent>Toile</StrongComponent>.",
"gettingStartedSeries": "Vous souhaitez plus de conseils? Consultez notre <LinkComponent>Série de démarrage</LinkComponent> pour des astuces sur l'exploitation du plein potentiel de l'Invoke Studio."
"toGetStarted": "Pour commencer, saisissez un prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement dans la <StrongComponent>Galerie</StrongComponent> ou de les modifier sur la <StrongComponent>Toile</StrongComponent>.",
"gettingStartedSeries": "Vous souhaitez plus de conseils? Consultez notre <LinkComponent>Série de démarrage</LinkComponent> pour des astuces sur l'exploitation du plein potentiel de l'Invoke Studio.",
"noModelsInstalled": "Il semblerait qu'aucun modèle ne soit installé",
"downloadStarterModels": "Télécharger les modèles de démarrage",
"importModels": "Importer Modèles",
"toGetStartedLocal": "Pour commencer, assurez-vous de télécharger ou d'importer des modèles nécessaires pour exécuter Invoke. Ensuite, saisissez le prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement sur <StrongComponent>Galerie</StrongComponent> ou les modifier sur la <StrongComponent>Toile</StrongComponent>."
},
"upsell": {
"shareAccess": "Partager l'accès",

View File

@@ -577,7 +577,18 @@
"noMatchingModels": "Nessun modello corrispondente",
"starterModelsInModelManager": "I modelli iniziali possono essere trovati in Gestione Modelli",
"spandrelImageToImage": "Immagine a immagine (Spandrel)",
"learnMoreAboutSupportedModels": "Scopri di più sui modelli che supportiamo"
"learnMoreAboutSupportedModels": "Scopri di più sui modelli che supportiamo",
"starterBundles": "Pacchetti per iniziare",
"installingBundle": "Installazione del pacchetto",
"skippingXDuplicates_one": ", saltando {{count}} duplicato",
"skippingXDuplicates_many": ", saltando {{count}} duplicati",
"skippingXDuplicates_other": ", saltando {{count}} duplicati",
"installingModel": "Installazione del modello",
"installingXModels_one": "Installazione di {{count}} modello",
"installingXModels_many": "Installazione di {{count}} modelli",
"installingXModels_other": "Installazione di {{count}} modelli",
"includesNModels": "Include {{n}} modelli e le loro dipendenze",
"starterBundleHelpText": "Installa facilmente tutti i modelli necessari per iniziare con un modello base, tra cui un modello principale, controlnet, adattatori IP e altro. Selezionando un pacchetto salterai tutti i modelli che hai già installato."
},
"parameters": {
"images": "Immagini",
@@ -722,7 +733,7 @@
"serverError": "Errore del Server",
"connected": "Connesso al server",
"canceled": "Elaborazione annullata",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"uploadFailedInvalidUploadDesc": "Devono essere immagini PNG o JPEG.",
"parameterSet": "Parametro richiamato",
"parameterNotSet": "Parametro non richiamato",
"problemCopyingImage": "Impossibile copiare l'immagine",
@@ -731,7 +742,7 @@
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
"imageUploaded": "Immagine caricata",
"addedToBoard": "Aggiunto alla bacheca",
"addedToBoard": "Aggiunto alle risorse della bacheca {{name}}",
"modelAddedSimple": "Modello aggiunto alla Coda",
"imageUploadFailed": "Caricamento immagine non riuscito",
"setControlImage": "Imposta come immagine di controllo",
@@ -770,7 +781,12 @@
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
"layerCopiedToClipboard": "Livello copiato negli appunti",
"imageNotLoadedDesc": "Impossibile trovare l'immagine",
"linkCopied": "Collegamento copiato"
"linkCopied": "Collegamento copiato",
"addedToUncategorized": "Aggiunto alle risorse della bacheca $t(boards.uncategorized)",
"imagesWillBeAddedTo": "Le immagini caricate verranno aggiunte alle risorse della bacheca {{boardName}}.",
"uploadFailedInvalidUploadDesc_withCount_one": "Devi caricare al massimo 1 immagine PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_many": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG o JPEG."
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -785,7 +801,8 @@
"about": "Informazioni",
"submitSupportTicket": "Invia ticket di supporto",
"toggleLeftPanel": "Attiva/disattiva il pannello sinistro (T)",
"toggleRightPanel": "Attiva/disattiva il pannello destro (G)"
"toggleRightPanel": "Attiva/disattiva il pannello destro (G)",
"uploadImages": "Carica immagine(i)"
},
"nodes": {
"zoomOutNodes": "Rimpicciolire",
@@ -2006,7 +2023,11 @@
},
"newUserExperience": {
"gettingStartedSeries": "Desideri maggiori informazioni? Consulta la nostra <LinkComponent>Getting Started Series</LinkComponent> per suggerimenti su come sfruttare appieno il potenziale di Invoke Studio.",
"toGetStarted": "Per iniziare, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>."
"toGetStarted": "Per iniziare, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>.",
"importModels": "Importa modelli",
"downloadStarterModels": "Scarica i modelli per iniziare",
"noModelsInstalled": "Sembra che tu non abbia installato alcun modello",
"toGetStartedLocal": "Per iniziare, assicurati di scaricare o importare i modelli necessari per eseguire Invoke. Quindi, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>."
},
"whatsNew": {
"canvasV2Announcement": {

View File

@@ -94,7 +94,8 @@
"reset": "Сброс",
"none": "Ничего",
"new": "Новый",
"ok": "Ok"
"ok": "Ok",
"close": "Закрыть"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -160,7 +161,9 @@
"openViewer": "Открыть просмотрщик",
"closeViewer": "Закрыть просмотрщик",
"imagesTab": "Изображения, созданные и сохраненные в Invoke.",
"assetsTab": "Файлы, которые вы загрузили для использования в своих проектах."
"assetsTab": "Файлы, которые вы загрузили для использования в своих проектах.",
"boardsSettings": "Настройки доски",
"imagesSettings": "Настройки галереи изображений"
},
"hotkeys": {
"searchHotkeys": "Поиск горячих клавиш",
@@ -583,7 +586,18 @@
"learnMoreAboutSupportedModels": "Подробнее о поддерживаемых моделях",
"t5Encoder": "T5 энкодер",
"spandrelImageToImage": "Image to Image (Spandrel)",
"clipEmbed": "CLIP Embed"
"clipEmbed": "CLIP Embed",
"installingXModels_one": "Установка {{count}} модели",
"installingXModels_few": "Установка {{count}} моделей",
"installingXModels_many": "Установка {{count}} моделей",
"installingBundle": "Установка пакета",
"installingModel": "Установка модели",
"starterBundles": "Стартовые пакеты",
"skippingXDuplicates_one": ", пропуская {{count}} дубликат",
"skippingXDuplicates_few": ", пропуская {{count}} дубликата",
"skippingXDuplicates_many": ", пропуская {{count}} дубликатов",
"includesNModels": "Включает в себя {{n}} моделей и их зависимостей",
"starterBundleHelpText": "Легко установите все модели, необходимые для начала работы с базовой моделью, включая основную модель, сети управления, IP-адаптеры и многое другое. При выборе комплекта все уже установленные модели будут пропущены."
},
"parameters": {
"images": "Изображения",
@@ -730,7 +744,7 @@
"serverError": "Ошибка сервера",
"connected": "Подключено к серверу",
"canceled": "Обработка отменена",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"uploadFailedInvalidUploadDesc": "Это должны быть изображения PNG или JPEG.",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"problemCopyingImage": "Не удается скопировать изображение",
@@ -742,7 +756,7 @@
"setNodeField": "Установить как поле узла",
"invalidUpload": "Неверная загрузка",
"imageUploaded": "Изображение загружено",
"addedToBoard": "Добавлено на доску",
"addedToBoard": "Добавлено в активы доски {{name}}",
"workflowLoaded": "Рабочий процесс загружен",
"problemDeletingWorkflow": "Проблема с удалением рабочего процесса",
"modelAddedSimple": "Модель добавлена в очередь",
@@ -777,7 +791,13 @@
"unableToLoadStylePreset": "Невозможно загрузить предустановку стиля",
"layerCopiedToClipboard": "Слой скопирован в буфер обмена",
"sentToUpscale": "Отправить на увеличение",
"layerSavedToAssets": "Слой сохранен в активах"
"layerSavedToAssets": "Слой сохранен в активах",
"linkCopied": "Ссылка скопирована",
"addedToUncategorized": "Добавлено в активы доски $t(boards.uncategorized)",
"imagesWillBeAddedTo": "Загруженные изображения будут добавлены в активы доски {{boardName}}.",
"uploadFailedInvalidUploadDesc_withCount_one": "Должно быть не более {{count}} изображения в формате PNG или JPEG.",
"uploadFailedInvalidUploadDesc_withCount_few": "Должно быть не более {{count}} изображений в формате PNG или JPEG.",
"uploadFailedInvalidUploadDesc_withCount_many": "Должно быть не более {{count}} изображений в формате PNG или JPEG."
},
"accessibility": {
"uploadImage": "Загрузить изображение",
@@ -792,7 +812,8 @@
"about": "Об этом",
"submitSupportTicket": "Отправить тикет в службу поддержки",
"toggleRightPanel": "Переключить правую панель (G)",
"toggleLeftPanel": "Переключить левую панель (T)"
"toggleLeftPanel": "Переключить левую панель (T)",
"uploadImages": "Загрузить изображения"
},
"nodes": {
"zoomInNodes": "Увеличьте масштаб",
@@ -933,7 +954,7 @@
"saveToGallery": "Сохранить в галерею",
"noWorkflows": "Нет рабочих процессов",
"noMatchingWorkflows": "Нет совпадающих рабочих процессов",
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>"
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>."
},
"boards": {
"autoAddBoard": "Авто добавление Доски",
@@ -1409,7 +1430,8 @@
"recallParameter": "Отозвать {{label}}",
"allPrompts": "Все запросы",
"imageDimensions": "Размеры изображения",
"canvasV2Metadata": "Холст"
"canvasV2Metadata": "Холст",
"guidance": "Точность"
},
"queue": {
"status": "Статус",
@@ -1561,7 +1583,12 @@
"defaultWorkflows": "Стандартные рабочие процессы",
"deleteWorkflow2": "Вы уверены, что хотите удалить этот рабочий процесс? Это нельзя отменить.",
"chooseWorkflowFromLibrary": "Выбрать рабочий процесс из библиотеки",
"uploadAndSaveWorkflow": "Загрузить в библиотеку"
"uploadAndSaveWorkflow": "Загрузить в библиотеку",
"edit": "Редактировать",
"download": "Скачать",
"copyShareLink": "Скопировать ссылку на общий доступ",
"copyShareLinkForWorkflow": "Скопировать ссылку на общий доступ для рабочего процесса",
"delete": "Удалить"
},
"hrf": {
"enableHrf": "Включить исправление высокого разрешения",
@@ -1890,7 +1917,10 @@
"fitToBbox": "Вместить в рамку",
"reset": "Сбросить",
"apply": "Применить",
"cancel": "Отменить"
"cancel": "Отменить",
"fitModeContain": "Уместить",
"fitMode": "Режим подгонки",
"fitModeFill": "Заполнить"
},
"disableAutoNegative": "Отключить авто негатив",
"deleteReferenceImage": "Удалить эталонное изображение",
@@ -1920,7 +1950,8 @@
"globalReferenceImage": "Глобальное эталонное изображение",
"sendToGallery": "Отправить в галерею",
"referenceImage": "Эталонное изображение",
"addGlobalReferenceImage": "Добавить $t(controlLayers.globalReferenceImage)"
"addGlobalReferenceImage": "Добавить $t(controlLayers.globalReferenceImage)",
"newImg2ImgCanvasFromImage": "Новое img2img из изображения"
},
"ui": {
"tabs": {

View File

@@ -4,6 +4,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
import { useLogger } from 'app/logging/useLogger';
import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
@@ -59,6 +60,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
useGlobalModifiersInit();
useGlobalHotkeys();
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();

View File

@@ -2,6 +2,8 @@ import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
@@ -20,7 +22,7 @@ import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $socketOptions } from 'services/events/stores';
@@ -46,6 +48,7 @@ interface Props extends PropsWithChildren {
isDebugging?: boolean;
logo?: ReactNode;
workflowCategories?: WorkflowCategory[];
loggingOverrides?: LoggingOverrides;
}
const InvokeAIUI = ({
@@ -65,7 +68,26 @@ const InvokeAIUI = ({
isDebugging = false,
logo,
workflowCategories,
loggingOverrides,
}: Props) => {
useLayoutEffect(() => {
/*
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
* possible opportunity.
*
* Once redux initializes, we will check the user's settings and update the logging config accordingly. See
* `useSyncLoggingConfig`.
*/
$loggingOverrides.set(loggingOverrides);
// Until we get the user's settings, we will use the overrides OR default values.
configureLogging(
loggingOverrides?.logIsEnabled ?? true,
loggingOverrides?.logLevel ?? 'debug',
loggingOverrides?.logNamespaces ?? '*'
);
}, [loggingOverrides]);
useEffect(() => {
// configure API client token
if (token) {

View File

@@ -9,11 +9,10 @@ const serializeMessage: MessageSerializer = (message) => {
};
ROARR.serializeMessage = serializeMessage;
ROARR.write = createLogWriter();
export const BASE_CONTEXT = {};
const BASE_CONTEXT = {};
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export const zLogNamespace = z.enum([
'canvas',
@@ -35,8 +34,22 @@ export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fat
export type LogLevel = z.infer<typeof zLogLevel>;
export const isLogLevel = (v: unknown): v is LogLevel => zLogLevel.safeParse(v).success;
/**
* Override logging settings.
* @property logIsEnabled Override the enabled log state. Omit to use the user's settings.
* @property logNamespaces Override the enabled log namespaces. Use `"*"` for all namespaces. Omit to use the user's settings.
* @property logLevel Override the log level. Omit to use the user's settings.
*/
export type LoggingOverrides = {
logIsEnabled?: boolean;
logNamespaces?: LogNamespace[] | '*';
logLevel?: LogLevel;
};
export const $loggingOverrides = atom<LoggingOverrides | undefined>();
// Translate human-readable log levels to numbers, used for log filtering
export const LOG_LEVEL_MAP: Record<LogLevel, number> = {
const LOG_LEVEL_MAP: Record<LogLevel, number> = {
trace: 10,
debug: 20,
info: 30,
@@ -44,3 +57,40 @@ export const LOG_LEVEL_MAP: Record<LogLevel, number> = {
error: 50,
fatal: 60,
};
/**
* Configure logging, pushing settings to local storage.
*
* @param logIsEnabled Whether logging is enabled
* @param logLevel The log level
* @param logNamespaces A list of log namespaces to enable, or '*' to enable all
*/
export const configureLogging = (
logIsEnabled: boolean = true,
logLevel: LogLevel = 'warn',
logNamespaces: LogNamespace[] | '*'
): void => {
if (!logIsEnabled) {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
} else {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
const namespaces = logNamespaces === '*' ? zLogNamespace.options : logNamespaces;
if (namespaces.length > 0) {
filter += ` AND (${namespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
} else {
// This effectively hides all logs because we use namespaces for all logs
filter += ' AND context.namespace:undefined';
}
localStorage.setItem('ROARR_FILTER', filter);
}
ROARR.write = createLogWriter();
};

View File

@@ -1,53 +1,9 @@
import { createLogWriter } from '@roarr/browser-log-writer';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectSystemLogIsEnabled,
selectSystemLogLevel,
selectSystemLogNamespaces,
} from 'features/system/store/systemSlice';
import { useEffect, useMemo } from 'react';
import { ROARR, Roarr } from 'roarr';
import { useMemo } from 'react';
import type { LogNamespace } from './logger';
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
import { logger } from './logger';
export const useLogger = (namespace: LogNamespace) => {
const logLevel = useAppSelector(selectSystemLogLevel);
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
// The provided Roarr browser log writer uses localStorage to config logging to console
useEffect(() => {
if (logIsEnabled) {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
if (logNamespaces.length > 0) {
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
} else {
filter += ' AND context.namespace:undefined';
}
localStorage.setItem('ROARR_FILTER', filter);
} else {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
}
ROARR.write = createLogWriter();
}, [logLevel, logIsEnabled, logNamespaces]);
// Update the module-scoped logger context as needed
useEffect(() => {
// TODO: type this properly
//eslint-disable-next-line @typescript-eslint/no-explicit-any
const newContext: Record<string, any> = {
...BASE_CONTEXT,
};
$logger.set(Roarr.child(newContext));
}, []);
const log = useMemo(() => logger(namespace), [namespace]);
return log;

View File

@@ -0,0 +1,43 @@
import { useStore } from '@nanostores/react';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import {
selectSystemLogIsEnabled,
selectSystemLogLevel,
selectSystemLogNamespaces,
} from 'features/system/store/systemSlice';
import { useLayoutEffect } from 'react';
/**
* This hook synchronizes the logging configuration stored in Redux with the logging system, which uses localstorage.
*
* The sync is one-way: from Redux to localstorage. This means that changes made in the UI will be reflected in the
* logging system, but changes made directly to localstorage will not be reflected in the UI.
*
* See {@link configureLogging}
*/
export const useSyncLoggingConfig = () => {
useAssertSingleton('useSyncLoggingConfig');
const loggingOverrides = useStore($loggingOverrides);
const logLevel = useAppSelector(selectSystemLogLevel);
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
useLayoutEffect(() => {
configureLogging(
loggingOverrides?.logIsEnabled ?? logIsEnabled,
loggingOverrides?.logLevel ?? logLevel,
loggingOverrides?.logNamespaces ?? logNamespaces
);
}, [
logIsEnabled,
logLevel,
logNamespaces,
loggingOverrides?.logIsEnabled,
loggingOverrides?.logLevel,
loggingOverrides?.logNamespaces,
]);
};

View File

@@ -7,12 +7,20 @@ import { diff } from 'jsondiffpatch';
/**
* Super simple logger middleware. Useful for debugging when the redux devtools are awkward.
*/
export const debugLoggerMiddleware: Middleware = (api: MiddlewareAPI) => (next) => (action) => {
const originalState = api.getState();
console.log('REDUX: dispatching', action);
const result = next(action);
const nextState = api.getState();
console.log('REDUX: next state', nextState);
console.log('REDUX: diff', diff(originalState, nextState));
return result;
};
export const getDebugLoggerMiddleware =
(options?: { withDiff?: boolean; withNextState?: boolean }): Middleware =>
(api: MiddlewareAPI) =>
(next) =>
(action) => {
const originalState = api.getState();
console.log('REDUX: dispatching', action);
const result = next(action);
const nextState = api.getState();
if (options?.withNextState) {
console.log('REDUX: next state', nextState);
}
if (options?.withDiff) {
console.log('REDUX: diff', diff(originalState, nextState));
}
return result;
};

View File

@@ -1,7 +1,7 @@
import type { FilterType } from 'features/controlLayers/store/filters';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { TabName } from 'features/ui/store/uiTypes';
import type { O } from 'ts-toolbelt';
import type { PartialDeep } from 'type-fest';
/**
* A disable-able application feature
@@ -119,4 +119,4 @@ export type AppConfig = {
};
};
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;
export type PartialAppConfig = PartialDeep<AppConfig>;

View File

@@ -1,4 +1,12 @@
type SerializableValue = string | number | boolean | null | undefined | SerializableValue[] | SerializableObject;
type SerializableValue =
| string
| number
| boolean
| null
| undefined
| SerializableValue[]
| readonly SerializableValue[]
| SerializableObject;
export type SerializableObject = {
[k: string | number]: SerializableValue;
};

View File

@@ -34,7 +34,6 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
isDisabled={isFLUX}
>
{t('controlLayers.globalReferenceImage')}
</Button>

View File

@@ -0,0 +1,24 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectAutoProcess, settingsAutoProcessToggled } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasAutoProcessSwitch = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoProcess = useAppSelector(selectAutoProcess);
const onChange = useCallback(() => {
dispatch(settingsAutoProcessToggled());
}, [dispatch]);
return (
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
<Switch size="sm" isChecked={autoProcess} onChange={onChange} />
</FormControl>
);
});
CanvasAutoProcessSwitch.displayName = 'CanvasAutoProcessSwitch';

View File

@@ -5,6 +5,7 @@ import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/componen
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import {
EntityIdentifierContext,
@@ -15,6 +16,7 @@ import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/sel
import {
isFilterableEntityIdentifier,
isSaveableEntityIdentifier,
isSegmentableEntityIdentifier,
isTransformableEntityIdentifier,
} from 'features/controlLayers/store/types';
import { memo } from 'react';
@@ -27,6 +29,7 @@ const CanvasContextMenuSelectedEntityMenuItemsContent = memo(() => {
<MenuGroup title={title}>
{isFilterableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsFilter />}
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsTransform />}
{isSegmentableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSegment />}
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCopyToClipboard />}
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSave />}
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCropToBbox />}

View File

@@ -40,7 +40,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
/>
<MenuList>
<MenuGroup title={t('controlLayers.global')}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isFLUX}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage}>
{t('controlLayers.globalReferenceImage')}
</MenuItem>
</MenuGroup>

View File

@@ -10,6 +10,7 @@ import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { SegmentAnything } from 'features/controlLayers/components/SegmentAnything/SegmentAnything';
import { StagingAreaIsStagingGate } from 'features/controlLayers/components/StagingArea/StagingAreaIsStagingGate';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
@@ -101,6 +102,7 @@ export const CanvasMainPanelContent = memo(() => {
<CanvasManagerProviderGate>
<Filter />
<Transform />
<SegmentAnything />
</CanvasManagerProviderGate>
</Flex>
<CanvasDropArea />

View File

@@ -0,0 +1,28 @@
import { FormControl, FormLabel, Switch, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectIsolatedLayerPreview,
settingsIsolatedLayerPreviewToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasOperationIsolatedLayerPreviewSwitch = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isolatedLayerPreview = useAppSelector(selectIsolatedLayerPreview);
const onChangeIsolatedPreview = useCallback(() => {
dispatch(settingsIsolatedLayerPreviewToggled());
}, [dispatch]);
return (
<Tooltip label={t('controlLayers.settings.isolatedLayerPreviewDesc')}>
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.settings.isolatedPreview')}</FormLabel>
<Switch size="sm" isChecked={isolatedLayerPreview} onChange={onChangeIsolatedPreview} />
</FormControl>
</Tooltip>
);
});
CanvasOperationIsolatedLayerPreviewSwitch.displayName = 'CanvasOperationIsolatedLayerPreviewSwitch';

View File

@@ -7,6 +7,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { ControlLayerMenuItemsConvertControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertControlToRaster';
import { ControlLayerMenuItemsTransparencyEffect } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsTransparencyEffect';
@@ -23,6 +24,7 @@ export const ControlLayerMenuItems = memo(() => {
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsSegment />
<ControlLayerMenuItemsConvertControlToRaster />
<ControlLayerMenuItemsTransparencyEffect />
<MenuDivider />

View File

@@ -1,18 +1,15 @@
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import {
selectAutoProcessFilter,
selectIsolatedFilteringPreview,
settingsAutoProcessFilterToggled,
settingsIsolatedFilteringPreviewToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type { FilterConfig } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
@@ -23,19 +20,13 @@ import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold
const FilterContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const config = useStore(adapter.filterer.$filterConfig);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasProcessed = useStore(adapter.filterer.$hasProcessed);
const autoProcessFilter = useAppSelector(selectAutoProcessFilter);
const isolatedFilteringPreview = useAppSelector(selectIsolatedFilteringPreview);
const onChangeIsolatedPreview = useCallback(() => {
dispatch(settingsIsolatedFilteringPreviewToggled());
}, [dispatch]);
const autoProcess = useAppSelector(selectAutoProcess);
const onChangeFilterConfig = useCallback(
(filterConfig: FilterConfig) => {
@@ -51,10 +42,6 @@ const FilterContent = memo(
[adapter.filterer.$filterConfig]
);
const onChangeAutoProcessFilter = useCallback(() => {
dispatch(settingsAutoProcessFilterToggled());
}, [dispatch]);
const isValid = useMemo(() => {
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
}, [config]);
@@ -94,14 +81,8 @@ const FilterContent = memo(
{t('controlLayers.filter.filter')}
</Heading>
<Spacer />
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
<Switch size="sm" isChecked={autoProcessFilter} onChange={onChangeAutoProcessFilter} />
</FormControl>
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.settings.isolatedPreview')}</FormLabel>
<Switch size="sm" isChecked={isolatedFilteringPreview} onChange={onChangeIsolatedPreview} />
</FormControl>
<CanvasAutoProcessSwitch />
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<FilterTypeSelect filterType={config.type} onChange={onChangeFilterType} />
<FilterSettings filterConfig={config} onChange={onChangeFilterConfig} />
@@ -112,7 +93,7 @@ const FilterContent = memo(
onClick={adapter.filterer.processImmediate}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.process')}
isDisabled={!isValid || autoProcessFilter}
isDisabled={!isValid || autoProcess}
>
{t('controlLayers.filter.process')}
</Button>

View File

@@ -2,7 +2,7 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectBase, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
@@ -11,9 +11,13 @@ import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
const FLUX_CLIP_VISION = 'ViT-L';
const CLIP_VISION_OPTIONS = [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
];
type Props = {
@@ -47,6 +51,8 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
[onChangeCLIPVisionModel]
);
const isFLUX = useAppSelector(selectIsFLUX);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
@@ -64,10 +70,16 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
isLoading,
});
const clipVisionModelValue = useMemo(
() => CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel),
[clipVisionModel]
);
const clipVisionOptions = useMemo(() => {
return CLIP_VISION_OPTIONS.map((option) => ({
...option,
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
}));
}, [isFLUX]);
const clipVisionModelValue = useMemo(() => {
return CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel);
}, [clipVisionModel]);
return (
<Flex gap={2}>
@@ -85,7 +97,7 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
{selectedModel?.format === 'checkpoint' && (
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
<Combobox
options={CLIP_VISION_OPTIONS}
options={clipVisionOptions}
placeholder={t('common.placeholderSelectAModel')}
value={clipVisionModelValue}
onChange={_onChangeCLIPVisionModel}

View File

@@ -16,6 +16,7 @@ import {
referenceImageIPAdapterModelChanged,
referenceImageIPAdapterWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { IPAImageDropData } from 'features/dnd/types';
@@ -90,6 +91,8 @@ export const IPAdapterSettings = memo(() => {
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const isBusy = useCanvasIsBusy();
const isFLUX = useAppSelector(selectIsFLUX);
return (
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
@@ -113,7 +116,7 @@ export const IPAdapterSettings = memo(() => {
</Flex>
<Flex gap={2} w="full" alignItems="center">
<Flex flexDir="column" gap={2} w="full">
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>

View File

@@ -7,6 +7,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { RasterLayerMenuItemsConvertRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertRasterToControl';
import { memo } from 'react';
@@ -22,6 +23,7 @@ export const RasterLayerMenuItems = memo(() => {
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsSegment />
<RasterLayerMenuItemsConvertRasterToControl />
<MenuDivider />
<CanvasEntityMenuItemsCropToBbox />

View File

@@ -0,0 +1,124 @@
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { SegmentAnythingPointType } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingPointType';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiStarBold, PiXBold } from 'react-icons/pi';
const SegmentAnythingContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.segmentAnything.$isProcessing);
const hasPoints = useStore(adapter.segmentAnything.$hasPoints);
const autoProcess = useAppSelector(selectAutoProcess);
useRegisteredHotkeys({
id: 'applySegmentAnything',
category: 'canvas',
callback: adapter.segmentAnything.apply,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelSegmentAnything',
category: 'canvas',
callback: adapter.segmentAnything.cancel,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
minW={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.segment.autoMask')}
</Heading>
<Spacer />
<CanvasAutoProcessSwitch />
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<SegmentAnythingPointType adapter={adapter} />
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
leftIcon={<PiStarBold />}
onClick={adapter.segmentAnything.processImmediate}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.process')}
variant="ghost"
isDisabled={!hasPoints || autoProcess}
>
{t('controlLayers.segment.process')}
</Button>
<Spacer />
<Button
leftIcon={<PiArrowsCounterClockwiseBold />}
onClick={adapter.segmentAnything.reset}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.reset')}
variant="ghost"
>
{t('controlLayers.segment.reset')}
</Button>
<Button
leftIcon={<PiCheckBold />}
onClick={adapter.segmentAnything.apply}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.apply')}
variant="ghost"
>
{t('controlLayers.segment.apply')}
</Button>
<Button
leftIcon={<PiXBold />}
onClick={adapter.segmentAnything.cancel}
isLoading={isProcessing}
loadingText={t('common.cancel')}
variant="ghost"
>
{t('controlLayers.segment.cancel')}
</Button>
</ButtonGroup>
</Flex>
);
}
);
SegmentAnythingContent.displayName = 'SegmentAnythingContent';
export const SegmentAnything = () => {
const canvasManager = useCanvasManager();
const adapter = useStore(canvasManager.stateApi.$segmentingAdapter);
if (!adapter) {
return null;
}
return <SegmentAnythingContent adapter={adapter} />;
};

View File

@@ -0,0 +1,44 @@
import { Flex, FormControl, FormLabel, Radio, RadioGroup, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const SegmentAnythingPointType = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const pointType = useStore(adapter.segmentAnything.$pointTypeString);
const onChange = useCallback(
(v: string) => {
const labelAsString = zSAMPointLabelString.parse(v);
const labelAsNumber = SAM_POINT_LABEL_STRING_TO_NUMBER[labelAsString];
adapter.segmentAnything.$pointType.set(labelAsNumber);
},
[adapter.segmentAnything.$pointType]
);
return (
<FormControl w="full">
<FormLabel>{t('controlLayers.segment.pointType')}</FormLabel>
<RadioGroup value={pointType} onChange={onChange} w="full" size="md">
<Flex alignItems="center" w="full" gap={4} fontWeight="semibold" color="base.300">
<Radio value="foreground">
<Text>{t('controlLayers.segment.foreground')}</Text>
</Radio>
<Radio value="background">
<Text>{t('controlLayers.segment.background')}</Text>
</Radio>
<Radio value="neutral">
<Text>{t('controlLayers.segment.neutral')}</Text>
</Radio>
</Flex>
</RadioGroup>
</FormControl>
);
}
);
SegmentAnythingPointType.displayName = 'SegmentAnythingPointType';

View File

@@ -1,28 +1,28 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectIsolatedFilteringPreview,
settingsIsolatedFilteringPreviewToggled,
selectIsolatedLayerPreview,
settingsIsolatedLayerPreviewToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsIsolatedFilteringPreviewSwitch = memo(() => {
export const CanvasSettingsIsolatedLayerPreviewSwitch = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isolatedFilteringPreview = useAppSelector(selectIsolatedFilteringPreview);
const isolatedLayerPreview = useAppSelector(selectIsolatedLayerPreview);
const onChange = useCallback(() => {
dispatch(settingsIsolatedFilteringPreviewToggled());
dispatch(settingsIsolatedLayerPreviewToggled());
}, [dispatch]);
return (
<FormControl>
<FormLabel m={0} flexGrow={1}>
{t('controlLayers.settings.isolatedFilteringPreview')}
{t('controlLayers.settings.isolatedLayerPreview')}
</FormLabel>
<Switch size="sm" isChecked={isolatedFilteringPreview} onChange={onChange} />
<Switch size="sm" isChecked={isolatedLayerPreview} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsIsolatedFilteringPreviewSwitch.displayName = 'CanvasSettingsIsolatedFilteringPreviewSwitch';
CanvasSettingsIsolatedLayerPreviewSwitch.displayName = 'CanvasSettingsIsolatedLayerPreviewSwitch';

View File

@@ -1,28 +0,0 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectIsolatedTransformingPreview,
settingsIsolatedTransformingPreviewToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsIsolatedTransformingPreviewSwitch = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isolatedTransformingPreview = useAppSelector(selectIsolatedTransformingPreview);
const onChange = useCallback(() => {
dispatch(settingsIsolatedTransformingPreviewToggled());
}, [dispatch]);
return (
<FormControl>
<FormLabel m={0} flexGrow={1}>
{t('controlLayers.settings.isolatedTransformingPreview')}
</FormLabel>
<Switch size="sm" isChecked={isolatedTransformingPreview} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsIsolatedTransformingPreviewSwitch.displayName = 'CanvasSettingsIsolatedTransformingPreviewSwitch';

View File

@@ -16,9 +16,8 @@ import { CanvasSettingsClipToBboxCheckbox } from 'features/controlLayers/compone
import { CanvasSettingsDynamicGridSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsDynamicGridSwitch';
import { CanvasSettingsSnapToGridCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsGridSize';
import { CanvasSettingsInvertScrollCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsInvertScrollCheckbox';
import { CanvasSettingsIsolatedFilteringPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedFilteringPreviewSwitch';
import { CanvasSettingsIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedLayerPreviewSwitch';
import { CanvasSettingsIsolatedStagingPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedStagingPreviewSwitch';
import { CanvasSettingsIsolatedTransformingPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedTransformingPreviewSwitch';
import { CanvasSettingsLogDebugInfoButton } from 'features/controlLayers/components/Settings/CanvasSettingsLogDebugInfo';
import { CanvasSettingsOutputOnlyMaskedRegionsCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsOutputOnlyMaskedRegionsCheckbox';
import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPreserveMaskCheckbox';
@@ -54,8 +53,7 @@ export const CanvasSettingsPopover = memo(() => {
<CanvasSettingsPressureSensitivityCheckbox />
<CanvasSettingsShowProgressOnCanvas />
<CanvasSettingsIsolatedStagingPreviewSwitch />
<CanvasSettingsIsolatedFilteringPreviewSwitch />
<CanvasSettingsIsolatedTransformingPreviewSwitch />
<CanvasSettingsIsolatedLayerPreviewSwitch />
<CanvasSettingsDynamicGridSwitch />
<CanvasSettingsBboxOverlaySwitch />
<CanvasSettingsShowHUDSwitch />

View File

@@ -10,8 +10,8 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
const canvasManager = useCanvasManager();
const isBusy = useCanvasIsBusy();
const onClick = useCallback(() => {
canvasManager.bbox.fitToLayers();
}, [canvasManager.bbox]);
canvasManager.tool.tools.bbox.fitToLayers();
}, [canvasManager.tool.tools.bbox]);
return (
<IconButton

View File

@@ -1,30 +1,21 @@
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { TransformFitToBboxButtons } from 'features/controlLayers/components/Transform/TransformFitToBboxButtons';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import {
selectIsolatedTransformingPreview,
settingsIsolatedTransformingPreviewToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useCallback, useRef } from 'react';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiXBold } from 'react-icons/pi';
const TransformContent = memo(({ adapter }: { adapter: CanvasEntityAdapter }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.transformer.$isProcessing);
const isolatedTransformingPreview = useAppSelector(selectIsolatedTransformingPreview);
const onChangeIsolatedPreview = useCallback(() => {
dispatch(settingsIsolatedTransformingPreviewToggled());
}, [dispatch]);
const silentTransform = useStore(adapter.transformer.$silentTransform);
useRegisteredHotkeys({
@@ -66,10 +57,7 @@ const TransformContent = memo(({ adapter }: { adapter: CanvasEntityAdapter }) =>
{t('controlLayers.transform.transform')}
</Heading>
<Spacer />
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.settings.isolatedPreview')}</FormLabel>
<Switch size="sm" isChecked={isolatedTransformingPreview} onChange={onChangeIsolatedPreview} />
</FormControl>
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<TransformFitToBboxButtons adapter={adapter} />

View File

@@ -0,0 +1,20 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntitySegmentAnything } from 'features/controlLayers/hooks/useEntitySegmentAnything';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiMaskHappyBold } from 'react-icons/pi';
export const CanvasEntityMenuItemsSegment = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext();
const segmentAnything = useEntitySegmentAnything(entityIdentifier);
return (
<MenuItem onClick={segmentAnything.start} icon={<PiMaskHappyBold />} isDisabled={segmentAnything.isDisabled}>
{t('controlLayers.segment.autoMask')}
</MenuItem>
);
});
CanvasEntityMenuItemsSegment.displayName = 'CanvasEntityMenuItemsSegment';

View File

@@ -0,0 +1,57 @@
import { useStore } from '@nanostores/react';
import { $false } from 'app/store/nanostores/util';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isSegmentableEntityIdentifier } from 'features/controlLayers/store/types';
import { useCallback, useMemo } from 'react';
export const useEntitySegmentAnything = (entityIdentifier: CanvasEntityIdentifier | null) => {
const canvasManager = useCanvasManager();
const adapter = useEntityAdapterSafe(entityIdentifier);
const isBusy = useCanvasIsBusy();
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
const isDisabled = useMemo(() => {
if (!entityIdentifier) {
return true;
}
if (!isSegmentableEntityIdentifier(entityIdentifier)) {
return true;
}
if (!adapter) {
return true;
}
if (isBusy) {
return true;
}
if (!isInteractable) {
return true;
}
if (isEmpty) {
return true;
}
return false;
}, [entityIdentifier, adapter, isBusy, isInteractable, isEmpty]);
const start = useCallback(() => {
if (isDisabled) {
return;
}
if (!entityIdentifier) {
return;
}
if (!isSegmentableEntityIdentifier(entityIdentifier)) {
return;
}
const adapter = canvasManager.getAdapter(entityIdentifier);
if (!adapter) {
return;
}
adapter.segmentAnything.start();
}, [isDisabled, entityIdentifier, canvasManager]);
return { isDisabled, start } as const;
};

View File

@@ -10,11 +10,9 @@ import type { CanvasEntityTransformer } from 'features/controlLayers/konva/Canva
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
import { getKonvaNodeDebugAttrs, getRectIntersection } from 'features/controlLayers/konva/util';
import {
selectIsolatedFilteringPreview,
selectIsolatedTransformingPreview,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { selectIsolatedLayerPreview } from 'features/controlLayers/store/canvasSettingsSlice';
import {
buildSelectIsHidden,
buildSelectIsSelected,
@@ -72,6 +70,15 @@ export abstract class CanvasEntityAdapterBase<
// without requiring all adapters to implement this property and their own `destroy`?
abstract filterer?: CanvasEntityFilterer;
/**
* The segment anything module for this entity adapter. Entities that support segment anything should implement
* this property.
*/
// TODO(psyche): This is in the ABC and not in the concrete classes to allow all adapters to share the `destroy`
// method. If it wasn't in this ABC, we'd get a TS error in `destroy`. Maybe there's a better way to handle this
// without requiring all adapters to implement this property and their own `destroy`?
abstract segmentAnything?: CanvasSegmentAnythingModule;
/**
* Synchronizes the entity state with the canvas. This includes rendering the entity's objects, handling visibility,
* positioning, opacity, locked state, and any other properties.
@@ -264,13 +271,11 @@ export abstract class CanvasEntityAdapterBase<
*/
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsHidden, this.syncVisibility));
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectIsolatedFilteringPreview, this.syncVisibility)
this.manager.stateApi.createStoreSubscription(selectIsolatedLayerPreview, this.syncVisibility)
);
this.subscriptions.add(this.manager.stateApi.$filteringAdapter.listen(this.syncVisibility));
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectIsolatedTransformingPreview, this.syncVisibility)
);
this.subscriptions.add(this.manager.stateApi.$transformingAdapter.listen(this.syncVisibility));
this.subscriptions.add(this.manager.stateApi.$segmentingAdapter.listen(this.syncVisibility));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsSelected, this.syncVisibility));
/**
@@ -435,8 +440,10 @@ export abstract class CanvasEntityAdapterBase<
return;
}
const isolatedLayerPreview = this.manager.stateApi.runSelector(selectIsolatedLayerPreview);
// Handle isolated preview modes - if another entity is filtering or transforming, we may need to hide this entity.
if (this.manager.stateApi.runSelector(selectIsolatedFilteringPreview)) {
if (isolatedLayerPreview) {
const filteringEntityIdentifier = this.manager.stateApi.$filteringAdapter.get()?.entityIdentifier;
if (filteringEntityIdentifier && filteringEntityIdentifier.id !== this.id) {
this.setVisibility(false);
@@ -444,7 +451,7 @@ export abstract class CanvasEntityAdapterBase<
}
}
if (this.manager.stateApi.runSelector(selectIsolatedTransformingPreview)) {
if (isolatedLayerPreview) {
const transformingEntity = this.manager.stateApi.$transformingAdapter.get();
if (
transformingEntity &&
@@ -457,6 +464,14 @@ export abstract class CanvasEntityAdapterBase<
}
}
if (isolatedLayerPreview) {
const segmentingEntity = this.manager.stateApi.$segmentingAdapter.get();
if (segmentingEntity && segmentingEntity.entityIdentifier.id !== this.id) {
this.setVisibility(false);
return;
}
}
// If the entity is not selected and offscreen, we can hide it
if (!this.$isOnScreen.get() && !this.manager.stateApi.getIsSelected(this.entityIdentifier.id)) {
this.setVisibility(false);
@@ -517,8 +532,17 @@ export abstract class CanvasEntityAdapterBase<
this.transformer.stopTransform();
}
this.transformer.destroy();
if (this.filterer?.$isFiltering.get()) {
this.filterer.cancel();
if (this.filterer) {
if (this.filterer.$isFiltering.get()) {
this.filterer.cancel();
}
this.filterer?.destroy();
}
if (this.segmentAnything) {
if (this.segmentAnything.$isSegmenting.get()) {
this.segmentAnything.cancel();
}
this.segmentAnything.destroy();
}
this.konva.layer.destroy();
this.manager.deleteAdapter(this.entityIdentifier);
@@ -534,6 +558,7 @@ export abstract class CanvasEntityAdapterBase<
transformer: this.transformer.repr(),
renderer: this.renderer.repr(),
bufferRenderer: this.bufferRenderer.repr(),
segmentAnything: this.segmentAnything?.repr(),
filterer: this.filterer?.repr(),
hasCache: this.$canvasCache.get() !== null,
isLocked: this.$isLocked.get(),

View File

@@ -5,6 +5,7 @@ import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
import type { CanvasControlLayerState, CanvasEntityIdentifier, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
import { omit } from 'lodash-es';
@@ -17,6 +18,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
bufferRenderer: CanvasEntityBufferObjectRenderer;
transformer: CanvasEntityTransformer;
filterer: CanvasEntityFilterer;
segmentAnything: CanvasSegmentAnythingModule;
constructor(entityIdentifier: CanvasEntityIdentifier<'control_layer'>, manager: CanvasManager) {
super(entityIdentifier, manager, 'control_layer_adapter');
@@ -25,6 +27,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
this.bufferRenderer = new CanvasEntityBufferObjectRenderer(this);
this.transformer = new CanvasEntityTransformer(this);
this.filterer = new CanvasEntityFilterer(this);
this.segmentAnything = new CanvasSegmentAnythingModule(this);
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectState, this.sync));
}

View File

@@ -16,6 +16,7 @@ export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<
bufferRenderer: CanvasEntityBufferObjectRenderer;
transformer: CanvasEntityTransformer;
filterer = undefined;
segmentAnything = undefined;
constructor(entityIdentifier: CanvasEntityIdentifier<'inpaint_mask'>, manager: CanvasManager) {
super(entityIdentifier, manager, 'inpaint_mask_adapter');

View File

@@ -5,6 +5,7 @@ import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
import type { CanvasEntityIdentifier, CanvasRasterLayerState, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
import { omit } from 'lodash-es';
@@ -17,6 +18,7 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
bufferRenderer: CanvasEntityBufferObjectRenderer;
transformer: CanvasEntityTransformer;
filterer: CanvasEntityFilterer;
segmentAnything: CanvasSegmentAnythingModule;
constructor(entityIdentifier: CanvasEntityIdentifier<'raster_layer'>, manager: CanvasManager) {
super(entityIdentifier, manager, 'raster_layer_adapter');
@@ -25,6 +27,7 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
this.bufferRenderer = new CanvasEntityBufferObjectRenderer(this);
this.transformer = new CanvasEntityTransformer(this);
this.filterer = new CanvasEntityFilterer(this);
this.segmentAnything = new CanvasSegmentAnythingModule(this);
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectState, this.sync));
}

View File

@@ -16,6 +16,7 @@ export class CanvasEntityAdapterRegionalGuidance extends CanvasEntityAdapterBase
bufferRenderer: CanvasEntityBufferObjectRenderer;
transformer: CanvasEntityTransformer;
filterer = undefined;
segmentAnything = undefined;
constructor(entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>, manager: CanvasManager) {
super(entityIdentifier, manager, 'regional_guidance_adapter');

View File

@@ -4,7 +4,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectAutoProcessFilter } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type { FilterConfig } from 'features/controlLayers/store/filters';
import { getFilterForModel, IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import type { CanvasImageState } from 'features/controlLayers/store/types';
@@ -15,7 +15,6 @@ import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type CanvasEntityFiltererConfig = {
processDebounceMs: number;
@@ -56,30 +55,41 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.log = this.manager.buildLogger(this);
this.log.debug('Creating filter module');
}
subscribe = () => {
this.subscriptions.add(
this.$filterConfig.listen(() => {
if (this.manager.stateApi.getSettings().autoProcessFilter && this.$isFiltering.get()) {
if (this.manager.stateApi.getSettings().autoProcess && this.$isFiltering.get()) {
this.process();
}
})
);
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectAutoProcessFilter, (autoPreviewFilter) => {
if (autoPreviewFilter && this.$isFiltering.get()) {
this.manager.stateApi.createStoreSubscription(selectAutoProcess, (autoProcess) => {
if (autoProcess && this.$isFiltering.get()) {
this.process();
}
})
);
}
};
unsubscribe = () => {
this.subscriptions.forEach((unsubscribe) => unsubscribe());
this.subscriptions.clear();
};
start = (config?: FilterConfig) => {
const filteringAdapter = this.manager.stateApi.$filteringAdapter.get();
if (filteringAdapter) {
assert(false, `Already filtering an entity: ${filteringAdapter.id}`);
this.log.error(`Already filtering an entity: ${filteringAdapter.id}`);
return;
}
this.log.trace('Initializing filter');
this.subscribe();
if (config) {
this.$filterConfig.set(config);
} else if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) {
@@ -97,7 +107,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
}
this.$isFiltering.set(true);
this.manager.stateApi.$filteringAdapter.set(this.parent);
if (this.manager.stateApi.getSettings().autoProcessFilter) {
if (this.manager.stateApi.getSettings().autoProcess) {
this.processImmediate();
}
};
@@ -204,6 +214,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
replaceObjects: true,
});
this.imageState = null;
this.unsubscribe();
this.$isFiltering.set(false);
this.$hasProcessed.set(false);
this.manager.stateApi.$filteringAdapter.set(null);
@@ -225,6 +236,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.log.trace('Cancelling filter');
this.reset();
this.unsubscribe();
this.$isProcessing.set(false);
this.$isFiltering.set(false);
this.$hasProcessed.set(false);
@@ -243,4 +255,13 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
$filterConfig: this.$filterConfig.get(),
};
};
destroy = () => {
this.log.debug('Destroying module');
if (this.abortController && !this.abortController.signal.aborted) {
this.abortController.abort();
}
this.abortController = null;
this.unsubscribe();
};
}

View File

@@ -234,8 +234,25 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.konva.transformer.on('transform', this.syncObjectGroupWithProxyRect);
this.konva.transformer.on('transformend', this.snapProxyRectToPixelGrid);
this.konva.transformer.on('pointerenter', () => {
this.manager.stage.setCursor('move');
});
this.konva.transformer.on('pointerleave', () => {
this.manager.stage.setCursor('default');
});
this.konva.proxyRect.on('dragmove', this.onDragMove);
this.konva.proxyRect.on('dragend', this.onDragEnd);
this.konva.proxyRect.on('pointerenter', () => {
this.manager.stage.setCursor('move');
});
this.konva.proxyRect.on('pointerleave', () => {
this.manager.stage.setCursor('default');
});
this.subscriptions.add(() => {
this.konva.transformer.off('transform transformend pointerenter pointerleave');
this.konva.proxyRect.off('dragmove dragend pointerenter pointerleave');
});
// When the stage scale changes, we may need to re-scale some of the transformer's components. For example,
// the bbox outline should always be 1 screen pixel wide, so we need to update its stroke width.
@@ -574,9 +591,9 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
syncInteractionState = () => {
this.log.trace('Syncing interaction state');
if (this.manager.$isBusy.get() && !this.$isTransforming.get()) {
// The canvas is busy, we can't interact with the transformer
this.parent.konva.layer.listening(false);
if (this.parent.segmentAnything?.$isSegmenting.get()) {
// When segmenting, the layer should listen but the transformer should not be interactable
this.parent.konva.layer.listening(true);
this._setInteractionMode('off');
return;
}
@@ -609,6 +626,13 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const tool = this.manager.tool.$tool.get();
const isSelected = this.manager.stateApi.getIsSelected(this.parent.id);
if (!isSelected) {
// The layer is not selected
this.parent.konva.layer.listening(false);
this._setInteractionMode('off');
return;
}
if (this.parent.$isEmpty.get()) {
// The layer is totally empty, we can just disable the layer
this.parent.konva.layer.listening(false);
@@ -616,14 +640,14 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
if (isSelected && !this.$isTransforming.get() && tool === 'move') {
if (!this.$isTransforming.get() && tool === 'move') {
// We are moving this layer, it must be listening
this.parent.konva.layer.listening(true);
this._setInteractionMode('drag');
return;
}
if (isSelected && this.$isTransforming.get()) {
if (this.$isTransforming.get()) {
// When transforming, we want the stage to still be movable if the view tool is selected. If the transformer is
// active, it will interrupt the stage drag events. So we should disable listening when the view tool is selected.
if (tool === 'view') {
@@ -633,11 +657,12 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.konva.layer.listening(true);
this._setInteractionMode('all');
}
} else {
// The layer is not selected, or we are using a tool that doesn't need the layer to be listening - disable interaction stuff
this.parent.konva.layer.listening(false);
this._setInteractionMode('off');
return;
}
// The layer is not selected
this.parent.konva.layer.listening(false);
this._setInteractionMode('off');
};
/**

View File

@@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import { SyncableMap } from 'common/util/SyncableMap/SyncableMap';
import { CanvasBboxModule } from 'features/controlLayers/konva/CanvasBboxModule';
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
import { CanvasCompositorModule } from 'features/controlLayers/konva/CanvasCompositorModule';
import { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
@@ -62,7 +61,6 @@ export class CanvasManager extends CanvasModuleBase {
entityRenderer: CanvasEntityRendererModule;
compositor: CanvasCompositorModule;
tool: CanvasToolModule;
bbox: CanvasBboxModule;
stagingArea: CanvasStagingAreaModule;
progressImage: CanvasProgressImageModule;
@@ -111,11 +109,12 @@ export class CanvasManager extends CanvasModuleBase {
this.stateApi.$isFiltering,
this.stateApi.$isTransforming,
this.stateApi.$isRasterizing,
this.stateApi.$isSegmenting,
this.stagingArea.$isStaging,
this.compositor.$isBusy,
],
(isFiltering, isTransforming, isRasterizing, isStaging, isCompositing) => {
return isFiltering || isTransforming || isRasterizing || isStaging || isCompositing;
(isFiltering, isTransforming, isRasterizing, isSegmenting, isStaging, isCompositing) => {
return isFiltering || isTransforming || isRasterizing || isSegmenting || isStaging || isCompositing;
}
);
@@ -123,18 +122,16 @@ export class CanvasManager extends CanvasModuleBase {
this.stage.addLayer(this.background.konva.layer);
this.konva = {
previewLayer: new Konva.Layer({ listening: false, imageSmoothingEnabled: false }),
previewLayer: new Konva.Layer({ listening: true, imageSmoothingEnabled: false }),
};
this.stage.addLayer(this.konva.previewLayer);
this.tool = new CanvasToolModule(this);
this.progressImage = new CanvasProgressImageModule(this);
this.bbox = new CanvasBboxModule(this);
// Must add in this order for correct z-index
this.konva.previewLayer.add(this.stagingArea.konva.group);
this.konva.previewLayer.add(this.progressImage.konva.group);
this.konva.previewLayer.add(this.bbox.konva.group);
this.konva.previewLayer.add(this.tool.konva.group);
}
@@ -232,7 +229,6 @@ export class CanvasManager extends CanvasModuleBase {
getAllModules = (): CanvasModuleBase[] => {
return [
this.bbox,
this.stagingArea,
this.tool,
this.progressImage,
@@ -280,7 +276,6 @@ export class CanvasManager extends CanvasModuleBase {
inpaintMasks: Array.from(this.adapters.inpaintMasks.values()).map((adapter) => adapter.repr()),
regionMasks: Array.from(this.adapters.regionMasks.values()).map((adapter) => adapter.repr()),
stateApi: this.stateApi.repr(),
bbox: this.bbox.repr(),
stagingArea: this.stagingArea.repr(),
tool: this.tool.repr(),
progressImage: this.progressImage.repr(),

View File

@@ -1,10 +1,10 @@
import { Mutex } from 'async-mutex';
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
import type { CanvasStagingAreaModule } from 'features/controlLayers/konva/CanvasStagingAreaModule';
import { loadImage } from 'features/controlLayers/konva/util';
import type { CanvasImageState } from 'features/controlLayers/store/types';
@@ -21,7 +21,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
| CanvasEntityObjectRenderer
| CanvasEntityBufferObjectRenderer
| CanvasStagingAreaModule
| CanvasEntityFilterer;
| CanvasSegmentAnythingModule;
readonly manager: CanvasManager;
readonly log: Logger;
@@ -42,7 +42,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
| CanvasEntityObjectRenderer
| CanvasEntityBufferObjectRenderer
| CanvasStagingAreaModule
| CanvasEntityFilterer
| CanvasSegmentAnythingModule
) {
super();
this.id = state.id;

View File

@@ -8,9 +8,9 @@ import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
import type { O } from 'ts-toolbelt';
import type { SetNonNullable } from 'type-fest';
type ProgressEventWithImage = O.NonNullable<S['InvocationProgressEvent'], 'image'>;
type ProgressEventWithImage = SetNonNullable<S['InvocationProgressEvent'], 'image'>;
const isProgressEventWithImage = (val: S['InvocationProgressEvent']): val is ProgressEventWithImage =>
Boolean(val.image);

View File

@@ -0,0 +1,789 @@
import { rgbaColorToString } from 'common/util/colorCodeTransformers';
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type {
CanvasImageState,
Coordinate,
RgbaColor,
SAMPoint,
SAMPointLabel,
SAMPointLabelString,
} from 'features/controlLayers/store/types';
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { debounce } from 'lodash-es';
import type { Atom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { ImageDTO } from 'services/api/types';
type CanvasSegmentAnythingModuleConfig = {
/**
* The radius of the SAM point Konva circle node.
*/
SAM_POINT_RADIUS: number;
/**
* The border width of the SAM point Konva circle node.
*/
SAM_POINT_BORDER_WIDTH: number;
/**
* The border color of the SAM point Konva circle node.
*/
SAM_POINT_BORDER_COLOR: RgbaColor;
/**
* The color of the SAM point Konva circle node when the label is 1.
*/
SAM_POINT_FOREGROUND_COLOR: RgbaColor;
/**
* The color of the SAM point Konva circle node when the label is -1.
*/
SAM_POINT_BACKGROUND_COLOR: RgbaColor;
/**
* The color of the SAM point Konva circle node when the label is 0.
*/
SAM_POINT_NEUTRAL_COLOR: RgbaColor;
/**
* The color to use for the mask preview overlay.
*/
MASK_COLOR: RgbaColor;
/**
* The debounce time in milliseconds for processing the points.
*/
PROCESS_DEBOUNCE_MS: number;
};
const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
SAM_POINT_RADIUS: 8,
SAM_POINT_BORDER_WIDTH: 2,
SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 },
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // light green
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity
PROCESS_DEBOUNCE_MS: 1000,
};
/**
* The state of a SAM point.
* @property id - The unique identifier of the point.
* @property label - The label of the point. -1 is background, 0 is neutral, 1 is foreground.
* @property konva - The Konva node state of the point.
* @property konva.circle - The Konva circle node of the point. The x and y coordinates for the point are derived from
* this node.
*/
type SAMPointState = {
id: string;
label: SAMPointLabel;
konva: {
circle: Konva.Circle;
};
};
export class CanvasSegmentAnythingModule extends CanvasModuleBase {
readonly type = 'canvas_segment_anything';
readonly id: string;
readonly path: string[];
readonly parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer;
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasSegmentAnythingModuleConfig = DEFAULT_CONFIG;
subscriptions = new Set<() => void>();
/**
* The AbortController used to cancel the filter processing.
*/
abortController: AbortController | null = null;
/**
* Whether the module is currently segmenting an entity.
*/
$isSegmenting = atom<boolean>(false);
/**
* Whether the current set of points has been processed.
*/
$hasProcessed = atom<boolean>(false);
/**
* Whether the module is currently processing the points.
*/
$isProcessing = atom<boolean>(false);
/**
* The type of point to create when segmenting. This is a number representation of the SAMPointLabel enum.
*/
$pointType = atom<SAMPointLabel>(1);
/**
* The type of point to create when segmenting, as a string. This is a computed value based on $pointType.
*/
$pointTypeString = computed<SAMPointLabelString, Atom<SAMPointLabel>>(
this.$pointType,
(pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType]
);
/**
* Whether a point is currently being dragged. This is used to prevent the point additions and deletions during
* dragging.
*/
$isDraggingPoint = atom<boolean>(false);
/**
* The ephemeral image state of the processed image. Only used while segmenting.
*/
imageState: CanvasImageState | null = null;
/**
* The current input points.
*/
$points = atom<SAMPointState[]>([]);
/**
* Whether the module has points. This is a computed value based on $points.
*/
$hasPoints = computed(this.$points, (points) => points.length > 0);
/**
* The masked image object, if it exists.
*/
maskedImage: CanvasObjectImage | null = null;
/**
* The Konva nodes for the module.
*/
konva: {
/**
* The main Konva group node for the module.
*/
group: Konva.Group;
/**
* The Konva group node for the SAM points.
*
* This is a child of the main group node, rendered above the mask group.
*/
pointGroup: Konva.Group;
/**
* The Konva group node for the mask image and compositing rect.
*
* This is a child of the main group node, rendered below the point group.
*/
maskGroup: Konva.Group;
/**
* The Konva rect node for compositing the mask image.
*
* It's rendered with a globalCompositeOperation of 'source-atop' to preview the mask as a semi-transparent overlay.
*/
compositingRect: Konva.Rect;
};
KONVA_CIRCLE_NAME = `${this.type}:circle`;
KONVA_GROUP_NAME = `${this.type}:group`;
KONVA_POINT_GROUP_NAME = `${this.type}:point_group`;
KONVA_MASK_GROUP_NAME = `${this.type}:mask_group`;
KONVA_COMPOSITING_RECT_NAME = `${this.type}:compositing_rect`;
constructor(parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
// Create all konva nodes
this.konva = {
group: new Konva.Group({ name: this.KONVA_GROUP_NAME }),
pointGroup: new Konva.Group({ name: this.KONVA_POINT_GROUP_NAME }),
maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME }),
compositingRect: new Konva.Rect({
name: this.KONVA_COMPOSITING_RECT_NAME,
fill: rgbaColorToString(this.config.MASK_COLOR),
globalCompositeOperation: 'source-atop',
listening: false,
strokeEnabled: false,
perfectDrawEnabled: false,
visible: false,
}),
};
// Points should always be rendered above the mask group
this.konva.group.add(this.konva.maskGroup);
this.konva.group.add(this.konva.pointGroup);
// Compositing rect is added to the mask group - will also be above the mask image, but that doesn't get created
// until after processing
this.konva.maskGroup.add(this.konva.compositingRect);
}
/**
* Synchronizes the cursor style to crosshair.
*/
syncCursorStyle = (): void => {
if (this.$isProcessing.get()) {
this.manager.stage.setCursor('wait');
} else if (this.$isSegmenting.get()) {
this.manager.stage.setCursor('crosshair');
}
};
/**
* Creates a SAM point at the given coordinate with the given label. -1 is background, 0 is neutral, 1 is foreground.
* @param coord The coordinate
* @param label The label.
* @returns The SAM point state.
*/
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
const id = getPrefixedId('sam_point');
const circle = new Konva.Circle({
name: this.KONVA_CIRCLE_NAME,
x: Math.round(coord.x),
y: Math.round(coord.y),
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes
fill: rgbaColorToString(this.getSAMPointColor(label)),
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH), // We will scale this as the stage scale changes
draggable: true,
perfectDrawEnabled: true, // Required for the stroke/fill to draw correctly w/ partial opacity
opacity: 0.6,
dragDistance: 3,
});
// When the point is clicked, remove it
circle.on('pointerup', (e) => {
// Ignore if we are dragging
if (this.$isDraggingPoint.get()) {
return;
}
// This event should not bubble up to the parent, stage or any other nodes
e.cancelBubble = true;
circle.destroy();
this.$points.set(this.$points.get().filter((point) => point.id !== id));
if (this.$points.get().length === 0) {
this.resetEphemeralState();
} else {
this.$hasProcessed.set(false);
}
});
circle.on('dragstart', () => {
this.$isDraggingPoint.set(true);
});
circle.on('dragend', () => {
this.$isDraggingPoint.set(false);
// Point has changed!
this.$hasProcessed.set(false);
this.$points.notify();
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Moved SAM point'
);
});
this.konva.pointGroup.add(circle);
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Created SAM point'
);
return {
id,
label,
konva: { circle },
};
}
/**
* Synchronizes the scales of the SAM points to the stage scale.
*
* SAM points are always the same size, regardless of the stage scale.
*/
syncPointScales = () => {
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
for (const point of this.$points.get()) {
point.konva.circle.radius(radius);
point.konva.circle.strokeWidth(borderWidth);
}
};
/**
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
*/
getSAMPoints = (): SAMPoint[] => {
const points: SAMPoint[] = [];
for (const { konva, label } of this.$points.get()) {
points.push({
// Pull out and round the x and y values from Konva
x: Math.round(konva.circle.x()),
y: Math.round(konva.circle.y()),
label,
});
}
return points;
};
/**
* Handles the pointerup event on the stage. This is used to add a SAM point to the module.
*/
onStagePointerUp = (e: KonvaEventObject<PointerEvent>) => {
// Only handle left-clicks
if (e.evt.button !== 0) {
return;
}
// Ignore if the stage is dragging/panning
if (this.manager.stage.getIsDragging()) {
return;
}
// Ignore if a point is being dragged
if (this.$isDraggingPoint.get()) {
return;
}
// Ignore if we are already processing
if (this.$isProcessing.get()) {
return;
}
// Ignore if the cursor is not within the stage (should never happen)
const cursorPos = this.manager.tool.$cursorPos.get();
if (!cursorPos) {
return;
}
// We need to offset the cursor position by the parent entity's position + pixel rect to get the correct position
const pixelRect = this.parent.transformer.$pixelRect.get();
const parentPosition = addCoords(this.parent.state.position, pixelRect);
// Normalize the cursor position to the parent entity's position
const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition);
// Create a SAM point at the normalized position
const point = this.createPoint(normalizedPoint, this.$pointType.get());
this.$points.set([...this.$points.get(), point]);
// Mark the module as having _not_ processed the points now that they have changed
this.$hasProcessed.set(false);
};
/**
* Adds event listeners needed while segmenting the entity.
*/
subscribe = () => {
this.manager.stage.konva.stage.on('pointerup', this.onStagePointerUp);
this.subscriptions.add(() => {
this.manager.stage.konva.stage.off('pointerup', this.onStagePointerUp);
});
// When we change the processing status, we should update the cursor style and the layer's listening status. For
// example, when processing, we should disable listening on the layer so the user can't add more points, else we
// should enable listening.
this.subscriptions.add(
this.$isProcessing.listen((isProcessing) => {
this.syncCursorStyle();
this.parent.konva.layer.listening(!isProcessing);
})
);
// Scale the SAM points when the stage scale changes
this.subscriptions.add(
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
if (stageAttrs.scale !== oldStageAttrs.scale) {
this.syncPointScales();
}
})
);
// When the points change, process them if autoProcess is enabled
this.subscriptions.add(
this.$points.listen((points) => {
if (points.length === 0) {
return;
}
if (this.manager.stateApi.getSettings().autoProcess) {
this.process();
}
})
);
// When auto-process is enabled, process the points if they have not been processed
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectAutoProcess, (autoProcess) => {
if (this.$points.get().length === 0) {
return;
}
if (autoProcess && !this.$hasProcessed.get()) {
this.process();
}
})
);
};
/**
* Adds event listeners needed while segmenting the entity.
*/
unsubscribe = () => {
this.subscriptions.forEach((unsubscribe) => unsubscribe());
this.subscriptions.clear();
};
/**
* Starts the segmenting process.
*/
start = () => {
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
if (segmentingAdapter) {
this.log.error(`Already segmenting an entity: ${segmentingAdapter.id}`);
return;
}
this.log.trace('Starting segment anything');
// Reset the module's state
this.resetEphemeralState();
this.$isSegmenting.set(true);
// Update the konva group's position to match the parent entity
const pixelRect = this.parent.transformer.$pixelRect.get();
const position = addCoords(this.parent.state.position, pixelRect);
this.konva.group.setAttrs(position);
// Add the module's Konva group to the parent adapter's layer so it is rendered
this.parent.konva.layer.add(this.konva.group);
// Enable listening on the parent adapter's layer so the module can receive pointer events
this.parent.konva.layer.listening(true);
// Subscribe all listeners needed for segmenting (e.g. window pointerup, state listeners)
this.subscribe();
// Set the global segmenting adapter to this module
this.manager.stateApi.$segmentingAdapter.set(this.parent);
// Sync the cursor style to crosshair
this.syncCursorStyle();
};
/**
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
*/
processImmediate = async () => {
if (this.$isProcessing.get()) {
this.log.warn('Already processing');
return;
}
const points = this.getSAMPoints();
if (points.length === 0) {
this.log.trace('No points to segment');
return;
}
this.$isProcessing.set(true);
this.log.trace({ points }, 'Segmenting');
// Rasterize the entity in its current state
const rect = this.parent.transformer.getRelativeRect();
const rasterizeResult = await withResultAsync(() =>
this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } })
);
if (rasterizeResult.isErr()) {
this.log.error({ error: serializeError(rasterizeResult.error) }, 'Error rasterizing entity');
this.$isProcessing.set(false);
return;
}
// Create an AbortController for the segmenting process
const controller = new AbortController();
this.abortController = controller;
// Build the graph for segmenting the image, using the rasterized image DTO
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value);
// Run the graph and get the segmented image output
const segmentResult = await withResultAsync(() =>
this.manager.stateApi.runGraphAndReturnImageOutput({
graph,
outputNodeId,
prepend: true,
signal: controller.signal,
})
);
// If there is an error, log it and bail out of this processing run
if (segmentResult.isErr()) {
this.log.error({ error: serializeError(segmentResult.error) }, 'Error segmenting');
this.$isProcessing.set(false);
// Clean up the abort controller as needed
if (!this.abortController.signal.aborted) {
this.abortController.abort();
}
this.abortController = null;
return;
}
this.log.trace({ imageDTO: segmentResult.value }, 'Segmented');
// Prepare the ephemeral image state
this.imageState = imageDTOToImageObject(segmentResult.value);
// Destroy any existing masked image and create a new one
if (this.maskedImage) {
this.maskedImage.destroy();
}
this.maskedImage = new CanvasObjectImage(this.imageState, this);
// Force update the masked image - after awaiting, the image will be rendered (in memory)
await this.maskedImage.update(this.imageState, true);
// Update the compositing rect to match the image size
this.konva.compositingRect.setAttrs({
width: this.imageState.image.width,
height: this.imageState.image.height,
visible: true,
});
// Now we can add the masked image to the mask group. It will be rendered above the compositing rect, but should be
// under it, so we will move the compositing rect to the top
this.konva.maskGroup.add(this.maskedImage.konva.group);
this.konva.compositingRect.moveToTop();
// Cache the group to ensure the mask is rendered correctly w/ opacity
this.konva.maskGroup.cache();
// We are done processing (still segmenting though!)
this.$isProcessing.set(false);
// The current points have been processed
this.$hasProcessed.set(true);
// Clean up the abort controller as needed
if (!this.abortController.signal.aborted) {
this.abortController.abort();
}
this.abortController = null;
};
/**
* Debounced version of processImmediate.
*/
process = debounce(this.processImmediate, this.config.PROCESS_DEBOUNCE_MS);
/**
* Applies the segmented image to the entity.
*/
apply = () => {
if (!this.$hasProcessed.get()) {
this.log.error('Cannot apply unprocessed points');
return;
}
const imageState = this.imageState;
if (!imageState) {
this.log.error('No image state to apply');
return;
}
this.log.trace('Applying');
// Commit the buffer, which will move the buffer to from the layers' buffer renderer to its main renderer
this.parent.bufferRenderer.commitBuffer();
// Rasterize the entity, this time replacing the objects with the masked image
const rect = this.parent.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.entityIdentifier,
imageObject: imageState,
position: {
x: Math.round(rect.x),
y: Math.round(rect.y),
},
replaceObjects: true,
});
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
/**
* Resets the module (e.g. remove all points and the mask image).
*
* Does not cancel or otherwise complete the segmenting process.
*/
reset = () => {
this.log.trace('Resetting');
this.resetEphemeralState();
};
/**
* Cancels the segmenting process.
*/
cancel = () => {
this.log.trace('Canceling');
// Reset the module's state and tear down, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
/**
* Performs teardown of the module. This shared logic is used for canceling and applying - when the segmenting is
* complete and the module is deactivated.
*
* This method:
* - Removes the module's main Konva node from the parent adapter's layer
* - Removes segmenting event listeners (e.g. window pointerup)
* - Resets the segmenting state
* - Resets the global segmenting adapter
*/
teardown = () => {
this.konva.group.remove();
this.unsubscribe();
this.$isSegmenting.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
};
/**
* Resets the module's ephemeral state. This shared logic is used for resetting, canceling, and applying.
*
* This method:
* - Aborts any processing
* - Destroys ephemeral Konva nodes
* - Resets internal module state
* - Resets non-ephemeral Konva nodes
* - Clears the parent module's buffer
*/
resetEphemeralState = () => {
// First we need to bail out of any processing
this.abortController?.abort();
this.abortController = null;
// Destroy ephemeral konva nodes
for (const point of this.$points.get()) {
point.konva.circle.destroy();
}
if (this.maskedImage) {
this.maskedImage.destroy();
}
// Empty internal module state
this.$points.set([]);
this.imageState = null;
this.$pointType.set(1);
this.$hasProcessed.set(false);
this.$isProcessing.set(false);
// Reset non-ephemeral konva nodes
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
// The parent module's buffer should be reset & forcibly sync the cache
this.parent.bufferRenderer.clearBuffer();
this.parent.renderer.syncKonvaCache(true);
};
/**
* Builds a graph for segmenting an image with the given image DTO.
*/
buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => {
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
// TODO(psyche): When SAM2 is available in transformers, use it here
// See: https://github.com/huggingface/transformers/pull/32317
const segmentAnything = graph.addNode({
id: getPrefixedId('segment_anything'),
type: 'segment_anything',
model: 'segment-anything-huge',
image: { image_name },
point_lists: [{ points: this.getSAMPoints() }],
mask_filter: 'largest',
});
// Apply the mask to the image, outputting an image w/ alpha transparency
const applyMask = graph.addNode({
id: getPrefixedId('apply_tensor_mask_to_image'),
type: 'apply_tensor_mask_to_image',
image: { image_name },
});
graph.addEdge(segmentAnything, 'mask', applyMask, 'mask');
return {
graph,
outputNodeId: applyMask.id,
};
};
/**
* Gets the color of a SAM point based on its label.
*/
getSAMPointColor(label: SAMPointLabel): RgbaColor {
if (label === 0) {
return this.config.SAM_POINT_NEUTRAL_COLOR;
} else if (label === 1) {
return this.config.SAM_POINT_FOREGROUND_COLOR;
} else {
// label === -1
return this.config.SAM_POINT_BACKGROUND_COLOR;
}
}
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
parent: this.parent.id,
points: this.$points.get().map(({ id, konva, label }) => ({
id,
label,
circle: getKonvaNodeDebugAttrs(konva.circle),
})),
imageState: deepClone(this.imageState),
maskedImage: this.maskedImage?.repr(),
config: deepClone(this.config),
$isSegmenting: this.$isSegmenting.get(),
$hasProcessed: this.$hasProcessed.get(),
$isProcessing: this.$isProcessing.get(),
$pointType: this.$pointType.get(),
$pointTypeString: this.$pointTypeString.get(),
$isDraggingPoint: this.$isDraggingPoint.get(),
konva: {
group: getKonvaNodeDebugAttrs(this.konva.group),
compositingRect: getKonvaNodeDebugAttrs(this.konva.compositingRect),
maskGroup: getKonvaNodeDebugAttrs(this.konva.maskGroup),
pointGroup: getKonvaNodeDebugAttrs(this.konva.pointGroup),
},
};
};
destroy = () => {
this.log.debug('Destroying module');
if (this.abortController && !this.abortController.signal.aborted) {
this.abortController.abort();
}
this.abortController = null;
this.unsubscribe();
this.konva.group.destroy();
};
}

View File

@@ -311,7 +311,7 @@ export class CanvasStageModule extends CanvasModuleBase {
this.setIsDraggable(true);
// Then start dragging the stage if it's not already being dragged
if (!this.konva.stage.isDragging()) {
if (!this.getIsDragging()) {
this.konva.stage.startDrag();
}
@@ -328,7 +328,7 @@ export class CanvasStageModule extends CanvasModuleBase {
this.setIsDraggable(this.manager.tool.$tool.get() === 'view');
// Stop dragging the stage if it's being dragged
if (this.konva.stage.isDragging()) {
if (this.getIsDragging()) {
this.konva.stage.stopDrag();
}
@@ -404,6 +404,10 @@ export class CanvasStageModule extends CanvasModuleBase {
this.konva.stage.draggable(isDraggable);
};
getIsDragging = () => {
return this.konva.stage.isDragging();
};
addLayer = (layer: Konva.Layer) => {
this.konva.stage.add(layer);
};

View File

@@ -613,10 +613,20 @@ export class CanvasStateApiModule extends CanvasModuleBase {
$rasterizingAdapter = atom<CanvasEntityAdapter | null>(null);
/**
* Whether an entity is currently being transformed. Derived from `$transformingAdapter`.
* Whether an entity is currently being rasterized. Derived from `$rasterizingAdapter`.
*/
$isRasterizing = computed(this.$rasterizingAdapter, (rasterizingAdapter) => Boolean(rasterizingAdapter));
/**
* The entity adapter being segmented, if any.
*/
$segmentingAdapter = atom<CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer | null>(null);
/**
* Whether an entity is currently being segmented. Derived from `$segmentingAdapter`.
*/
$isSegmenting = computed(this.$segmentingAdapter, (segmentingAdapter) => Boolean(segmentingAdapter));
/**
* Whether the space key is currently pressed.
*/

View File

@@ -6,11 +6,13 @@ import {
} from 'common/util/roundDownToMultiple';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
import Konva from 'konva';
import { noop } from 'lodash-es';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { assert } from 'tsafe';
@@ -31,11 +33,11 @@ const NO_ANCHORS: string[] = [];
/**
* Renders the bounding box. The bounding box can be transformed by the user.
*/
export class CanvasBboxModule extends CanvasModuleBase {
export class CanvasBboxToolModule extends CanvasModuleBase {
readonly type = 'bbox';
readonly id: string;
readonly path: string[];
readonly parent: CanvasManager;
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
@@ -61,18 +63,18 @@ export class CanvasBboxModule extends CanvasModuleBase {
*/
$aspectRatioBuffer = atom(1);
constructor(manager: CanvasManager) {
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = manager;
this.manager = manager;
this.parent = parent;
this.manager = parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating bbox module');
this.konva = {
group: new Konva.Group({ name: `${this.type}:group`, listening: true }),
group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
// We will use a Konva.Transformer for the generation bbox. Transformers need some shape to transform, so we will
// create a transparent rect for this purpose.
proxyRect: new Konva.Rect({
@@ -127,6 +129,7 @@ export class CanvasBboxModule extends CanvasModuleBase {
perfectDrawEnabled: false,
}),
transformer: new Konva.Transformer({
listening: false,
name: `${this.type}:transformer`,
borderDash: [5, 5],
borderStroke: 'rgba(212,216,234,1)',
@@ -135,7 +138,6 @@ export class CanvasBboxModule extends CanvasModuleBase {
rotateEnabled: false,
keepRatio: false,
ignoreStroke: true,
listening: false,
flipEnabled: false,
anchorFill: 'rgba(212,216,234,1)',
anchorStroke: 'rgb(42,42,42)',
@@ -149,9 +151,18 @@ export class CanvasBboxModule extends CanvasModuleBase {
};
this.konva.proxyRect.on('dragmove', this.onDragMove);
this.konva.proxyRect.on('pointerenter', () => {
this.manager.stage.setCursor('move');
});
this.konva.proxyRect.on('pointerleave', () => {
this.manager.stage.setCursor('default');
});
this.konva.transformer.on('transform', this.onTransform);
this.konva.transformer.on('transformend', this.onTransformEnd);
this.subscriptions.add(() => {
this.konva.proxyRect.off('dragmove pointerenter pointerleave');
this.konva.transformer.off('transform transformend');
});
// The transformer will always be transforming the proxy rect
this.konva.transformer.nodes([this.konva.proxyRect]);
@@ -161,7 +172,7 @@ export class CanvasBboxModule extends CanvasModuleBase {
this.konva.group.add(this.konva.transformer);
// We will listen to the tool state to determine if the bbox should be visible or not.
this.subscriptions.add(this.manager.tool.$tool.listen(this.render));
this.subscriptions.add(this.parent.$tool.listen(this.render));
// Also listen to redux state to update the bbox's position and dimensions.
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectBbox, this.render));
@@ -176,6 +187,9 @@ export class CanvasBboxModule extends CanvasModuleBase {
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
}
// This is a noop. The cursor is changed when the cursor enters or leaves the bbox.
syncCursorStyle = noop;
initialize = () => {
this.log.debug('Initializing module');
// We need to retain a copy of the bbox state because
@@ -189,16 +203,13 @@ export class CanvasBboxModule extends CanvasModuleBase {
* Renders the bbox. The bbox is only visible when the tool is set to 'bbox'.
*/
render = () => {
this.log.trace('Rendering');
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
const tool = this.manager.tool.$tool.get();
this.konva.group.visible(true);
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
// We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with.
// If the mangaer is busy, we disable listening so the bbox cannot be interacted with.
this.manager.konva.previewLayer.listening(tool === 'bbox' && !this.manager.$isBusy.get());
this.konva.group.listening(tool === 'bbox' && !this.manager.$isBusy.get());
this.konva.proxyRect.setAttrs({
x,

View File

@@ -0,0 +1,427 @@
import { rgbaColorToString } from 'common/util/colorCodeTransformers';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import {
alignCoordForTool,
getLastPointOfLastLine,
getLastPointOfLastLineWithPressure,
getLastPointOfLine,
getPrefixedId,
isDistanceMoreThanMin,
offsetCoord,
} from 'features/controlLayers/konva/util';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { Logger } from 'roarr';
type CanvasBrushToolModuleConfig = {
/**
* The inner border color for the brush tool preview.
*/
BORDER_INNER_COLOR: string;
/**
* The outer border color for the brush tool preview.
*/
BORDER_OUTER_COLOR: string;
/**
* The number of milliseconds to wait before hiding the brush preview's fill circle after the mouse is released.
*/
HIDE_FILL_TIMEOUT_MS: number;
};
const DEFAULT_CONFIG: CanvasBrushToolModuleConfig = {
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
HIDE_FILL_TIMEOUT_MS: 1500, // same as Affinity
};
/**
* Renders a preview of the brush tool on the canvas.
*/
export class CanvasBrushToolModule extends CanvasModuleBase {
readonly type = 'brush_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasBrushToolModuleConfig = DEFAULT_CONFIG;
hideFillTimeoutId: number | null = null;
/**
* The Konva objects that make up the brush tool preview:
* - A group to hold the fill circle and borders
* - A circle to fill the brush area
* - An inner border ring
* - An outer border ring
*/
konva: {
group: Konva.Group;
fillCircle: Konva.Circle;
innerBorder: Konva.Ring;
outerBorder: Konva.Ring;
};
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
this.konva = {
group: new Konva.Group({ name: `${this.type}:brush_group`, listening: false }),
fillCircle: new Konva.Circle({
name: `${this.type}:brush_fill_circle`,
listening: false,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
innerBorder: new Konva.Ring({
name: `${this.type}:brush_inner_border_ring`,
listening: false,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_INNER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
outerBorder: new Konva.Ring({
name: `${this.type}:brush_outer_border_ring`,
listening: false,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_OUTER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
};
this.konva.group.add(this.konva.fillCircle, this.konva.innerBorder, this.konva.outerBorder);
}
syncCursorStyle = () => {
this.manager.stage.setCursor('none');
};
render = () => {
if (this.parent.$tool.get() !== 'brush') {
this.setVisibility(false);
return;
}
if (!this.parent.getCanDraw()) {
this.setVisibility(false);
return;
}
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
this.setVisibility(false);
return;
}
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
const lastPointerType = this.parent.$lastPointerType.get();
if (lastPointerType !== 'mouse' && isPrimaryPointerDown) {
this.setVisibility(false);
return;
}
this.setVisibility(true);
if (this.hideFillTimeoutId !== null) {
window.clearTimeout(this.hideFillTimeoutId);
this.hideFillTimeoutId = null;
}
const settings = this.manager.stateApi.getSettings();
const brushPreviewFill = this.manager.stateApi.getBrushPreviewColor();
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.brushWidth);
const radius = settings.brushWidth / 2;
// The circle is scaled
this.konva.fillCircle.setAttrs({
x: alignedCursorPos.x,
y: alignedCursorPos.y,
radius,
fill: rgbaColorToString(brushPreviewFill),
visible: !isPrimaryPointerDown && lastPointerType === 'mouse',
});
// But the borders are in screen-pixels
const onePixel = this.manager.stage.unscale(1);
const twoPixels = this.manager.stage.unscale(2);
this.konva.innerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius,
outerRadius: radius + onePixel,
});
this.konva.outerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius + onePixel,
outerRadius: radius + twoPixels,
});
this.hideFillTimeoutId = window.setTimeout(() => {
this.konva.fillCircle.visible(false);
this.hideFillTimeoutId = null;
}, this.config.HIDE_FILL_TIMEOUT_MS);
};
setVisibility = (visible: boolean) => {
this.konva.group.visible(visible);
};
/**
* Handles the pointer enter event on the stage, when the brush tool is active. This may create a new brush line if
* the mouse is down as the cursor enters the stage.
*
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerEnter = async (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) {
/**
* Can't do anything without:
* - A cursor position: the cursor is not on the stage
* - The mouse is down: the user is not drawing
* - A selected entity: there is no entity to draw on
*/
return;
}
const settings = this.manager.stateApi.getSettings();
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
// If the pen is down and pressure sensitivity is enabled, add the point with pressure
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line_with_pressure'),
type: 'brush_line_with_pressure',
points: [alignedPoint.x, alignedPoint.y, e.evt.pressure],
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.parent.getClip(selectedEntity.state),
});
} else {
// Else, add the point without pressure
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line'),
type: 'brush_line',
points: [alignedPoint.x, alignedPoint.y],
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.parent.getClip(selectedEntity.state),
});
}
};
/**
* Handles the pointer down event on the stage, when the brush tool is active. If the shift key is held, this will
* create a straight line from the last point of the last line to the current point. Else, it will create a new line
* with the current point.
*
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerDown = async (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !selectedEntity || !isPrimaryPointerDown) {
/**
* Can't do anything without:
* - A cursor position: the cursor is not on the stage
* - The mouse is down: the user is not drawing
* - A selected entity: there is no entity to draw on
*/
return;
}
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
const settings = this.manager.stateApi.getSettings();
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
// We need to get the last point of the last line to create a straight line if shift is held
const lastLinePoint = getLastPointOfLastLineWithPressure(
selectedEntity.state.objects,
'brush_line_with_pressure'
);
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [
lastLinePoint.x,
lastLinePoint.y,
lastLinePoint.pressure,
alignedPoint.x,
alignedPoint.y,
e.evt.pressure,
];
} else {
// Create a new line with the current point
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line_with_pressure'),
type: 'brush_line_with_pressure',
points,
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.parent.getClip(selectedEntity.state),
});
} else {
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'brush_line');
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
} else {
// Create a new line with the current point
points = [alignedPoint.x, alignedPoint.y];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line'),
type: 'brush_line',
points,
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.parent.getClip(selectedEntity.state),
});
}
};
/**
* Handles the pointer up event on the stage, when the brush tool is active. This handles finalizing the brush line
* that was being drawn (if any).
*
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
return;
}
if (
(selectedEntity.bufferRenderer.state?.type === 'brush_line' ||
selectedEntity.bufferRenderer.state?.type === 'brush_line_with_pressure') &&
selectedEntity.bufferRenderer.hasBuffer()
) {
selectedEntity.bufferRenderer.commitBuffer();
} else {
selectedEntity.bufferRenderer.clearBuffer();
}
};
/**
* Handles the pointer move event on the stage, when the brush tool is active. This handles extending the brush line
* that is being drawn (if any).
*
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
return;
}
if (!this.parent.$isPrimaryPointerDown.get()) {
return;
}
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
return;
}
const bufferState = selectedEntity.bufferRenderer.state;
if (!bufferState) {
return;
}
if (bufferState.type !== 'brush_line' && bufferState.type !== 'brush_line_with_pressure') {
return;
}
const settings = this.manager.stateApi.getSettings();
const lastPoint = getLastPointOfLine(bufferState.points);
const minDistance = settings.brushWidth * this.parent.config.BRUSH_SPACING_TARGET_SCALE;
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
return;
}
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
// Do not add duplicate points
return;
}
bufferState.points.push(alignedPoint.x, alignedPoint.y);
// Add pressure if the pen is down and pressure sensitivity is enabled
if (bufferState.type === 'brush_line_with_pressure' && settings.pressureSensitivity) {
bufferState.points.push(e.evt.pressure);
}
await selectedEntity.bufferRenderer.setBuffer(bufferState);
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
config: this.config,
};
};
destroy = () => {
this.log.debug('Destroying module');
this.konva.group.destroy();
};
}

View File

@@ -2,11 +2,16 @@ import { rgbColorToString } from 'common/util/colorCodeTransformers';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { getColorAtCoordinate, getPrefixedId } from 'features/controlLayers/konva/util';
import type { RgbColor } from 'features/controlLayers/store/types';
import { RGBA_BLACK } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { atom } from 'nanostores';
import rafThrottle from 'raf-throttle';
import type { Logger } from 'roarr';
type CanvasToolColorPickerConfig = {
type CanvasColorPickerToolModuleConfig = {
/**
* The inner radius of the ring.
*/
@@ -49,7 +54,7 @@ type CanvasToolColorPickerConfig = {
CROSSHAIR_BORDER_COLOR: string;
};
const DEFAULT_CONFIG: CanvasToolColorPickerConfig = {
const DEFAULT_CONFIG: CanvasColorPickerToolModuleConfig = {
RING_INNER_RADIUS: 25,
RING_OUTER_RADIUS: 35,
RING_BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
@@ -65,7 +70,7 @@ const DEFAULT_CONFIG: CanvasToolColorPickerConfig = {
/**
* Renders a preview of the color picker tool on the canvas.
*/
export class CanvasToolColorPicker extends CanvasModuleBase {
export class CanvasColorPickerToolModule extends CanvasModuleBase {
readonly type = 'color_picker_tool';
readonly id: string;
readonly path: string[];
@@ -73,7 +78,12 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasToolColorPickerConfig = DEFAULT_CONFIG;
config: CanvasColorPickerToolModuleConfig = DEFAULT_CONFIG;
/**
* The color currently under the cursor. Only has a value when the color picker tool is active.
*/
$colorUnderCursor = atom<RgbColor>(RGBA_BLACK);
/**
* The Konva objects that make up the color picker tool preview:
@@ -110,6 +120,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
this.konva = {
group: new Konva.Group({ name: `${this.type}:color_picker_group`, listening: false }),
ringCandidateColor: new Konva.Ring({
listening: false,
name: `${this.type}:color_picker_candidate_color_ring`,
innerRadius: 0,
outerRadius: 0,
@@ -117,6 +128,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
perfectDrawEnabled: false,
}),
ringCurrentColor: new Konva.Arc({
listening: false,
name: `${this.type}:color_picker_current_color_arc`,
innerRadius: 0,
outerRadius: 0,
@@ -125,6 +137,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
perfectDrawEnabled: false,
}),
ringInnerBorder: new Konva.Ring({
listening: false,
name: `${this.type}:color_picker_inner_border_ring`,
innerRadius: 0,
outerRadius: 0,
@@ -133,6 +146,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
perfectDrawEnabled: false,
}),
ringOuterBorder: new Konva.Ring({
listening: false,
name: `${this.type}:color_picker_outer_border_ring`,
innerRadius: 0,
outerRadius: 0,
@@ -141,41 +155,49 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
perfectDrawEnabled: false,
}),
crosshairNorthInner: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_north1_line`,
stroke: this.config.CROSSHAIR_LINE_COLOR,
perfectDrawEnabled: false,
}),
crosshairNorthOuter: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_north2_line`,
stroke: this.config.CROSSHAIR_BORDER_COLOR,
perfectDrawEnabled: false,
}),
crosshairEastInner: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_east1_line`,
stroke: this.config.CROSSHAIR_LINE_COLOR,
perfectDrawEnabled: false,
}),
crosshairEastOuter: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_east2_line`,
stroke: this.config.CROSSHAIR_BORDER_COLOR,
perfectDrawEnabled: false,
}),
crosshairSouthInner: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_south1_line`,
stroke: this.config.CROSSHAIR_LINE_COLOR,
perfectDrawEnabled: false,
}),
crosshairSouthOuter: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_south2_line`,
stroke: this.config.CROSSHAIR_BORDER_COLOR,
perfectDrawEnabled: false,
}),
crosshairWestInner: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_west1_line`,
stroke: this.config.CROSSHAIR_LINE_COLOR,
perfectDrawEnabled: false,
}),
crosshairWestOuter: new Konva.Line({
listening: false,
name: `${this.type}:color_picker_crosshair_west2_line`,
stroke: this.config.CROSSHAIR_BORDER_COLOR,
perfectDrawEnabled: false,
@@ -198,21 +220,27 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
);
}
syncCursorStyle = () => {
this.manager.stage.setCursor('none');
};
/**
* Renders the color picker tool preview on the canvas.
*/
render = () => {
const tool = this.parent.$tool.get();
if (this.parent.$tool.get() !== 'colorPicker') {
this.setVisibility(false);
return;
}
if (tool !== 'colorPicker') {
if (!this.parent.getCanDraw()) {
this.setVisibility(false);
return;
}
const cursorPos = this.parent.$cursorPos.get();
const canDraw = this.parent.getCanDraw();
if (!cursorPos || tool !== 'colorPicker' || !canDraw) {
if (!cursorPos) {
this.setVisibility(false);
return;
}
@@ -222,7 +250,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
const { x, y } = cursorPos.relative;
const settings = this.manager.stateApi.getSettings();
const colorUnderCursor = this.parent.$colorUnderCursor.get();
const colorUnderCursor = this.$colorUnderCursor.get();
const colorPickerInnerRadius = this.manager.stage.unscale(this.config.RING_INNER_RADIUS);
const colorPickerOuterRadius = this.manager.stage.unscale(this.config.RING_OUTER_RADIUS);
const onePixel = this.manager.stage.unscale(1);
@@ -299,12 +327,38 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
this.konva.group.visible(visible);
};
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
const color = this.$colorUnderCursor.get();
if (color) {
const settings = this.manager.stateApi.getSettings();
// This will update the color but not the alpha value
this.manager.stateApi.setColor({ ...settings.color, ...color });
}
};
onStagePointerMove = (_e: KonvaEventObject<PointerEvent>) => {
this.syncColorUnderCursor();
};
syncColorUnderCursor = rafThrottle(() => {
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
return;
}
const color = getColorAtCoordinate(this.manager.stage.konva.stage, cursorPos.absolute);
if (color) {
this.$colorUnderCursor.set(color);
}
});
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
config: this.config,
$colorUnderCursor: this.$colorUnderCursor.get(),
};
};

View File

@@ -0,0 +1,394 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import {
alignCoordForTool,
getLastPointOfLastLine,
getLastPointOfLastLineWithPressure,
getLastPointOfLine,
getPrefixedId,
isDistanceMoreThanMin,
offsetCoord,
} from 'features/controlLayers/konva/util';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { Logger } from 'roarr';
type CanvasEraserToolModuleConfig = {
/**
* The inner border color for the eraser tool preview.
*/
BORDER_INNER_COLOR: string;
/**
* The outer border color for the eraser tool preview.
*/
BORDER_OUTER_COLOR: string;
};
const DEFAULT_CONFIG: CanvasEraserToolModuleConfig = {
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
};
export class CanvasEraserToolModule extends CanvasModuleBase {
readonly type = 'eraser_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasEraserToolModuleConfig = DEFAULT_CONFIG;
konva: {
group: Konva.Group;
cutoutCircle: Konva.Circle;
innerBorder: Konva.Ring;
outerBorder: Konva.Ring;
};
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
this.konva = {
group: new Konva.Group({ name: `${this.type}:eraser_group`, listening: false }),
cutoutCircle: new Konva.Circle({
name: `${this.type}:eraser_cutout_circle`,
listening: false,
strokeEnabled: false,
// The fill is used only to erase what is underneath it, so its color doesn't matter - just needs to be opaque
fill: 'white',
globalCompositeOperation: 'destination-out',
perfectDrawEnabled: false,
}),
innerBorder: new Konva.Ring({
name: `${this.type}:eraser_inner_border_ring`,
listening: false,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_INNER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
outerBorder: new Konva.Ring({
listening: false,
name: `${this.type}:eraser_outer_border_ring`,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_OUTER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
};
this.konva.group.add(this.konva.cutoutCircle, this.konva.innerBorder, this.konva.outerBorder);
}
syncCursorStyle = () => {
this.manager.stage.setCursor('none');
};
render = () => {
if (this.parent.$tool.get() !== 'eraser') {
this.setVisibility(false);
return;
}
if (!this.parent.getCanDraw()) {
this.setVisibility(false);
return;
}
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
this.setVisibility(false);
return;
}
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
const lastPointerType = this.parent.$lastPointerType.get();
if (lastPointerType !== 'mouse' && isPrimaryPointerDown) {
this.setVisibility(false);
return;
}
this.setVisibility(true);
const settings = this.manager.stateApi.getSettings();
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.eraserWidth);
const radius = settings.eraserWidth / 2;
// The circle is scaled
this.konva.cutoutCircle.setAttrs({
x: alignedCursorPos.x,
y: alignedCursorPos.y,
radius,
});
// But the borders are in screen-pixels
const onePixel = this.manager.stage.unscale(1);
const twoPixels = this.manager.stage.unscale(2);
this.konva.innerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius,
outerRadius: radius + onePixel,
});
this.konva.outerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius + onePixel,
outerRadius: radius + twoPixels,
});
};
setVisibility = (visible: boolean) => {
this.konva.group.visible(visible);
};
/**
* Handles the pointer enter event on the stage, when the eraser tool is active. This may create a new eraser line if
* the mouse is down as the cursor enters the stage.
*
* The tool module will pass on the event to this method if the tool is 'eraser', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerEnter = async (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) {
/**
* Can't do anything without:
* - A cursor position: the cursor is not on the stage
* - The mouse is down: the user is not drawing
* - A selected entity: there is no entity to draw on
*/
return;
}
const settings = this.manager.stateApi.getSettings();
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
// If the pen is down and pressure sensitivity is enabled, add the point with pressure
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line_with_pressure'),
type: 'eraser_line_with_pressure',
points: [alignedPoint.x, alignedPoint.y, e.evt.pressure],
strokeWidth: settings.eraserWidth,
clip: this.parent.getClip(selectedEntity.state),
});
} else {
// Else, add the point without pressure
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line'),
type: 'eraser_line',
points: [alignedPoint.x, alignedPoint.y],
strokeWidth: settings.eraserWidth,
clip: this.parent.getClip(selectedEntity.state),
});
}
};
/**
* Handles the pointer down event on the stage, when the eraser tool is active. If the shift key is held, this will
* create a straight line from the last point of the last line to the current point. Else, it will create a new line
* with the current point.
*
* The tool module will pass on the event to this method if the tool is 'eraser', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerDown = async (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !selectedEntity) {
/**
* Can't do anything without:
* - A cursor position: the cursor is not on the stage
* - A selected entity: there is no entity to draw on
*/
return;
}
const settings = this.manager.stateApi.getSettings();
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
// We need to get the last point of the last line to create a straight line if shift is held
const lastLinePoint = getLastPointOfLastLineWithPressure(
selectedEntity.state.objects,
'eraser_line_with_pressure'
);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [
lastLinePoint.x,
lastLinePoint.y,
lastLinePoint.pressure,
alignedPoint.x,
alignedPoint.y,
e.evt.pressure,
];
} else {
// Create a new line with the current point
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line_with_pressure'),
type: 'eraser_line_with_pressure',
points,
strokeWidth: settings.eraserWidth,
clip: this.parent.getClip(selectedEntity.state),
});
} else {
// We need to get the last point of the last line to create a straight line if shift is held
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'eraser_line');
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
} else {
// Create a new line with the current point
points = [alignedPoint.x, alignedPoint.y];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line'),
type: 'eraser_line',
points,
strokeWidth: settings.eraserWidth,
clip: this.parent.getClip(selectedEntity.state),
});
}
};
/**
* Handles the pointer up event on the stage, when the eraser tool is active. This handles finalizing the eraser line
* that was being drawn (if any).
*
* The tool module will pass on the event to this method if the tool is 'eraser', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
return;
}
if (
(selectedEntity.bufferRenderer.state?.type === 'eraser_line' ||
selectedEntity.bufferRenderer.state?.type === 'eraser_line_with_pressure') &&
selectedEntity.bufferRenderer.hasBuffer()
) {
selectedEntity.bufferRenderer.commitBuffer();
} else {
selectedEntity.bufferRenderer.clearBuffer();
}
};
/**
* Handles the pointer move event on the stage, when the brush tool is active. This handles extending the brush line
* that is being drawn (if any).
*
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
* and non-tool-specific handling.
*
* @param e The Konva event object.
*/
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
return;
}
if (!this.parent.$isPrimaryPointerDown.get()) {
return;
}
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
return;
}
const bufferState = selectedEntity.bufferRenderer.state;
if (!bufferState) {
return;
}
if (bufferState.type !== 'eraser_line' && bufferState.type !== 'eraser_line_with_pressure') {
return;
}
const settings = this.manager.stateApi.getSettings();
const lastPoint = getLastPointOfLine(bufferState.points);
const minDistance = settings.eraserWidth * this.parent.config.BRUSH_SPACING_TARGET_SCALE;
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
return;
}
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
// Do not add duplicate points
return;
}
bufferState.points.push(alignedPoint.x, alignedPoint.y);
// Add pressure if the pen is down and pressure sensitivity is enabled
if (bufferState.type === 'eraser_line_with_pressure' && settings.pressureSensitivity) {
bufferState.points.push(e.evt.pressure);
}
await selectedEntity.bufferRenderer.setBuffer(bufferState);
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
config: this.config,
};
};
destroy = () => {
this.log.debug('Destroying eraser tool preview module');
this.konva.group.destroy();
};
}

View File

@@ -0,0 +1,31 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { noop } from 'lodash-es';
import type { Logger } from 'roarr';
export class CanvasMoveToolModule extends CanvasModuleBase {
readonly type = 'move_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
}
/**
* This is a noop. Entity transformers handle cursor style when the move tool is active.
*/
syncCursorStyle = noop;
}

View File

@@ -0,0 +1,102 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { floorCoord, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { Logger } from 'roarr';
export class CanvasRectToolModule extends CanvasModuleBase {
readonly type = 'rect_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
}
syncCursorStyle = () => {
this.manager.stage.setCursor('crosshair');
};
onStagePointerDown = async (_e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) {
/**
* Can't do anything without:
* - A cursor position: the cursor is not on the stage
* - The mouse is down: the user is not drawing
* - A selected entity: there is no entity to draw on
*/
return;
}
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('rect'),
type: 'rect',
rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 },
color: this.manager.stateApi.getCurrentColor(),
});
};
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
return;
}
if (selectedEntity.bufferRenderer.state?.type === 'rect' && selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
} else {
selectedEntity.bufferRenderer.clearBuffer();
}
};
onStagePointerMove = async (_e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
return;
}
if (!this.parent.$isPrimaryPointerDown.get()) {
return;
}
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
return;
}
const bufferState = selectedEntity.bufferRenderer.state;
if (!bufferState) {
return;
}
if (bufferState.type !== 'rect') {
return;
}
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = floorCoord(normalizedPoint);
bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x);
bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y);
await selectedEntity.bufferRenderer.setBuffer(bufferState);
};
}

View File

@@ -1,182 +0,0 @@
import { rgbaColorToString } from 'common/util/colorCodeTransformers';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { alignCoordForTool, getPrefixedId } from 'features/controlLayers/konva/util';
import Konva from 'konva';
import type { Logger } from 'roarr';
type CanvasToolBrushConfig = {
/**
* The inner border color for the brush tool preview.
*/
BORDER_INNER_COLOR: string;
/**
* The outer border color for the brush tool preview.
*/
BORDER_OUTER_COLOR: string;
/**
* The number of milliseconds to wait before hiding the brush preview's fill circle after the mouse is released.
*/
HIDE_FILL_TIMEOUT_MS: number;
};
const DEFAULT_CONFIG: CanvasToolBrushConfig = {
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
HIDE_FILL_TIMEOUT_MS: 1500, // same as Affinity
};
/**
* Renders a preview of the brush tool on the canvas.
*/
export class CanvasToolBrush extends CanvasModuleBase {
readonly type = 'brush_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasToolBrushConfig = DEFAULT_CONFIG;
hideFillTimeoutId: number | null = null;
/**
* The Konva objects that make up the brush tool preview:
* - A group to hold the fill circle and borders
* - A circle to fill the brush area
* - An inner border ring
* - An outer border ring
*/
konva: {
group: Konva.Group;
fillCircle: Konva.Circle;
innerBorder: Konva.Ring;
outerBorder: Konva.Ring;
};
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
this.konva = {
group: new Konva.Group({ name: `${this.type}:brush_group`, listening: false }),
fillCircle: new Konva.Circle({
name: `${this.type}:brush_fill_circle`,
listening: false,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
innerBorder: new Konva.Ring({
name: `${this.type}:brush_inner_border_ring`,
listening: false,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_INNER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
outerBorder: new Konva.Ring({
name: `${this.type}:brush_outer_border_ring`,
listening: false,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_OUTER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
};
this.konva.group.add(this.konva.fillCircle, this.konva.innerBorder, this.konva.outerBorder);
}
render = () => {
const tool = this.parent.$tool.get();
if (tool !== 'brush') {
this.setVisibility(false);
return;
}
const cursorPos = this.parent.$cursorPos.get();
const canDraw = this.parent.getCanDraw();
if (!cursorPos || !canDraw) {
this.setVisibility(false);
return;
}
const isMouseDown = this.parent.$isMouseDown.get();
const lastPointerType = this.parent.$lastPointerType.get();
if (lastPointerType !== 'mouse' && isMouseDown) {
this.setVisibility(false);
return;
}
this.setVisibility(true);
if (this.hideFillTimeoutId !== null) {
window.clearTimeout(this.hideFillTimeoutId);
this.hideFillTimeoutId = null;
}
const settings = this.manager.stateApi.getSettings();
const brushPreviewFill = this.manager.stateApi.getBrushPreviewColor();
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.brushWidth);
const radius = settings.brushWidth / 2;
// The circle is scaled
this.konva.fillCircle.setAttrs({
x: alignedCursorPos.x,
y: alignedCursorPos.y,
radius,
fill: rgbaColorToString(brushPreviewFill),
visible: !isMouseDown && lastPointerType === 'mouse',
});
// But the borders are in screen-pixels
const onePixel = this.manager.stage.unscale(1);
const twoPixels = this.manager.stage.unscale(2);
this.konva.innerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius,
outerRadius: radius + onePixel,
});
this.konva.outerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius + onePixel,
outerRadius: radius + twoPixels,
});
this.hideFillTimeoutId = window.setTimeout(() => {
this.konva.fillCircle.visible(false);
this.hideFillTimeoutId = null;
}, this.config.HIDE_FILL_TIMEOUT_MS);
};
setVisibility = (visible: boolean) => {
this.konva.group.visible(visible);
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
config: this.config,
};
};
destroy = () => {
this.log.debug('Destroying module');
this.konva.group.destroy();
};
}

View File

@@ -1,155 +0,0 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { alignCoordForTool, getPrefixedId } from 'features/controlLayers/konva/util';
import Konva from 'konva';
import type { Logger } from 'roarr';
type CanvasToolEraserConfig = {
/**
* The inner border color for the eraser tool preview.
*/
BORDER_INNER_COLOR: string;
/**
* The outer border color for the eraser tool preview.
*/
BORDER_OUTER_COLOR: string;
};
const DEFAULT_CONFIG: CanvasToolEraserConfig = {
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
};
export class CanvasToolEraser extends CanvasModuleBase {
readonly type = 'eraser_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasToolEraserConfig = DEFAULT_CONFIG;
konva: {
group: Konva.Group;
cutoutCircle: Konva.Circle;
innerBorder: Konva.Ring;
outerBorder: Konva.Ring;
};
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
this.konva = {
group: new Konva.Group({ name: `${this.type}:eraser_group`, listening: false }),
cutoutCircle: new Konva.Circle({
name: `${this.type}:eraser_cutout_circle`,
listening: false,
strokeEnabled: false,
// The fill is used only to erase what is underneath it, so its color doesn't matter - just needs to be opaque
fill: 'white',
globalCompositeOperation: 'destination-out',
perfectDrawEnabled: false,
}),
innerBorder: new Konva.Ring({
name: `${this.type}:eraser_inner_border_ring`,
listening: false,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_INNER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
outerBorder: new Konva.Ring({
name: `${this.type}:eraser_outer_border_ring`,
innerRadius: 0,
outerRadius: 0,
fill: this.config.BORDER_OUTER_COLOR,
strokeEnabled: false,
perfectDrawEnabled: false,
}),
};
this.konva.group.add(this.konva.cutoutCircle, this.konva.innerBorder, this.konva.outerBorder);
}
render = () => {
const tool = this.parent.$tool.get();
if (tool !== 'eraser') {
this.setVisibility(false);
return;
}
const cursorPos = this.parent.$cursorPos.get();
const canDraw = this.parent.getCanDraw();
if (!cursorPos || !canDraw) {
this.setVisibility(false);
return;
}
const isMouseDown = this.parent.$isMouseDown.get();
const lastPointerType = this.parent.$lastPointerType.get();
if (lastPointerType !== 'mouse' && isMouseDown) {
this.setVisibility(false);
return;
}
this.setVisibility(true);
const settings = this.manager.stateApi.getSettings();
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.eraserWidth);
const radius = settings.eraserWidth / 2;
// The circle is scaled
this.konva.cutoutCircle.setAttrs({
x: alignedCursorPos.x,
y: alignedCursorPos.y,
radius,
});
// But the borders are in screen-pixels
const onePixel = this.manager.stage.unscale(1);
const twoPixels = this.manager.stage.unscale(2);
this.konva.innerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius,
outerRadius: radius + onePixel,
});
this.konva.outerBorder.setAttrs({
x: cursorPos.relative.x,
y: cursorPos.relative.y,
innerRadius: radius + onePixel,
outerRadius: radius + twoPixels,
});
};
setVisibility = (visible: boolean) => {
this.konva.group.visible(visible);
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
config: this.config,
};
};
destroy = () => {
this.log.debug('Destroying eraser tool preview module');
this.konva.group.destroy();
};
}

View File

@@ -1,20 +1,16 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasToolBrush } from 'features/controlLayers/konva/CanvasTool/CanvasToolBrush';
import { CanvasToolColorPicker } from 'features/controlLayers/konva/CanvasTool/CanvasToolColorPicker';
import { CanvasToolEraser } from 'features/controlLayers/konva/CanvasTool/CanvasToolEraser';
import { CanvasBboxToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBboxToolModule';
import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBrushToolModule';
import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule';
import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasEraserToolModule';
import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule';
import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule';
import { CanvasViewToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasViewToolModule';
import {
alignCoordForTool,
calculateNewBrushSizeFromWheelDelta,
floorCoord,
getColorAtCoordinate,
getIsPrimaryMouseDown,
getLastPointOfLastLine,
getLastPointOfLastLineWithPressure,
getLastPointOfLine,
getPrefixedId,
isDistanceMoreThanMin,
offsetCoord,
} from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
@@ -24,14 +20,11 @@ import type {
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
Coordinate,
RgbColor,
Tool,
} from 'features/controlLayers/store/types';
import { RGBA_BLACK } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import { atom } from 'nanostores';
import rafThrottle from 'raf-throttle';
import type { Logger } from 'roarr';
// Konva's docs say the default drag buttons are [0], but it's actually [0,1]. We only want left-click to drag, so we
@@ -63,9 +56,15 @@ export class CanvasToolModule extends CanvasModuleBase {
config: CanvasToolModuleConfig = DEFAULT_CONFIG;
brushToolPreview: CanvasToolBrush;
eraserToolPreview: CanvasToolEraser;
colorPickerToolPreview: CanvasToolColorPicker;
tools: {
brush: CanvasBrushToolModule;
eraser: CanvasEraserToolModule;
rect: CanvasRectToolModule;
colorPicker: CanvasColorPickerToolModule;
bbox: CanvasBboxToolModule;
view: CanvasViewToolModule;
move: CanvasMoveToolModule;
};
/**
* The currently selected tool.
@@ -77,17 +76,22 @@ export class CanvasToolModule extends CanvasModuleBase {
*/
$toolBuffer = atom<Tool | null>(null);
/**
* Whether the mouse is currently down.
* Whether the primary pointer (left mouse, pen, first touch) is currently down on the stage.
*
* This is set true when the pointer down is fired on the stage and false when the pointer up is fired anywhere,
* including outside of the stage. This flag is thus true when the user is actively drawing on the stage.
*
* For example, if the pointer down was fired on the stage and the cursor then leaves the stage without a pointer up
* event, this will still be true. If the cursor then moves back onto the stage, this will still be true.
*
* However, if the pointer down was initially fired _outside_ the stage, and the cursor moves onto the stage, this
* will be false.
*/
$isMouseDown = atom<boolean>(false);
$isPrimaryPointerDown = atom<boolean>(false);
/**
* The last cursor position.
*/
$cursorPos = atom<{ relative: Coordinate; absolute: Coordinate } | null>(null);
/**
* The color currently under the cursor. Only has a value when the color picker tool is active.
*/
$colorUnderCursor = atom<RgbColor>(RGBA_BLACK);
/**
* The last pointer type that was used on the stage. This is used to determine if we should show a tool preview. For
* example, when using a pen, we should not show a brush preview.
@@ -109,18 +113,25 @@ export class CanvasToolModule extends CanvasModuleBase {
this.log.debug('Creating tool module');
this.brushToolPreview = new CanvasToolBrush(this);
this.eraserToolPreview = new CanvasToolEraser(this);
this.colorPickerToolPreview = new CanvasToolColorPicker(this);
this.tools = {
brush: new CanvasBrushToolModule(this),
eraser: new CanvasEraserToolModule(this),
rect: new CanvasRectToolModule(this),
colorPicker: new CanvasColorPickerToolModule(this),
bbox: new CanvasBboxToolModule(this),
view: new CanvasViewToolModule(this),
move: new CanvasMoveToolModule(this),
};
this.konva = {
stage: this.manager.stage.konva.stage,
group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
group: new Konva.Group({ name: `${this.type}:group`, listening: true }),
};
this.konva.group.add(this.brushToolPreview.konva.group);
this.konva.group.add(this.eraserToolPreview.konva.group);
this.konva.group.add(this.colorPickerToolPreview.konva.group);
this.konva.group.add(this.tools.brush.konva.group);
this.konva.group.add(this.tools.eraser.konva.group);
this.konva.group.add(this.tools.colorPicker.konva.group);
this.konva.group.add(this.tools.bbox.konva.group);
this.subscriptions.add(this.manager.stage.$stageAttrs.listen(this.render));
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
@@ -129,7 +140,7 @@ export class CanvasToolModule extends CanvasModuleBase {
this.subscriptions.add(
this.$tool.listen(() => {
// On tool switch, reset mouse state
this.manager.tool.$isMouseDown.set(false);
this.manager.tool.$isPrimaryPointerDown.set(false);
this.render();
})
);
@@ -145,71 +156,47 @@ export class CanvasToolModule extends CanvasModuleBase {
this.syncCursorStyle();
};
setToolVisibility = (tool: Tool, isDrawable: boolean) => {
this.brushToolPreview.setVisibility(isDrawable && tool === 'brush');
this.eraserToolPreview.setVisibility(isDrawable && tool === 'eraser');
this.colorPickerToolPreview.setVisibility(tool === 'colorPicker');
};
syncCursorStyle = () => {
const stage = this.manager.stage;
const tool = this.$tool.get();
const isStageDragging = this.manager.stage.konva.stage.isDragging();
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
if (tool === 'view' && !isStageDragging) {
stage.setCursor('grab');
} else if (this.manager.stage.konva.stage.isDragging()) {
stage.setCursor('grabbing');
} else if (this.manager.stateApi.$isTransforming.get()) {
stage.setCursor('default');
if ((this.manager.stage.getIsDragging() || tool === 'view') && !segmentingAdapter) {
this.tools.view.syncCursorStyle();
} else if (segmentingAdapter) {
segmentingAdapter.segmentAnything.syncCursorStyle();
} else if (this.manager.stateApi.$isFiltering.get()) {
stage.setCursor('not-allowed');
} else if (this.manager.stagingArea.$isStaging.get()) {
stage.setCursor('not-allowed');
} else if (tool === 'bbox') {
stage.setCursor('default');
this.tools.bbox.syncCursorStyle();
} else if (this.manager.stateApi.getRenderedEntityCount() === 0) {
stage.setCursor('not-allowed');
} else if (!this.manager.stateApi.getSelectedEntityAdapter()?.$isInteractable.get()) {
stage.setCursor('not-allowed');
} else if (tool === 'colorPicker' || tool === 'brush' || tool === 'eraser') {
stage.setCursor('none');
} else if (tool === 'brush') {
this.tools.brush.syncCursorStyle();
} else if (tool === 'eraser') {
this.tools.eraser.syncCursorStyle();
} else if (tool === 'colorPicker') {
this.tools.colorPicker.syncCursorStyle();
} else if (tool === 'move') {
stage.setCursor('default');
this.tools.move.syncCursorStyle();
} else if (tool === 'rect') {
stage.setCursor('crosshair');
this.tools.rect.syncCursorStyle();
} else {
stage.setCursor('not-allowed');
}
};
render = () => {
const renderedEntityCount = this.manager.stateApi.getRenderedEntityCount();
const cursorPos = this.$cursorPos.get();
const isFiltering = this.manager.stateApi.$isFiltering.get();
const isStaging = this.manager.stagingArea.$isStaging.get();
const isStageDragging = this.manager.stage.konva.stage.isDragging();
this.syncCursorStyle();
/**
* The tool should not be rendered when:
* - There is no cursor position (i.e. the cursor is outside of the stage)
* - The user is filtering, in which case the user is not allowed to use the tools. Note that we do not disable
* the group while transforming, bc that requires use of the move tool.
* - The canvas is staging, in which case the user is not allowed to use the tools.
* - There are no entities rendered on the canvas. Maybe we should allow the user to draw on an empty canvas,
* creating a new layer when they start?
* - The stage is being dragged, in which case the user is not allowed to use the tools.
*/
if (!cursorPos || isFiltering || isStaging || renderedEntityCount === 0 || isStageDragging) {
this.konva.group.visible(false);
} else {
this.konva.group.visible(true);
this.brushToolPreview.render();
this.eraserToolPreview.render();
this.colorPickerToolPreview.render();
}
this.tools.brush.render();
this.tools.eraser.render();
this.tools.colorPicker.render();
this.tools.bbox.render();
};
syncCursorPositions = () => {
@@ -282,6 +269,14 @@ export class CanvasToolModule extends CanvasModuleBase {
};
};
/**
* Gets whether the user is allowed to draw on the canvas.
* - There must be at least one entity rendered on the canvas.
* - The canvas must not be busy (e.g. transforming, filtering, rasterizing, staging, compositing, segment-anything-ing).
* - There must be a selected entity.
* - The selected entity must be interactable (e.g. not hidden, disabled or locked).
* @returns Whether the user is allowed to draw on the canvas.
*/
getCanDraw = (): boolean => {
if (this.manager.stateApi.getRenderedEntityCount() === 0) {
return false;
@@ -291,6 +286,10 @@ export class CanvasToolModule extends CanvasModuleBase {
return false;
}
if (this.manager.stage.getIsDragging()) {
return false;
}
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
@@ -313,71 +312,19 @@ export class CanvasToolModule extends CanvasModuleBase {
}
this.syncCursorPositions();
const cursorPos = this.$cursorPos.get();
const isMouseDown = this.$isMouseDown.get();
const settings = this.manager.stateApi.getSettings();
const tool = this.$tool.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !isMouseDown || !selectedEntity?.$isInteractable.get()) {
return;
}
if (selectedEntity.bufferRenderer.state?.type !== 'rect' && selectedEntity.bufferRenderer.hasBuffer()) {
if (selectedEntity?.bufferRenderer.state?.type !== 'rect' && selectedEntity?.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
return;
}
if (tool === 'brush') {
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line_with_pressure'),
type: 'brush_line_with_pressure',
points: [alignedPoint.x, alignedPoint.y, e.evt.pressure],
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.getClip(selectedEntity.state),
});
} else {
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line'),
type: 'brush_line',
points: [alignedPoint.x, alignedPoint.y],
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.getClip(selectedEntity.state),
});
}
return;
}
if (tool === 'eraser') {
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (selectedEntity.bufferRenderer.state && selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line_with_pressure'),
type: 'eraser_line_with_pressure',
points: [alignedPoint.x, alignedPoint.y],
strokeWidth: settings.eraserWidth,
clip: this.getClip(selectedEntity.state),
});
} else {
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line'),
type: 'eraser_line',
points: [alignedPoint.x, alignedPoint.y],
strokeWidth: settings.eraserWidth,
clip: this.getClip(selectedEntity.state),
});
}
return;
await this.tools.brush.onStagePointerEnter(e);
} else if (tool === 'eraser') {
await this.tools.eraser.onStagePointerEnter(e);
}
} finally {
this.render();
@@ -385,6 +332,10 @@ export class CanvasToolModule extends CanvasModuleBase {
};
onStagePointerDown = async (e: KonvaEventObject<PointerEvent>) => {
if (e.target !== this.konva.stage) {
return;
}
try {
this.$lastPointerType.set(e.evt.pointerType);
@@ -392,147 +343,18 @@ export class CanvasToolModule extends CanvasModuleBase {
return;
}
const isMouseDown = getIsPrimaryMouseDown(e);
this.$isMouseDown.set(isMouseDown);
this.$isPrimaryPointerDown.set(getIsPrimaryMouseDown(e));
this.syncCursorPositions();
const cursorPos = this.$cursorPos.get();
const tool = this.$tool.get();
const settings = this.manager.stateApi.getSettings();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !isMouseDown || !selectedEntity?.$isInteractable.get()) {
return;
}
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
if (tool === 'brush') {
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
const lastLinePoint = getLastPointOfLastLineWithPressure(
selectedEntity.state.objects,
'brush_line_with_pressure'
);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [
lastLinePoint.x,
lastLinePoint.y,
lastLinePoint.pressure,
alignedPoint.x,
alignedPoint.y,
e.evt.pressure,
];
} else {
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line_with_pressure'),
type: 'brush_line_with_pressure',
points,
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.getClip(selectedEntity.state),
});
} else {
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'brush_line');
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
} else {
points = [alignedPoint.x, alignedPoint.y];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('brush_line'),
type: 'brush_line',
points,
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
clip: this.getClip(selectedEntity.state),
});
}
}
if (tool === 'eraser') {
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
const lastLinePoint = getLastPointOfLastLineWithPressure(
selectedEntity.state.objects,
'eraser_line_with_pressure'
);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [
lastLinePoint.x,
lastLinePoint.y,
lastLinePoint.pressure,
alignedPoint.x,
alignedPoint.y,
e.evt.pressure,
];
} else {
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line_with_pressure'),
type: 'eraser_line_with_pressure',
points,
strokeWidth: settings.eraserWidth,
clip: this.getClip(selectedEntity.state),
});
} else {
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'eraser_line');
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
let points: number[];
if (e.evt.shiftKey && lastLinePoint) {
// Create a straight line from the last line point
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
} else {
points = [alignedPoint.x, alignedPoint.y];
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('eraser_line'),
type: 'eraser_line',
points,
strokeWidth: settings.eraserWidth,
clip: this.getClip(selectedEntity.state),
});
}
}
if (tool === 'rect') {
if (selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
}
await selectedEntity.bufferRenderer.setBuffer({
id: getPrefixedId('rect'),
type: 'rect',
rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 },
color: this.manager.stateApi.getCurrentColor(),
});
await this.tools.brush.onStagePointerDown(e);
} else if (tool === 'eraser') {
await this.tools.eraser.onStagePointerDown(e);
} else if (tool === 'rect') {
await this.tools.rect.onStagePointerDown(e);
}
} finally {
this.render();
@@ -540,6 +362,10 @@ export class CanvasToolModule extends CanvasModuleBase {
};
onStagePointerUp = (e: KonvaEventObject<PointerEvent>) => {
if (e.target !== this.konva.stage) {
return;
}
try {
this.$lastPointerType.set(e.evt.pointerType);
@@ -548,160 +374,46 @@ export class CanvasToolModule extends CanvasModuleBase {
}
const tool = this.$tool.get();
const settings = this.manager.stateApi.getSettings();
if (tool === 'colorPicker') {
const color = this.$colorUnderCursor.get();
if (color) {
this.manager.stateApi.setColor({ ...settings.color, ...color });
}
return;
}
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity?.$isInteractable.get()) {
return;
}
if (tool === 'brush') {
if (
(selectedEntity.bufferRenderer.state?.type === 'brush_line' ||
selectedEntity.bufferRenderer.state?.type === 'brush_line_with_pressure') &&
selectedEntity.bufferRenderer.hasBuffer()
) {
selectedEntity.bufferRenderer.commitBuffer();
} else {
selectedEntity.bufferRenderer.clearBuffer();
}
}
if (tool === 'eraser') {
if (
(selectedEntity.bufferRenderer.state?.type === 'eraser_line' ||
selectedEntity.bufferRenderer.state?.type === 'eraser_line_with_pressure') &&
selectedEntity.bufferRenderer.hasBuffer()
) {
selectedEntity.bufferRenderer.commitBuffer();
} else {
selectedEntity.bufferRenderer.clearBuffer();
}
}
if (tool === 'rect') {
if (selectedEntity.bufferRenderer.state?.type === 'rect' && selectedEntity.bufferRenderer.hasBuffer()) {
selectedEntity.bufferRenderer.commitBuffer();
} else {
selectedEntity.bufferRenderer.clearBuffer();
}
this.tools.colorPicker.onStagePointerUp(e);
} else if (tool === 'brush') {
this.tools.brush.onStagePointerUp(e);
} else if (tool === 'eraser') {
this.tools.eraser.onStagePointerUp(e);
} else if (tool === 'rect') {
this.tools.rect.onStagePointerUp(e);
}
} finally {
this.render();
}
};
syncColorUnderCursor = rafThrottle(() => {
const cursorPos = this.$cursorPos.get();
if (!cursorPos) {
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
if (e.target !== this.konva.stage) {
return;
}
const color = getColorAtCoordinate(this.konva.stage, cursorPos.absolute);
if (color) {
this.$colorUnderCursor.set(color);
}
});
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
try {
this.$lastPointerType.set(e.evt.pointerType);
this.syncCursorPositions();
if (!this.getCanDraw()) {
return;
}
this.syncCursorPositions();
const cursorPos = this.$cursorPos.get();
if (!cursorPos) {
return;
}
const tool = this.$tool.get();
if (tool === 'colorPicker') {
this.syncColorUnderCursor();
}
const isMouseDown = this.$isMouseDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!isMouseDown || !selectedEntity?.$isInteractable.get()) {
return;
}
const bufferState = selectedEntity.bufferRenderer.state;
if (!bufferState) {
return;
}
const settings = this.manager.stateApi.getSettings();
if (tool === 'brush' && (bufferState.type === 'brush_line' || bufferState.type === 'brush_line_with_pressure')) {
const lastPoint = getLastPointOfLine(bufferState.points);
const minDistance = settings.brushWidth * this.config.BRUSH_SPACING_TARGET_SCALE;
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
return;
}
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
// Do not add duplicate points
return;
}
bufferState.points.push(alignedPoint.x, alignedPoint.y);
if (bufferState.type === 'brush_line_with_pressure') {
bufferState.points.push(e.evt.pressure);
}
await selectedEntity.bufferRenderer.setBuffer(bufferState);
} else if (
tool === 'eraser' &&
(bufferState.type === 'eraser_line' || bufferState.type === 'eraser_line_with_pressure')
) {
const lastPoint = getLastPointOfLine(bufferState.points);
const minDistance = settings.eraserWidth * this.config.BRUSH_SPACING_TARGET_SCALE;
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
return;
}
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
// Do not add duplicate points
return;
}
bufferState.points.push(alignedPoint.x, alignedPoint.y);
if (bufferState.type === 'eraser_line_with_pressure') {
bufferState.points.push(e.evt.pressure);
}
await selectedEntity.bufferRenderer.setBuffer(bufferState);
} else if (tool === 'rect' && bufferState.type === 'rect') {
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
const alignedPoint = floorCoord(normalizedPoint);
bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x);
bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y);
await selectedEntity.bufferRenderer.setBuffer(bufferState);
this.tools.colorPicker.onStagePointerMove(e);
} else if (tool === 'brush') {
await this.tools.brush.onStagePointerMove(e);
} else if (tool === 'eraser') {
await this.tools.eraser.onStagePointerMove(e);
} else if (tool === 'rect') {
await this.tools.rect.onStagePointerMove(e);
} else {
selectedEntity?.bufferRenderer.clearBuffer();
this.manager.stateApi.getSelectedEntityAdapter()?.bufferRenderer.clearBuffer();
}
} finally {
this.render();
@@ -709,6 +421,10 @@ export class CanvasToolModule extends CanvasModuleBase {
};
onStagePointerLeave = (e: PointerEvent) => {
if (e.target !== this.manager.stage.container) {
return;
}
try {
this.$lastPointerType.set(e.pointerType);
this.$cursorPos.set(null);
@@ -732,6 +448,10 @@ export class CanvasToolModule extends CanvasModuleBase {
};
onStageMouseWheel = (e: KonvaEventObject<WheelEvent>) => {
if (e.target !== this.konva.stage) {
return;
}
if (!this.getCanDraw()) {
return;
}
@@ -770,7 +490,7 @@ export class CanvasToolModule extends CanvasModuleBase {
*/
onWindowPointerUp = (_: PointerEvent) => {
try {
this.$isMouseDown.set(false);
this.$isPrimaryPointerDown.set(false);
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) {
@@ -872,12 +592,18 @@ export class CanvasToolModule extends CanvasModuleBase {
config: this.config,
$tool: this.$tool.get(),
$toolBuffer: this.$toolBuffer.get(),
$isMouseDown: this.$isMouseDown.get(),
$isPrimaryPointerDown: this.$isPrimaryPointerDown.get(),
$cursorPos: this.$cursorPos.get(),
$colorUnderCursor: this.$colorUnderCursor.get(),
brushToolPreview: this.brushToolPreview.repr(),
eraserToolPreview: this.eraserToolPreview.repr(),
colorPickerToolPreview: this.colorPickerToolPreview.repr(),
$lastPointerType: this.$lastPointerType.get(),
tools: {
brush: this.tools.brush.repr(),
eraser: this.tools.eraser.repr(),
colorPicker: this.tools.colorPicker.repr(),
rect: this.tools.rect.repr(),
bbox: this.tools.bbox.repr(),
view: this.tools.view.repr(),
move: this.tools.move.repr(),
},
};
};

View File

@@ -0,0 +1,29 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Logger } from 'roarr';
export class CanvasViewToolModule extends CanvasModuleBase {
readonly type = 'view_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
}
syncCursorStyle = () => {
this.manager.stage.setCursor(this.manager.stage.getIsDragging() ? 'grabbing' : 'grab');
};
}

View File

@@ -1,6 +1,12 @@
import type { Selector, Store } from '@reduxjs/toolkit';
import { $authToken } from 'app/store/nanostores/authToken';
import type { CanvasEntityIdentifier, CanvasObjectState, Coordinate, Rect } from 'features/controlLayers/store/types';
import type {
CanvasEntityIdentifier,
CanvasObjectState,
Coordinate,
CoordinateWithPressure,
Rect,
} from 'features/controlLayers/store/types';
import type Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { Vector2d } from 'konva/lib/types';
@@ -74,6 +80,18 @@ export const offsetCoord = (coord: Coordinate, offset: Coordinate): Coordinate =
};
};
/**
* Adds two coordinates together.
* @param a The first coordinate
* @param b The second coordinate
*/
export const addCoords = (a: Coordinate, b: Coordinate): Coordinate => {
return {
x: a.x + b.x,
y: a.y + b.y,
};
};
/**
* Snaps a position to the edge of the stage if within a threshold of the edge
* @param pos The position to snap
@@ -134,7 +152,7 @@ export const snapToRect = (pos: Vector2d, rect: Rect, threshold = 10): Vector2d
* Checks if the left mouse button is currently pressed
* @param e The konva event
*/
export const getIsMouseDown = (e: KonvaEventObject<MouseEvent>): boolean => e.evt.buttons === 1;
export const getIsPrimaryPointerDown = (e: KonvaEventObject<PointerEvent>): boolean => e.evt.buttons === 1;
/**
* Checks if the stage is currently focused
@@ -545,11 +563,6 @@ export const exhaustiveCheck = (value: never): never => {
assert(false, `Unhandled value: ${value}`);
};
type CoordinateWithPressure = {
x: number;
y: number;
pressure: number;
};
export const getLastPointOfLastLineWithPressure = (
objects: CanvasObjectState[],
type: 'brush_line_with_pressure' | 'eraser_line_with_pressure'
@@ -615,6 +628,7 @@ export const getKonvaNodeDebugAttrs = (node: Konva.Node) => {
isCached: node.isCached(),
visible: node.visible(),
listening: node.listening(),
zIndex: node.zIndex(),
};
};

View File

@@ -48,9 +48,9 @@ type CanvasSettingsState = {
*/
outputOnlyMaskedRegions: boolean;
/**
* Whether to automatically process the filter when the filter configuration changes.
* Whether to automatically process the operations like filtering and auto-masking.
*/
autoProcessFilter: boolean;
autoProcess: boolean;
/**
* The snap-to-grid setting for the canvas.
*/
@@ -72,13 +72,9 @@ type CanvasSettingsState = {
*/
isolatedStagingPreview: boolean;
/**
* Whether to show only the selected layer while filtering.
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
*/
isolatedFilteringPreview: boolean;
/**
* Whether to show only the selected layer while transforming.
*/
isolatedTransformingPreview: boolean;
isolatedLayerPreview: boolean;
/**
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
*/
@@ -95,14 +91,13 @@ const initialState: CanvasSettingsState = {
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
sendToCanvas: false,
outputOnlyMaskedRegions: false,
autoProcessFilter: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,
bboxOverlay: false,
preserveMask: false,
isolatedStagingPreview: true,
isolatedFilteringPreview: true,
isolatedTransformingPreview: true,
isolatedLayerPreview: true,
pressureSensitivity: true,
};
@@ -137,8 +132,8 @@ export const canvasSettingsSlice = createSlice({
settingsOutputOnlyMaskedRegionsToggled: (state) => {
state.outputOnlyMaskedRegions = !state.outputOnlyMaskedRegions;
},
settingsAutoProcessFilterToggled: (state) => {
state.autoProcessFilter = !state.autoProcessFilter;
settingsAutoProcessToggled: (state) => {
state.autoProcess = !state.autoProcess;
},
settingsSnapToGridToggled: (state) => {
state.snapToGrid = !state.snapToGrid;
@@ -155,11 +150,8 @@ export const canvasSettingsSlice = createSlice({
settingsIsolatedStagingPreviewToggled: (state) => {
state.isolatedStagingPreview = !state.isolatedStagingPreview;
},
settingsIsolatedFilteringPreviewToggled: (state) => {
state.isolatedFilteringPreview = !state.isolatedFilteringPreview;
},
settingsIsolatedTransformingPreviewToggled: (state) => {
state.isolatedTransformingPreview = !state.isolatedTransformingPreview;
settingsIsolatedLayerPreviewToggled: (state) => {
state.isolatedLayerPreview = !state.isolatedLayerPreview;
},
settingsPressureSensitivityToggled: (state) => {
state.pressureSensitivity = !state.pressureSensitivity;
@@ -185,14 +177,13 @@ export const {
settingsInvertScrollForToolWidthChanged,
settingsSendToCanvasChanged,
settingsOutputOnlyMaskedRegionsToggled,
settingsAutoProcessFilterToggled,
settingsAutoProcessToggled,
settingsSnapToGridToggled,
settingsShowProgressOnCanvasToggled,
settingsBboxOverlayToggled,
settingsPreserveMaskToggled,
settingsIsolatedStagingPreviewToggled,
settingsIsolatedFilteringPreviewToggled,
settingsIsolatedTransformingPreviewToggled,
settingsIsolatedLayerPreviewToggled,
settingsPressureSensitivityToggled,
} = canvasSettingsSlice.actions;
@@ -219,17 +210,12 @@ export const selectOutputOnlyMaskedRegions = createCanvasSettingsSelector(
export const selectDynamicGrid = createCanvasSettingsSelector((settings) => settings.dynamicGrid);
export const selectBboxOverlay = createCanvasSettingsSelector((settings) => settings.bboxOverlay);
export const selectShowHUD = createCanvasSettingsSelector((settings) => settings.showHUD);
export const selectAutoProcessFilter = createCanvasSettingsSelector((settings) => settings.autoProcessFilter);
export const selectAutoProcess = createCanvasSettingsSelector((settings) => settings.autoProcess);
export const selectSnapToGrid = createCanvasSettingsSelector((settings) => settings.snapToGrid);
export const selectSendToCanvas = createCanvasSettingsSelector((canvasSettings) => canvasSettings.sendToCanvas);
export const selectShowProgressOnCanvas = createCanvasSettingsSelector(
(canvasSettings) => canvasSettings.showProgressOnCanvas
);
export const selectIsolatedStagingPreview = createCanvasSettingsSelector((settings) => settings.isolatedStagingPreview);
export const selectIsolatedFilteringPreview = createCanvasSettingsSelector(
(settings) => settings.isolatedFilteringPreview
);
export const selectIsolatedTransformingPreview = createCanvasSettingsSelector(
(settings) => settings.isolatedTransformingPreview
);
export const selectIsolatedLayerPreview = createCanvasSettingsSelector((settings) => settings.isolatedLayerPreview);
export const selectPressureSensitivity = createCanvasSettingsSelector((settings) => settings.pressureSensitivity);

View File

@@ -381,6 +381,13 @@ export const canvasSlice = createSlice({
return;
}
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (
state,
@@ -577,6 +584,13 @@ export const canvasSlice = createSlice({
return;
}
referenceImage.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (referenceImage.ipAdapter.model?.base === 'flux') {
referenceImage.ipAdapter.clipVisionModel = 'ViT-L';
} else if (referenceImage.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
referenceImage.ipAdapter.clipVisionModel = 'ViT-H';
}
},
rgIPAdapterCLIPVisionModelChanged: (
state,

View File

@@ -46,7 +46,7 @@ const zControlModeV2 = z.enum(['balanced', 'more_prompt', 'more_control', 'unbal
export type ControlModeV2 = z.infer<typeof zControlModeV2>;
export const isControlModeV2 = (v: unknown): v is ControlModeV2 => zControlModeV2.safeParse(v).success;
const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G']);
const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G', 'ViT-L']);
export type CLIPVisionModelV2 = z.infer<typeof zCLIPVisionModelV2>;
export const isCLIPVisionModelV2 = (v: unknown): v is CLIPVisionModelV2 => zCLIPVisionModelV2.safeParse(v).success;
@@ -89,6 +89,49 @@ const zCoordinate = z.object({
y: z.number(),
});
export type Coordinate = z.infer<typeof zCoordinate>;
const zCoordinateWithPressure = z.object({
x: z.number(),
y: z.number(),
pressure: z.number(),
});
export type CoordinateWithPressure = z.infer<typeof zCoordinateWithPressure>;
const SAM_POINT_LABELS = {
background: -1,
neutral: 0,
foreground: 1,
} as const;
const zSAMPointLabel = z.nativeEnum(SAM_POINT_LABELS);
export type SAMPointLabel = z.infer<typeof zSAMPointLabel>;
export const zSAMPointLabelString = z.enum(['background', 'neutral', 'foreground']);
export type SAMPointLabelString = z.infer<typeof zSAMPointLabelString>;
/**
* A mapping of SAM point labels (as numbers) to their string representations.
*/
export const SAM_POINT_LABEL_NUMBER_TO_STRING: Record<SAMPointLabel, SAMPointLabelString> = {
'-1': 'background',
0: 'neutral',
1: 'foreground',
};
/**
* A mapping of SAM point labels (as strings) to their numeric representations.
*/
export const SAM_POINT_LABEL_STRING_TO_NUMBER: Record<SAMPointLabelString, SAMPointLabel> = {
background: -1,
neutral: 0,
foreground: 1,
};
const zSAMPoint = z.object({
x: z.number().int().gte(0),
y: z.number().int().gte(0),
label: zSAMPointLabel,
});
export type SAMPoint = z.infer<typeof zSAMPoint>;
const zRect = z.object({
x: z.number(),
@@ -107,6 +150,9 @@ const zCanvasBrushLineState = z.object({
id: zId,
type: z.literal('brush_line'),
strokeWidth: z.number().min(1),
/**
* Points without pressure are in the format [x1, y1, x2, y2, ...]
*/
points: zPoints,
color: zRgbaColor,
clip: zRect.nullable(),
@@ -117,6 +163,9 @@ const zCanvasBrushLineWithPressureState = z.object({
id: zId,
type: z.literal('brush_line_with_pressure'),
strokeWidth: z.number().min(1),
/**
* Points with pressure are in the format [x1, y1, pressure1, x2, y2, pressure2, ...]
*/
points: zPointsWithPressure,
color: zRgbaColor,
clip: zRect.nullable(),
@@ -127,6 +176,9 @@ const zCanvasEraserLineState = z.object({
id: zId,
type: z.literal('eraser_line'),
strokeWidth: z.number().min(1),
/**
* Points without pressure are in the format [x1, y1, x2, y2, ...]
*/
points: zPoints,
clip: zRect.nullable(),
});
@@ -136,6 +188,9 @@ const zCanvasEraserLineWithPressureState = z.object({
id: zId,
type: z.literal('eraser_line_with_pressure'),
strokeWidth: z.number().min(1),
/**
* Points with pressure are in the format [x1, y1, pressure1, x2, y2, pressure2, ...]
*/
points: zPointsWithPressure,
clip: zRect.nullable(),
});
@@ -450,6 +505,12 @@ export function isFilterableEntityIdentifier(
return isRasterLayerEntityIdentifier(entityIdentifier) || isControlLayerEntityIdentifier(entityIdentifier);
}
export function isSegmentableEntityIdentifier(
entityIdentifier: CanvasEntityIdentifier
): entityIdentifier is CanvasEntityIdentifier<'raster_layer'> | CanvasEntityIdentifier<'control_layer'> {
return isRasterLayerEntityIdentifier(entityIdentifier) || isControlLayerEntityIdentifier(entityIdentifier);
}
export function isTransformableEntityIdentifier(
entityIdentifier: CanvasEntityIdentifier
): entityIdentifier is

View File

@@ -1,5 +1,7 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { t } from 'i18next';
import { PiUploadBold } from 'react-icons/pi';
@@ -7,14 +9,23 @@ const options = { postUploadAction: { type: 'TOAST' }, allowMultiple: true } as
export const GalleryUploadButton = () => {
const uploadApi = useImageUploadButton(options);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
return (
<>
<IconButton
size="sm"
alignSelf="stretch"
variant="link"
aria-label={t('accessibility.uploadImages')}
tooltip={t('accessibility.uploadImages')}
aria-label={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
tooltip={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
icon={<PiUploadBold />}
{...uploadApi.getUploadButtonProps()}
/>

View File

@@ -1,5 +1,5 @@
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlLayers/store/types';
import type { O } from 'ts-toolbelt';
import type { SetNonNullable } from 'type-fest';
/**
* Renders a value of type T as a React node.
@@ -145,6 +145,6 @@ export type BuildMetadataHandlers = <TValue, TItem>(
arg: BuildMetadataHandlersArg<TValue, TItem>
) => MetadataHandlers<TValue, TItem>;
export type ControlNetConfigMetadata = O.NonNullable<ControlNetConfig, 'model'>;
export type T2IAdapterConfigMetadata = O.NonNullable<T2IAdapterConfig, 'model'>;
export type IPAdapterConfigMetadata = O.NonNullable<IPAdapterConfig, 'model'>;
export type ControlNetConfigMetadata = SetNonNullable<ControlNetConfig, 'model'>;
export type T2IAdapterConfigMetadata = SetNonNullable<T2IAdapterConfig, 'model'>;
export type IPAdapterConfigMetadata = SetNonNullable<IPAdapterConfig, 'model'>;

View File

@@ -12,6 +12,7 @@ type ModelManagerState = {
searchTerm: string;
filteredModelType: FilterableModelType | null;
scanPath: string | undefined;
shouldInstallInPlace: boolean;
};
const initialModelManagerState: ModelManagerState = {
@@ -21,6 +22,7 @@ const initialModelManagerState: ModelManagerState = {
filteredModelType: null,
searchTerm: '',
scanPath: undefined,
shouldInstallInPlace: true,
};
export const modelManagerV2Slice = createSlice({
@@ -37,18 +39,26 @@ export const modelManagerV2Slice = createSlice({
setSearchTerm: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
state.filteredModelType = action.payload;
},
setScanPath: (state, action: PayloadAction<string | undefined>) => {
state.scanPath = action.payload;
},
shouldInstallInPlaceChanged: (state, action: PayloadAction<boolean>) => {
state.shouldInstallInPlace = action.payload;
},
},
});
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
modelManagerV2Slice.actions;
export const {
setSelectedModelKey,
setSearchTerm,
setFilteredModelType,
setSelectedModelMode,
setScanPath,
shouldInstallInPlaceChanged,
} = modelManagerV2Slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateModelManagerState = (state: any): any => {
@@ -74,3 +84,4 @@ export const selectSelectedModelKey = createModelManagerSelector((modelManager)
export const selectSelectedModelMode = createModelManagerSelector((modelManager) => modelManager.selectedModelMode);
export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm);
export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType);
export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace);

View File

@@ -1,22 +1,28 @@
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import {
selectShouldInstallInPlace,
shouldInstallInPlaceChanged,
} from 'features/modelManagerV2/store/modelManagerV2Slice';
import { t } from 'i18next';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
type SimpleImportModelConfig = {
location: string;
inplace: boolean;
};
export const InstallModelForm = memo(() => {
const inplace = useAppSelector(selectShouldInstallInPlace);
const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModel();
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
defaultValues: {
location: '',
inplace: true,
},
mode: 'onChange',
});
@@ -31,12 +37,19 @@ export const InstallModelForm = memo(() => {
installModel({
source: values.location,
inplace: values.inplace,
inplace: inplace,
onSuccess: resetForm,
onError: resetForm,
});
},
[installModel, resetForm]
[installModel, resetForm, inplace]
);
const onChangeInplace = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldInstallInPlaceChanged(e.target.checked));
},
[dispatch]
);
return (
@@ -63,7 +76,7 @@ export const InstallModelForm = memo(() => {
<FormControl>
<Flex flexDir="column" gap={2}>
<Flex gap={4}>
<Checkbox {...register('inplace')} />
<Checkbox isChecked={inplace} onChange={onChangeInplace} />
<FormLabel>
{t('modelManager.inplaceInstall')} ({t('modelManager.localOnly')})
</FormLabel>

View File

@@ -11,8 +11,13 @@ import {
InputGroup,
InputRightElement,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import {
selectShouldInstallInPlace,
shouldInstallInPlaceChanged,
} from 'features/modelManagerV2/store/modelManagerV2Slice';
import type { ChangeEvent, ChangeEventHandler } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
@@ -26,9 +31,10 @@ type ScanModelResultsProps = {
};
export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
const inplace = useAppSelector(selectShouldInstallInPlace);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const [inplace, setInplace] = useState(true);
const [installModel] = useInstallModel();
const filteredResults = useMemo(() => {
@@ -42,9 +48,12 @@ export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
setSearchTerm(e.target.value.trim());
}, []);
const onChangeInplace = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setInplace(e.target.checked);
}, []);
const onChangeInplace = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldInstallInPlaceChanged(e.target.checked));
},
[dispatch]
);
const clearSearch = useCallback(() => {
setSearchTerm('');

View File

@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
'sd-3': 'purple',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',

View File

@@ -34,6 +34,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@@ -66,6 +68,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
@@ -168,10 +171,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -6,7 +6,7 @@ import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { CLIPEmbedModelConfig, MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -19,7 +19,7 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPEmbedModelConfig | null) => {
(value: CLIPEmbedModelConfig | MainModelConfig | null) => {
if (!value) {
return;
}

Some files were not shown because too many files have changed in this diff Show More