Compare commits

...

128 Commits

Author SHA1 Message Date
psychedelicious
75acece1f1 fix(ui): excessive toasts when generating on canvas
- Add `withToast` flag to `uploadImage` util
- Skip the toast if this is not set
- Use the flag to disable toasts when canvas does internal image-uploading stuff that should be invisible to user
2024-11-08 10:30:04 +11:00
psychedelicious
a9db2ffefd fix(ui): ensure clip vision model is set correctly for FLUX IP Adapters 2024-11-08 10:02:41 +11:00
psychedelicious
cdd148b4d1 feat(ui): add toast for graph building errors 2024-11-08 10:02:41 +11:00
psychedelicious
730fabe2de feat(ui): add util to extract message from a tsafe AssertionError 2024-11-08 10:02:41 +11:00
psychedelicious
6c59790a7f chore: bump version to v5.4.1rc2 2024-11-08 10:00:20 +11:00
psychedelicious
c37251d6f7 tweak(ui): workflow linear field styling 2024-11-08 07:39:09 +11:00
psychedelicious
2854210162 fix(ui): dnd autoscroll on elements w/ custom scrollbar
Have to do a bit of fanagling to get it to work and get `pragmatic-drag-and-drop` to not complain.
2024-11-08 07:39:09 +11:00
psychedelicious
5545b980af fix(ui): workflow field sorting doesn't use unique identifier for fields 2024-11-08 07:39:09 +11:00
psychedelicious
0c9434c464 chore(ui): lint 2024-11-08 07:39:09 +11:00
psychedelicious
8771de917d feat(ui): migrate fullscreen drop zone to pdnd 2024-11-08 07:39:09 +11:00
psychedelicious
122946ef4c feat(ui): DndDropOverlay supports react node for label 2024-11-08 07:39:09 +11:00
psychedelicious
2d974f670c feat(ui): restore missing upload buttons 2024-11-08 07:39:09 +11:00
psychedelicious
75f0da9c35 fix(ui): use revised uploader for CL empty state 2024-11-08 07:39:09 +11:00
psychedelicious
5df3c00e28 feat(ui): remove SerializableObject, use type-fest's JsonObject 2024-11-08 07:39:09 +11:00
psychedelicious
b049880502 fix(ui): uploads initiated from canvas 2024-11-08 07:39:09 +11:00
psychedelicious
e5293fdd1a fix(ui): match new default controlnet behaviour 2024-11-08 07:39:09 +11:00
psychedelicious
8883775762 feat(ui): rework image uploads (wip) 2024-11-08 07:39:09 +11:00
psychedelicious
cfadb313d2 fix(ui): ts issues 2024-11-08 07:39:09 +11:00
psychedelicious
b5cadd9a1a fix(ui): scroll issue w/ boards list 2024-11-08 07:39:09 +11:00
psychedelicious
5361b6e014 refactor(ui): image actions sep of concerns 2024-11-08 07:39:09 +11:00
psychedelicious
ff346172af feat(ui): use new image actions system for image menu 2024-11-08 07:39:09 +11:00
psychedelicious
92f660018b refactor(ui): dnd actions to image actions
We don't need a "dnd" image system. We need a "image action" system. We need to execute specific flows with images from various "origins":
- internal dnd e.g. from gallery
- external dnd e.g. user drags an image file into the browser
- direct file upload e.g. user clicks an upload button
- some other internal app button e.g. a context menu

The actions are now generalized to better support these various use-cases.
2024-11-08 07:39:09 +11:00
psychedelicious
1afc2cba4e feat(ui): support different labels for external drop targets (e.g. uploads) 2024-11-08 07:39:09 +11:00
psychedelicious
ee8359242c feat(ui): more dnd cleanup and tidy 2024-11-08 07:39:09 +11:00
psychedelicious
f0c80a8d7a tidy(ui): dnd stuff 2024-11-08 07:39:09 +11:00
psychedelicious
8da9e7c1f6 fix(ui): min height for workflow image field drop target 2024-11-08 07:39:09 +11:00
psychedelicious
6d7a486e5b feat(ui): restore dnd to workflow fields 2024-11-08 07:39:09 +11:00
psychedelicious
57122c6aa3 feat(ui): layer reordering styling 2024-11-08 07:39:09 +11:00
psychedelicious
54abd8d4d1 feat(ui): dnd layer reordering (wip) 2024-11-08 07:39:09 +11:00
psychedelicious
06283cffed feat(ui): use custom drag previews for images 2024-11-08 07:39:09 +11:00
psychedelicious
27fa0e1140 tidy(ui): more efficient dnd overlay styling 2024-11-08 07:39:09 +11:00
psychedelicious
533d48abdb feat(ui): multi-image drag preview 2024-11-08 07:39:09 +11:00
psychedelicious
6845cae4c9 tidy(ui): move new dnd impl into features/dnd 2024-11-08 07:39:09 +11:00
psychedelicious
31c9acb1fa tidy(ui): clean up old dnd stuff 2024-11-08 07:39:09 +11:00
psychedelicious
fb5e462300 tidy(ui): document & clean up dnd 2024-11-08 07:39:09 +11:00
psychedelicious
2f3abc29b1 feat(ui): better types for getData 2024-11-08 07:39:09 +11:00
psychedelicious
c5c071f285 feat(ui): better type name 2024-11-08 07:39:09 +11:00
psychedelicious
93a3ed56e7 feat(ui): simpler dnd typing implementation 2024-11-08 07:39:09 +11:00
psychedelicious
406fc58889 feat(ui): migrate to pragmatic-drag-and-drop (wip 4) 2024-11-08 07:39:09 +11:00
psychedelicious
cf67d084fd feat(ui): migrate to pragmatic-drag-and-drop (wip 3) 2024-11-08 07:39:09 +11:00
psychedelicious
d4a95af14f perf(ui): more gallery perf improvements 2024-11-08 07:39:09 +11:00
psychedelicious
8c8e7102c2 perf(ui): improved gallery perf 2024-11-08 07:39:09 +11:00
psychedelicious
b6b9ea9d70 feat(ui): migrate to pragmatic-drag-and-drop (wip 2) 2024-11-08 07:39:09 +11:00
psychedelicious
63126950bc feat(ui): migrate to pragmatic-drag-and-drop (wip) 2024-11-08 07:39:09 +11:00
psychedelicious
29d63d5dea fix(app): silence pydantic protected namespace warning
Closes #7287
2024-11-08 07:36:50 +11:00
Jonathan
2f6b035138 Update flux_denoise.py
Added a bool to allow the node user to add noise in to initial latents (default) or to leave them alone.
2024-11-07 08:44:10 -05:00
psychedelicious
4f9ae44472 chore(ui): bump version to v5.4.1rc1 2024-11-07 12:19:28 +11:00
psychedelicious
c682330852 feat(ui): updated whats new handling and v5.4.1 items 2024-11-07 12:19:28 +11:00
Brandon Rising
c064257759 fix: Look in known subfolders for configs for clip variants 2024-11-07 12:01:02 +11:00
Brandon Rising
8a4c629576 fix: Avoid downloading unsafe .bin files if a safetensors file is available 2024-11-06 19:31:18 -05:00
psychedelicious
a01d44f813 chore(ui): lint 2024-11-06 10:25:46 -05:00
psychedelicious
63fb3a15e9 feat(ui): default to no control model selected for control layers 2024-11-06 10:25:46 -05:00
psychedelicious
4d0837541b feat(ui): add simple mode filtering 2024-11-06 10:25:46 -05:00
psychedelicious
999809b4c7 fix(ui): minor viewer close button styling 2024-11-06 10:25:46 -05:00
psychedelicious
c452edfb9f feat(ui): add control layer empty state 2024-11-06 10:25:46 -05:00
psychedelicious
ad2cdbd8a2 feat(ui): tooltip for canvas preview image 2024-11-06 10:25:46 -05:00
psychedelicious
f15c24bfa7 feat(ui): add " (recommended)" to balanced control mode label 2024-11-06 10:25:46 -05:00
psychedelicious
d1f653f28c feat(ui): make default control end step 0.75 2024-11-06 10:25:46 -05:00
psychedelicious
244465d3a6 feat(ui): make default control weight 0.75 2024-11-06 10:25:46 -05:00
psychedelicious
c6236ab70c feat(ui): add menubar-ish header on comparison 2024-11-06 10:25:46 -05:00
psychedelicious
644d5cb411 feat(ui): add menubar-ish header on viewer 2024-11-06 10:25:46 -05:00
Riku
bb0a630416 fix(ui): adjust knip config to ignore parameter schema exports 2024-11-06 22:51:17 +11:00
Riku
2148ae9287 feat(ui): simplify parameter schema declaration and type inference 2024-11-06 22:51:17 +11:00
psychedelicious
42d242609c chore(gh): update pr template w/ reminder for what's new copy 2024-11-06 19:03:31 +11:00
psychedelicious
fd0a52392b feat(ui): added line about when denoising str is disabled 2024-11-06 19:01:33 +11:00
psychedelicious
e64415d59a feat(ui): revised logic to disable denoising str 2024-11-06 19:01:33 +11:00
psychedelicious
1871e0bdbf feat(ui): tweaked denoise str styling 2024-11-06 19:01:33 +11:00
Mary Hipp
3ae9a965c2 lint 2024-11-06 19:01:33 +11:00
Mary Hipp
85932e35a7 update copy again 2024-11-06 19:01:33 +11:00
Mary Hipp
41b07a56cc update popover copy and add image 2024-11-06 19:01:33 +11:00
Mary Hipp
54064c0cb8 fix(ui): match badge height to slider height so layout does not shift 2024-11-06 19:01:33 +11:00
Mary Hipp
68284b37fa remove opacity logic from WavyLine, add badge explaining disabled state, add translations 2024-11-06 19:01:33 +11:00
Mary Hipp
ae5bc6f5d6 feat(ui): move denoising strength to layers panel w/ visualization of how much change will be applied, only enable if 1+ enabled raster layer 2024-11-06 19:01:33 +11:00
Mary Hipp
6dc16c9f54 wip 2024-11-06 19:01:33 +11:00
Brandon Rising
faa9ac4e15 fix: get_clip_variant_type should never return None 2024-11-06 09:59:50 +11:00
Mary Hipp Rogers
d0460849b0 fix bad merge conflict (#7273)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-11-05 16:02:03 -05:00
Mary Hipp Rogers
bed3c2dd77 update Whats New for 5.3.1 (#7272)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-11-05 15:43:16 -05:00
Mary Hipp
916ddd17d7 fix(ui): fix link for infill method popover 2024-11-05 15:39:03 -05:00
Mary Hipp
accfa7407f fix undefined 2024-11-05 15:30:17 -05:00
Mary Hipp
908db31e48 feat(api,ui): allow Whats New module to get content from back-end 2024-11-05 15:30:17 -05:00
Mary Hipp
b70f632b26 fix(ui): add some feedback while layers are merging 2024-11-05 12:38:50 -05:00
Brandon Rising
d07a6385ab Always default to ClipVariantType.L instead of None 2024-11-05 12:03:40 -05:00
Brandon Rising
68df612fa1 fix: Never throw an exception when finding the clip variant type 2024-11-05 12:03:40 -05:00
psychedelicious
3b96c79461 chore: bump version to v5.4.0 2024-11-05 10:09:21 +11:00
psychedelicious
89bda5b983 Ryan/sd3 diffusers (#7222)
## Summary

Nodes to support SD3.5 txt2img generations
* adds SD3.5 to starter models
* adds default workflow for SD3.5 txt2img

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-11-05 08:21:28 +11:00
Brandon Rising
22bff1fb22 Fix conditional within filter_by_variant to not read all candidates as default 2024-11-04 12:42:09 -05:00
Mary Hipp
55ba6488d1 fix up types file 2024-11-04 12:42:09 -05:00
brandonrising
2d78859171 Create bespoke latents to image node for sd3 2024-11-04 12:42:09 -05:00
Mary Hipp
3a661bac34 fix(ui): exclude submodels from model manager 2024-11-04 12:42:09 -05:00
Mary Hipp
bb8a02de18 update schema 2024-11-04 12:42:09 -05:00
maryhipp
78155344f6 update node fields for SD3 to match other SD nodes 2024-11-04 12:42:09 -05:00
Brandon Rising
391a24b0f6 Re-add erroniously removed hash code 2024-11-04 12:42:09 -05:00
Brandon Rising
e75903389f Run ruff, fix bug in hf downloading code which failed to download parts of a model 2024-11-04 12:42:09 -05:00
Brandon Rising
27567052f2 Create new latent factors for sd35 2024-11-04 12:42:09 -05:00
Brandon Rising
6f447f7169 Rather than .fp16., some repos start the suffix with .fp16... for weights spread across multiple files 2024-11-04 12:42:09 -05:00
Mary Hipp
8b370cc182 (ui): dont show SD3 in main model dropdown yet 2024-11-04 12:42:09 -05:00
maryhipp
af583d2971 ruff format 2024-11-04 12:42:09 -05:00
Mary Hipp
0ebe8fb1bd (ui): add required/optional logic to other submodel fields 2024-11-04 12:42:09 -05:00
maryhipp
befb629f46 add default workflow 2024-11-04 12:42:09 -05:00
maryhipp
874d67cb37 add SD3.5 to starter models 2024-11-04 12:42:09 -05:00
Mary Hipp
19f7a1295a (ui): add fields for CLIP-L and CLIP-G, remove MainModelConfig type changes 2024-11-04 12:42:09 -05:00
maryhipp
78bd605617 (nodes,api): expose the submodels on SD3 model loader as optional, add types needed for CLIP-L and CLIP-G fields 2024-11-04 12:42:09 -05:00
Brandon Rising
b87f4e59a5 Create clip variant type, create new fucntions for discerning clipL and clipG in the frontend 2024-11-04 12:42:09 -05:00
Ryan Dick
1eca4f12c8 Make T5 encoder optonal in SD3 workflows. 2024-11-04 12:42:09 -05:00
Ryan Dick
f1de11d6bf Make the default CFG for SD3 3.5. 2024-11-04 12:42:09 -05:00
Ryan Dick
9361ed9d70 Add progress images to SD3 and make denoising cancellable. 2024-11-04 12:42:09 -05:00
Brandon Rising
ebabf4f7a8 Setup Model and T5 Encoder selection fields for sd3 nodes 2024-11-04 12:42:09 -05:00
Brandon Rising
606f3321f5 Initial wave of frontend updates for sd-3 node inputs 2024-11-04 12:42:09 -05:00
Brandon Rising
3970aa30fb define submodels on sd3 models during probe 2024-11-04 12:42:09 -05:00
Ryan Dick
678436e07c Add tqdm progress bar for SD3. 2024-11-04 12:42:09 -05:00
Ryan Dick
c620581699 Bug fixes to get SD3 text-to-image workflow running. 2024-11-04 12:42:09 -05:00
Ryan Dick
c331d42ce4 Temporary hack for testing SD3 model loader. 2024-11-04 12:42:09 -05:00
Ryan Dick
1ac9b502f1 Fix Sd3TextEncoderInvocation output type. 2024-11-04 12:42:09 -05:00
Ryan Dick
3fa478a12f Initial draft of SD3DenoiseInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
2d86298b7f Add first draft of Sd3TextEncoderInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
009cdb714c Add Sd3ModelLoaderInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
9d3f5427b4 Move FluxModelLoaderInvocation to its own file. model.py was getting bloated. 2024-11-04 12:42:09 -05:00
Ryan Dick
e4b17f019a Get diffusers SD3 model probing working. 2024-11-04 12:42:09 -05:00
Ryan Dick
586c00bc02 (minor) Remove unused dict. 2024-11-04 12:42:09 -05:00
Eugene Brodsky
0f11fda65a fix(deps): pin mediapipe strictly to a known working version 2024-11-04 10:16:19 -05:00
psychedelicious
3e75331ef7 fix(ui): load workflow from file
In a8de6406c5 a change was made to many menus in an effort to improve performance. The menus were made to be lazy, so that they are mounted only while open.

This causes unexpected behaviour when there is some logic in the menu that may need to execute after the user selects a menu item.

In this case, when you click to load a workflow from file, the file picker opens but then the menuitem unmounts, taking the input element and all uploading logic with it. When you select a file, nothing happens because we've nuked the handlers by unmounting everything.

Easy fix - un-lazy-fy the menu.

Closes #7240
2024-11-04 08:02:55 -05:00
psychedelicious
be133408ac fix(nodes): relaxed validation for segment anything
The validation on this node causes graph validation to valid. It must be validated _after_ instantiation.

Also, it was a bit too strict. The only case we explicitly do not handle is when both bboxes and points are provided. It's acceptable if neither are provided.

Closes #7248
2024-11-04 08:00:52 -05:00
psychedelicious
7e1e0d6928 fix(ui): non-default filters can erase layer
When filtering, we use a listener to trigger processing the image whenever a filter setting changes. For example, if the user changes from canny to depth, and auto-process is enabled, we re-process the layer with new filter settings.

The filterer has a method to reset its ephemeral state. This includes the filter settings, so resetting the ephemeral state is expected to trigger processing of the filter.

When we exit filtering, we reset the ephemeral state before resetting everything else, like the listeners.

This can cause problem when we exit filtering. The sequence:
- Start filtering a layer.
- Auto-process the filter in response to starting the filter process.
- Change the filter settings.
- Auto-process the filter in response to the changed settings.
- Apply the filter.
- Exit filtering, first by resetting the ephemeral state.
- Auto-process the filter in response to the reset settings.*
- Finish exiting, including unsubscribing from listeners.

*Whoops! That last auto-process has now borked the layer's rendering by processing a filter when we shouldn't be processing a filter.

We need to first unsubscribe from listeners, so we don't react to that change to the filter settings and erroneously process the layer.

Also, add a check to the `processImmediate` method to prevent processing if that method is accidentally called without first starting the filterer.

The same issue could affect the segmenyanything module - same fixes are implemented there.
2024-11-04 07:11:20 -05:00
psychedelicious
cd3d8df5a8 fix(ui): save canvas to gallery does nothing
The root issue is the compositing cache. When we save the canvas to gallery, we need to first composite raster layers together and then upload the image.

The compositor makes extensive use of caching to reduce the number of images created and improve performance. There are two "layers" of caching:
1. Caching the composite canvas element, which is used both for uploading the canvas and for generation mode analysis.
2. Caching the uploaded composite canvas element as an image.

The combination of these caches allows for the various processes that require composite canvases to do minimal work.

But this causes a problem in this situation, because the user expects a new image to be uploaded when they click save to gallery.

For example, suppose we have already composited and uploaded the raster layer state for use in a generation. Then, we ask the compositor to save the canvas to gallery.

The compositor sees that we are requesting an image for the current canvas state, and instead of recompositing and uploading the image again, it just returns the cached image.

In this case, no image is uploaded and it the button does nothing.

We need to be able to opt out of the caching at some level, for certain actions. A `forceUpload` arg is added to the compositor's high-level `getCompositeImageDTO` method to do this.

When true, we ignore the uppermost caching layer (the uploaded image layer), but still use the lower caching layer (the canvas element layer). So we don't recompute the canvas element, but we do upload it as a new image to the server.
2024-11-04 07:11:20 -05:00
psychedelicious
24d3c22017 fix(ui): temp fix for stuck tooltips 2024-11-04 07:11:20 -05:00
psychedelicious
b0d37f4e51 fix(ui): progress image does not reset when canceling generation
Previously, we cleared the canvas progress image when the canvas had no active generations. This allowed for a brief flash of canvas state between the last progress image for a given generation, and when the output image for that generation rendered. Here's the sequence:
- Progress images are received and rendered
- Generation completes - no active canvas generations
- Clear the progress image -> canvas layers visible unexpectedly, creating an awkward jarring change
- Generation output image is rendered -> output image overlaid on canvas layers

In 83538c4b2b I attempted to fix this by only clearing the progress image while we were not staging.

This isn't quite right, though. We are often staging with no active generations - for example, you have a few images completed and are waiting to choose one.

In this situation, if you cancel a pending generation, the logic to clear the progress image doesn't fire because it sees staging is in progress.

What we really need is:
- Staging area module clears the progress image once it has rendered an output image.
- Progress image module clears the progress image when a generation is canceled or failed, in which case there will be no output image.

To do this, we can add an event listener to the progress image module to listen for queue item status changes, and when we get a cancelation or failure, clear the progress image.
2024-11-04 07:11:20 -05:00
psychedelicious
3559124674 feat(ui): use nanostores in CanvasProgressImageModule for internal state 2024-11-04 07:11:20 -05:00
Eugene Brodsky
6c33e02141 fix(pkg): pin torch to <2.5.0 to prevent unnecessary downloads
pip's dependency resolution doesn't take into account transitive
dependencies when choosing package versions for download.
Even though `torch=~2.4.1` is required by `diffusers`, pip will
download 2.5.0 and higher, but only install 2.4.1.
Pinning torch to <2.5.0 prevents this behaviour.
2024-11-01 12:27:28 -04:00
217 changed files with 7002 additions and 3820 deletions

View File

@@ -19,3 +19,4 @@
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_

View File

@@ -40,6 +40,8 @@ class AppVersion(BaseModel):
version: str = Field(description="App version")
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
class AppDependencyVersions(BaseModel):
"""App depencency Versions Response"""

View File

@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@@ -52,6 +53,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@@ -131,8 +134,10 @@ class FieldDescriptions:
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
clip_g_model = "CLIP-G Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
mmditx = "MMDiTX"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -140,6 +145,7 @@ class FieldDescriptions:
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -246,6 +252,12 @@ class FluxConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -56,7 +56,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.2.0",
version="3.2.1",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -81,6 +81,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
@@ -207,9 +208,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"to be poor. Consider using a FLUX dev model instead."
)
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
if self.add_noise:
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
x = init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:

View File

@@ -0,0 +1,89 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
SubModelType,
)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
SD3ConditioningField,
TensorField,
UIComponent,
)
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,260 @@
from typing import Callable, Tuple
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
SD3ConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_denoise",
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
)
positive_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
joint_attention_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
sd3_conditioning = cond_data.conditionings[0]
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
t5_embeds = sd3_conditioning.t5_embeds
if t5_embeds is None:
t5_embeds = torch.zeros(
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
device=device,
dtype=dtype,
)
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
pooled_prompt_embeds = torch.cat(
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
)
return prompt_embeds, pooled_prompt_embeds
def _get_noise(
self,
num_samples: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
"""Prepare the CFG scale list.
Args:
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
on the scheduler used (e.g. higher order schedulers).
Returns:
list[float]: _description_
"""
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)
noise = self._get_noise(
num_samples=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=latents,
),
)
with transformer_info.model_on_device() as (cached_weights, transformer):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=None,
return_dict=False,
)[0]
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
return step_callback

View File

@@ -0,0 +1,73 @@
from contextlib import nullcontext
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_l2i",
title="SD3 Latents to Image",
tags=["latents", "image", "vae", "l2i", "sd3"],
category="latents",
version="1.3.0",
)
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)
vae.disable_tiling()
tiling_context = nullcontext()
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
img = vae.decode(latents, return_dict=False)[0]
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,108 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("sd3_model_loader_output")
class Sd3ModelLoaderOutput(BaseInvocationOutput):
"""SD3 base model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"sd3_model_loader",
title="SD3 Main Model",
tags=["model", "sd3"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
)
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="T5 Encoder",
default=None,
)
clip_l_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPLEmbedModel,
input=Input.Direct,
title="CLIP L Encoder",
default=None,
)
clip_g_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_g_model,
ui_type=UIType.CLIPGEmbedModel,
input=Input.Direct,
title="CLIP G Encoder",
default=None,
)
vae_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
)
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = (
self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
if self.vae_model
else self.model.model_copy(update={"submodel_type": SubModelType.VAE})
)
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = (
self.clip_l_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
if self.clip_l_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
)
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
clip_encoder_g = (
self.clip_g_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,199 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@invocation(
"sd3_text_encoder",
title="SD3 Text Encoding",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""
clip_l: CLIPField = InputField(
title="CLIP L",
description=FieldDescriptions.clip,
input=Input.Connection,
)
clip_g: CLIPField = InputField(
title="CLIP G",
description=FieldDescriptions.clip,
input=Input.Connection,
)
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
t5_encoder: T5EncoderField | None = InputField(
title="T5Encoder",
default=None,
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
t5_embeddings: torch.Tensor | None = None
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(
conditionings=[
SD3ConditioningInfo(
clip_l_embeds=clip_l_embeddings,
clip_l_pooled_embeds=clip_l_pooled_embeddings,
clip_g_embeds=clip_g_embeddings,
clip_g_pooled_embeds=clip_g_pooled_embeddings,
t5_embeds=t5_embeddings,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return SD3ConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
text_inputs = t5_tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
text_inputs = clip_tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -5,7 +5,7 @@ from typing import Literal
import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
@@ -77,19 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
default="all",
)
@model_validator(mode="after")
def check_point_lists_or_bounding_box(self):
if self.point_lists is None and self.bounding_boxes is None:
raise ValueError("Either point_lists or bounding_box must be provided.")
elif self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
return self
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
not self.point_lists or len(self.point_lists) == 0
):

View File

@@ -15,6 +15,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ClipVariantType,
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
@@ -85,7 +86,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)

View File

@@ -0,0 +1,382 @@
{
"name": "SD3.5 Text to Image",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 3.5",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, SD3.5, default",
"notes": "",
"exposedFields": [
{
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"fieldName": "model"
},
{
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"fieldName": "prompt"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
"nodes": [
{
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "invocation",
"data": {
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "sd3_model_loader",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"model": {
"name": "model",
"label": "",
"value": {
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
"hash": "placeholder",
"name": "stable-diffusion-3.5-medium",
"base": "sd-3",
"type": "main"
}
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_l_model": {
"name": "clip_l_model",
"label": ""
},
"clip_g_model": {
"name": "clip_g_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": -55.58689609637031,
"y": -111.53602444662268
}
},
{
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "invocation",
"data": {
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"nodePack": "invokeai",
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 470.45870147220353,
"y": 350.3141781644303
}
},
{
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "invocation",
"data": {
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "sd3_l2i",
"version": "1.3.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1192.3097009334897,
"y": -366.0994675072209
}
},
{
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "invocation",
"data": {
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 408.16054647924784,
"y": 65.06415352118786
}
},
{
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "invocation",
"data": {
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 378.9283412440941,
"y": -302.65777497352553
}
},
{
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "invocation",
"data": {
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "sd3_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_conditioning": {
"name": "positive_conditioning",
"label": ""
},
"negative_conditioning": {
"name": "negative_conditioning",
"label": ""
},
"cfg_scale": {
"name": "cfg_scale",
"label": "",
"value": 3.5
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"steps": {
"name": "steps",
"label": "",
"value": 30
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 813.7814762740603,
"y": -142.20529727605867
}
}
],
"edges": [
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
"type": "default",
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
"type": "default",
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
"type": "default",
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
"type": "default",
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
}
]
}

View File

@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
[-0.1307, -0.1874, -0.7445], # L4
]
SD3_5_LATENT_RGB_FACTORS = [
[-0.05240681, 0.03251581, 0.0749016],
[-0.0580572, 0.00759826, 0.05729818],
[0.16144888, 0.01270368, -0.03768577],
[0.14418615, 0.08460266, 0.15941818],
[0.04894035, 0.0056485, -0.06686988],
[0.05187166, 0.19222395, 0.06261094],
[0.1539433, 0.04818359, 0.07103094],
[-0.08601796, 0.09013458, 0.10893912],
[-0.12398469, -0.06766567, 0.0033688],
[-0.0439737, 0.07825329, 0.02258823],
[0.03101129, 0.06382551, 0.07753657],
[-0.01315361, 0.08554491, -0.08772475],
[0.06464487, 0.05914605, 0.13262741],
[-0.07863674, -0.02261737, -0.12761454],
[-0.09923835, -0.08010759, -0.06264447],
[-0.03392309, -0.0804029, -0.06078822],
]
FLUX_LATENT_RGB_FACTORS = [
[-0.0412, 0.0149, 0.0521],
[0.0056, 0.0291, 0.0768],
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
elif base_model == BaseModelType.StableDiffusion3:
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
else:
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)

View File

@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
@@ -92,6 +95,13 @@ class SubModelType(str, Enum):
SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum):
"""Variant type."""
@@ -147,6 +157,17 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant = None
model_config = ConfigDict(protected_namespaces=())
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
@@ -193,6 +214,9 @@ class ModelConfigBase(BaseModel):
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
class CheckpointConfigBase(ModelConfigBase):
@@ -335,7 +359,7 @@ class MainConfigBase(ModelConfigBase):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType = ModelVariantType.Normal
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
@@ -419,12 +443,33 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-G Embeddings."""
variant: ClipVariantType = ClipVariantType.G
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-L Embeddings."""
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@@ -501,6 +546,8 @@ AnyModelConfig = Annotated[
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
raise ValueError(

View File

@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
config: AnyModelConfig,

View File

@@ -1,7 +1,7 @@
import json
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union
import safetensors.torch
import spandrel
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
AnyVariant,
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
@@ -33,8 +34,15 @@ from invokeai.backend.model_manager.config import (
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubmodelDefinition,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import (
get_clip_variant_type,
lora_token_vector_length,
read_checkpoint_meta,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -112,6 +120,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
@@ -122,8 +131,12 @@ class ModelProbe(object):
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
}
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
@@ -170,7 +183,10 @@ class ModelProbe(object):
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
fields["variant"] = (
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
)
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
@@ -217,6 +233,10 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
get_submodels = getattr(probe, "get_submodels", None)
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@@ -747,18 +767,33 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
config_path = self.model_path / "unet" / "config.json"
if config_path.exists():
with open(config_path) as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a transformer (i.e. SD3).
config_path = self.model_path / "transformer" / "config.json"
if config_path.exists():
with open(config_path) as file:
transformer_conf = json.load(file)
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
@@ -770,6 +805,23 @@ class PipelineFolderProbe(FolderProbeBase):
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
submodels: Dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type,
variant=variant_func and variant_func((self.model_path / key).as_posix()),
)
return submodels
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the

View File

@@ -140,6 +140,22 @@ flux_dev = StarterModel(
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-medium",
description="Medium SD3.5 Model: ~15GB",
type=ModelType.Main,
dependencies=[],
)
sd35_large = StarterModel(
name="SD3.5 Large",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-large",
description="Large SD3.5 Model: ~19G",
type=ModelType.Main,
dependencies=[],
)
cyberrealistic_sd1 = StarterModel(
name="CyberRealistic v4.1",
base=BaseModelType.StableDiffusion1,
@@ -570,6 +586,8 @@ STARTER_MODELS: list[StarterModel] = [
flux_dev_quantized,
flux_schnell,
flux_dev,
sd35_medium,
sd35_large,
cyberrealistic_sd1,
rev_animated_sd1,
dreamshaper_8_sd1,

View File

@@ -8,6 +8,7 @@ import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -165,3 +166,25 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k]
return original_state_dict
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
try:
path = Path(location)
config_path = path / "config.json"
if not config_path.exists():
config_path = path / "text_encoder" / "config.json"
if not config_path.exists():
return ClipVariantType.L
with open(config_path) as file:
clip_conf = json.load(file)
hidden_size = clip_conf.get("hidden_size", -1)
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return ClipVariantType.L
except Exception:
return ClipVariantType.L

View File

@@ -85,6 +85,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result: set[Path] = set()
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
safetensors_detected = False
for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX:
@@ -119,19 +120,27 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
if safetensors_detected and path.suffix == ".bin":
continue
parent = path.parent
score = 0
if path.suffix == ".safetensors":
safetensors_detected = True
if parent in subfolder_weights:
subfolder_weights[parent] = [sfc for sfc in subfolder_weights[parent] if sfc.path.suffix != ".bin"]
score += 1
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
):
if (
variant is not ModelRepoVariant.Default
and candidate_variant_label
and candidate_variant_label.startswith(f".{variant.value}")
) or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]):
score += 1
if parent not in subfolder_weights:
@@ -146,7 +155,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
at_least_one_fp16 = True
break
@@ -162,7 +171,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
match = re.match(pattern, highest_score_candidate.path.as_posix())
if match:
for candidate in candidate_list:
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
match.group(2)
):
result.add(candidate.path)
else:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list

View File

@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
return self
@dataclass
class SD3ConditioningInfo:
clip_l_pooled_embeds: torch.Tensor
clip_l_embeds: torch.Tensor
clip_g_pooled_embeds: torch.Tensor
clip_g_embeds: torch.Tensor
t5_embeds: torch.Tensor | None
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
if self.t5_embeds is not None:
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]
| List[FLUXConditioningInfo]
| List[SD3ConditioningInfo]
)
@dataclass

View File

@@ -9,6 +9,7 @@ const config: KnipConfig = {
'src/services/api/schema.ts',
'src/features/nodes/types/v1/**',
'src/features/nodes/types/v2/**',
'src/features/parameters/types/parameterSchemas.ts',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
// TODO(psyche): restore HRF functionality?

View File

@@ -52,11 +52,11 @@
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
"@dagrejs/dagre": "^1.1.4",
"@dagrejs/graphlib": "^2.2.4",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.1.0",
"@invoke-ai/ui-library": "^0.0.43",
"@nanostores/react": "^0.7.3",

View File

@@ -5,21 +5,21 @@ settings:
excludeLinksFromLockfile: false
dependencies:
'@atlaskit/pragmatic-drag-and-drop':
specifier: ^1.4.0
version: 1.4.0
'@atlaskit/pragmatic-drag-and-drop-auto-scroll':
specifier: ^1.4.0
version: 1.4.0
'@atlaskit/pragmatic-drag-and-drop-hitbox':
specifier: ^1.0.3
version: 1.0.3
'@dagrejs/dagre':
specifier: ^1.1.4
version: 1.1.4
'@dagrejs/graphlib':
specifier: ^2.2.4
version: 2.2.4
'@dnd-kit/core':
specifier: ^6.1.0
version: 6.1.0(react-dom@18.3.1)(react@18.3.1)
'@dnd-kit/sortable':
specifier: ^8.0.0
version: 8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1)
'@dnd-kit/utilities':
specifier: ^3.2.2
version: 3.2.2(react@18.3.1)
'@fontsource-variable/inter':
specifier: ^5.1.0
version: 5.1.0
@@ -319,6 +319,28 @@ packages:
'@jridgewell/trace-mapping': 0.3.25
dev: true
/@atlaskit/pragmatic-drag-and-drop-auto-scroll@1.4.0:
resolution: {integrity: sha512-5GoikoTSW13UX76F9TDeWB8x3jbbGlp/Y+3aRkHe1MOBMkrWkwNpJ42MIVhhX/6NSeaZiPumP0KbGJVs2tOWSQ==}
dependencies:
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
'@babel/runtime': 7.25.7
dev: false
/@atlaskit/pragmatic-drag-and-drop-hitbox@1.0.3:
resolution: {integrity: sha512-/Sbu/HqN2VGLYBhnsG7SbRNg98XKkbF6L7XDdBi+izRybfaK1FeMfodPpm/xnBHPJzwYMdkE0qtLyv6afhgMUA==}
dependencies:
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
'@babel/runtime': 7.25.7
dev: false
/@atlaskit/pragmatic-drag-and-drop@1.4.0:
resolution: {integrity: sha512-qRY3PTJIcxfl/QB8Gwswz+BRvlmgAC5pB+J2hL6dkIxgqAgVwOhAamMUKsrOcFU/axG2Q7RbNs1xfoLKDuhoPg==}
dependencies:
'@babel/runtime': 7.25.7
bind-event-listener: 3.0.0
raf-schd: 4.0.3
dev: false
/@babel/code-frame@7.25.7:
resolution: {integrity: sha512-0xZJFNE5XMpENsgfHYTw8FbX4kv53mFLn2i3XPoq69LyhYSCBJtitaHx9QnsVTrsogI4Z3+HtEfZ2/GFPOtf5g==}
engines: {node: '>=6.9.0'}
@@ -980,49 +1002,6 @@ packages:
engines: {node: '>17.0.0'}
dev: false
/@dnd-kit/accessibility@3.1.0(react@18.3.1):
resolution: {integrity: sha512-ea7IkhKvlJUv9iSHJOnxinBcoOI3ppGnnL+VDJ75O45Nss6HtZd8IdN8touXPDtASfeI2T2LImb8VOZcL47wjQ==}
peerDependencies:
react: '>=16.8.0'
dependencies:
react: 18.3.1
tslib: 2.7.0
dev: false
/@dnd-kit/core@6.1.0(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-J3cQBClB4TVxwGo3KEjssGEXNJqGVWx17aRTZ1ob0FliR5IjYgTxl5YJbKTzA6IzrtelotH19v6y7uoIRUZPSg==}
peerDependencies:
react: '>=16.8.0'
react-dom: '>=16.8.0'
dependencies:
'@dnd-kit/accessibility': 3.1.0(react@18.3.1)
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
tslib: 2.7.0
dev: false
/@dnd-kit/sortable@8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1):
resolution: {integrity: sha512-U3jk5ebVXe1Lr7c2wU7SBZjcWdQP+j7peHJfCspnA81enlu88Mgd7CC8Q+pub9ubP7eKVETzJW+IBAhsqbSu/g==}
peerDependencies:
'@dnd-kit/core': ^6.1.0
react: '>=16.8.0'
dependencies:
'@dnd-kit/core': 6.1.0(react-dom@18.3.1)(react@18.3.1)
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
react: 18.3.1
tslib: 2.7.0
dev: false
/@dnd-kit/utilities@3.2.2(react@18.3.1):
resolution: {integrity: sha512-+MKAJEOfaBe5SmV6t34p80MMKhjvUz0vRrvVJbPT0WElzaOJ/1xs+D+KDv+tD/NE5ujfrChEcshd4fLn0wpiqg==}
peerDependencies:
react: '>=16.8.0'
dependencies:
react: 18.3.1
tslib: 2.7.0
dev: false
/@emotion/babel-plugin@11.12.0:
resolution: {integrity: sha512-y2WQb+oP8Jqvvclh8Q55gLUyb7UFvgv7eJfsj7td5TToBrIUtPay2kMrZi4xjq9qw2vD0ZR5fSho0yqoFgX7Rw==}
dependencies:
@@ -4313,6 +4292,10 @@ packages:
open: 8.4.2
dev: true
/bind-event-listener@3.0.0:
resolution: {integrity: sha512-PJvH288AWQhKs2v9zyfYdPzlPqf5bXbGMmhmUIY9x4dAUGIWgomO771oBQNwJnMQSnUIXhKu6sgzpBRXTlvb8Q==}
dev: false
/bl@4.1.0:
resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==}
dependencies:
@@ -7557,6 +7540,10 @@ packages:
resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==}
dev: true
/raf-schd@4.0.3:
resolution: {integrity: sha512-tQkJl2GRWh83ui2DiPTJz9wEiMN20syf+5oKfB03yYP7ioZcJwsIK8FjrtLwH1m7C7e+Tt2yYBlrOpdT+dyeIQ==}
dev: false
/raf-throttle@2.0.6:
resolution: {integrity: sha512-C7W6hy78A+vMmk5a/B6C5szjBHrUzWJkVyakjKCK59Uy2CcA7KhO1JUvvH32IXYFIcyJ3FMKP3ZzCc2/71I6Vg==}
dev: false

Binary file not shown.

After

Width:  |  Height:  |  Size: 895 KiB

View File

@@ -997,6 +997,7 @@
"controlNetControlMode": "Control Mode",
"copyImage": "Copy Image",
"denoisingStrength": "Denoising Strength",
"noRasterLayers": "No Raster Layers",
"downloadImage": "Download Image",
"general": "General",
"guidance": "Guidance",
@@ -1412,8 +1413,9 @@
"paramDenoisingStrength": {
"heading": "Denoising Strength",
"paragraphs": [
"How much noise is added to the input image.",
"0 will result in an identical image, while 1 will result in a completely new image."
"Controls how much the generated image varies from the raster layer(s).",
"Lower strength stays closer to the combined visible raster layers. Higher strength relies more on the global prompt.",
"When there are no raster layers with visible content, this setting is ignored."
]
},
"paramHeight": {
@@ -1662,6 +1664,7 @@
"mergeDown": "Merge Down",
"mergeVisibleOk": "Merged layers",
"mergeVisibleError": "Error merging layers",
"mergingLayers": "Merging layers",
"clearHistory": "Clear History",
"bboxOverlay": "Show Bbox Overlay",
"resetCanvas": "Reset Canvas",
@@ -1774,9 +1777,10 @@
"newCanvasSession": "New Canvas Session",
"newCanvasSessionDesc": "This will clear the canvas and all settings except for your model selection. Generations will be staged on the canvas.",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
"controlMode": {
"controlMode": "Control Mode",
"balanced": "Balanced",
"balanced": "Balanced (recommended)",
"prompt": "Prompt",
"control": "Control",
"megaControl": "Mega Control"
@@ -1815,6 +1819,9 @@
"process": "Process",
"apply": "Apply",
"cancel": "Cancel",
"advanced": "Advanced",
"processingLayerWith": "Processing layer with the {{type}} filter.",
"forMoreControl": "For more control, click Advanced below.",
"spandrel_filter": {
"label": "Image-to-Image Model",
"description": "Run an image-to-image model on the selected layer.",
@@ -2095,9 +2102,10 @@
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"line1": "<ItalicComponent>Select Object</ItalicComponent> tool for precise object selection and editing",
"line2": "Expanded Flux support, now with Global Reference Images",
"line3": "Improved tooltips and context menus",
"items": [
"<StrongComponent>SD 3.5</StrongComponent>: Support for Text-to-Image in Workflows with SD 3.5 Medium and Large.",
"<StrongComponent>Canvas</StrongComponent>: Streamlined Control Layer processing and improved default Control settings."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -8,10 +8,8 @@ import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import {
@@ -19,6 +17,7 @@ import {
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
@@ -62,8 +61,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
const handleReset = useCallback(() => {
clearStorage();
location.reload();
@@ -92,19 +89,8 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box
id="invoke-app-wrapper"
w="100dvw"
h="100dvh"
position="relative"
overflow="hidden"
{...dropzone.getRootProps()}
>
<input {...dropzone.getInputProps()} />
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{dropzone.isDragActive && isHandlingUpload && (
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
)}
</Box>
<DeleteImageModal />
<ChangeBoardModal />
@@ -121,6 +107,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<NewGallerySessionDialog />
<NewCanvasSessionDialog />
<ImageContextMenu />
<FullscreenDropzone />
</ErrorBoundary>
);
};

View File

@@ -1,4 +1,3 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
@@ -8,13 +7,11 @@ import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const GlobalImageHotkeys = memo(() => {
useAssertSingleton('GlobalImageHotkeys');
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const { currentData: imageDTO } = useGetImageDTOQuery(lastSelectedImage?.image_name ?? skipToken);
const imageDTO = useAppSelector(selectLastSelectedImage);
if (!imageDTO) {
return null;

View File

@@ -19,7 +19,6 @@ import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
import { createStore } from 'app/store/store';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
@@ -237,9 +236,7 @@ const InvokeAIUI = ({
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<AppDndContext>
<App config={config} studioInitAction={studioInitAction} />
</AppDndContext>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
</React.Suspense>
</Provider>

View File

@@ -17,6 +17,7 @@ const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export const zLogNamespace = z.enum([
'canvas',
'config',
'dnd',
'events',
'gallery',
'generation',

View File

@@ -1,4 +1,3 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
/** @knipignore */
export const EMPTY_OBJECT = {};

View File

@@ -16,7 +16,6 @@ import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMi
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageDeletionListeners } from 'app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners';
import { addImageDroppedListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImagesStarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesStarred';
import { addImagesUnstarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesUnstarred';
@@ -93,9 +92,6 @@ addGetOpenAPISchemaListener(startAppListening);
addWorkflowLoadRequestedListener(startAppListening);
addUpdateAllNodesRequestedListener(startAppListening);
// DND
addImageDroppedListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);

View File

@@ -1,12 +1,12 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('queue');
@@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
log.debug({ enqueueResult } as JsonObject, t('queue.graphQueued'));
} catch (error) {
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
log.error({ enqueueBatchArg } as JsonObject, t('queue.graphFailedToQueue'));
if (error instanceof Object && 'status' in error && error.status === 403) {
return;

View File

@@ -1,12 +1,12 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';
const log = logger('queue');
@@ -17,7 +17,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
effect: (action) => {
const enqueueResult = action.payload;
const arg = action.meta.arg.originalArgs;
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
log.debug({ enqueueResult } as JsonObject, 'Batch enqueued');
toast({
id: 'QUEUE_BATCH_SUCCEEDED',
@@ -45,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
status: 'error',
description: t('common.unknownError'),
});
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
log.error({ batchConfig } as JsonObject, t('queue.batchFailedToQueue'));
return;
}
@@ -71,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
description: t('common.unknownError'),
});
}
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
log.error({ batchConfig, error: serializeError(response) } as JsonObject, t('queue.batchFailedToQueue'));
},
});
};

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import type { Result } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
@@ -10,10 +10,12 @@ import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGr
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { toast } from 'features/toast/toast';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
import { assert, AssertionError } from 'tsafe';
import type { JsonObject } from 'type-fest';
const log = logger('generation');
@@ -57,7 +59,17 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
}
if (buildGraphResult.isErr()) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status: 'error',
title: 'Failed to build graph',
description,
});
return;
}
@@ -88,7 +100,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return;
}
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
log.debug({ batchConfig: prepareBatchResult.value } as JsonObject, 'Enqueued batch');
},
});
};

View File

@@ -1,12 +1,12 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { JsonObject } from 'type-fest';
const log = logger('system');
@@ -16,12 +16,12 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
effect: (action, { getState }) => {
const schemaJSON = action.payload;
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
log.debug({ schemaJSON: parseify(schemaJSON) } as JsonObject, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
log.debug({ nodeTemplates } as JsonObject, `Built ${size(nodeTemplates)} node templates`);
$templates.set(nodeTemplates);
},

View File

@@ -1,334 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
entityRasterized,
entitySelected,
inpaintMaskAdded,
rasterLayerAdded,
referenceImageAdded,
referenceImageIPAdapterImageChanged,
rgAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { imagesApi } from 'services/api/endpoints/images';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
activeData: TypesafeDraggableData;
}>('dnd/dndDropped');
const log = logger('system');
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: dndDropped,
effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
}
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug({ activeData, overData }, 'Node field dropped');
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
/**
* Image dropped on IP Adapter Layer
*/
if (
overData.actionType === 'SET_IPA_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id } = overData.context;
dispatch(
referenceImageIPAdapterImageChanged({
entityIdentifier: { id, type: 'reference_image' },
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on RG Layer IP Adapter
*/
if (
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id, referenceImageId } = overData.context;
dispatch(
rgIPAdapterImageChanged({
entityIdentifier: { id, type: 'regional_guidance' },
referenceImageId,
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on Raster layer
*/
if (
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
return;
}
/**
/**
* Image dropped on Inpaint Mask
*/
if (
overData.actionType === 'ADD_INPAINT_MASK_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasInpaintMaskState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
return;
}
/**
/**
* Image dropped on Regional Guidance
*/
if (
overData.actionType === 'ADD_REGIONAL_GUIDANCE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRegionalGuidanceState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/
if (
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
const defaultControlAdapter = selectDefaultControlAdapter(state);
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageObject],
position: { x, y },
controlAdapter: defaultControlAdapter,
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
return;
}
if (
overData.actionType === 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasRegionalGuidanceState> = {
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }],
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
if (
overData.actionType === 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasReferenceImageState> = {
ipAdapter,
};
dispatch(referenceImageAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/
if (overData.actionType === 'REPLACE_LAYER_WITH_IMAGE' && activeData.payloadType === 'IMAGE_DTO') {
const state = getState();
const { entityIdentifier } = overData.context;
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
}
/**
* Image dropped on node image field
*/
if (
overData.actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldImageValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image selected for compare
*/
if (
overData.actionType === 'SELECT_FOR_COMPARE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(imageToCompareChanged(imageDTO));
return;
}
/**
* Image dropped on user board
*/
if (
overData.actionType === 'ADD_TO_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO,
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Image dropped on 'none' board
*/
if (
overData.actionType === 'REMOVE_FROM_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Image dropped on upscale initial image
*/
if (
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(upscaleInitialImageChanged(imageDTO));
return;
}
/**
* Multiple images dropped on user board
*/
if (overData.actionType === 'ADD_TO_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
const imageDTOs = getState().gallery.selection;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImagesToBoard.initiate({
imageDTOs,
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Multiple images dropped on 'none' board
*/
if (overData.actionType === 'REMOVE_FROM_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
const imageDTOs = getState().gallery.selection;
dispatch(
imagesApi.endpoints.removeImagesFromBoard.initiate({
imageDTOs,
})
);
dispatch(selectionChanged([]));
return;
}
},
});
};

View File

@@ -1,18 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import {
entityRasterized,
entitySelected,
referenceImageIPAdapterImageChanged,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
@@ -51,93 +41,45 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
log.debug({ imageDTO }, 'Image uploaded');
const { postUploadAction } = action.meta.arg.originalArgs;
const boardId = imageDTO.board_id ?? 'none';
if (!postUploadAction) {
return;
}
if (action.meta.arg.originalArgs.withToast) {
const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'),
status: 'success',
} as const;
const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'),
status: 'success',
} as const;
// default action - just upload and alert user
if (postUploadAction.type === 'TOAST') {
const boardId = imageDTO.board_id ?? 'none';
// default action - just upload and alert user
if (lastUploadedToastTimeout !== null) {
window.clearTimeout(lastUploadedToastTimeout);
}
const toastApi = toast({
...DEFAULT_UPLOADED_TOAST,
title: postUploadAction.title || DEFAULT_UPLOADED_TOAST.title,
title: DEFAULT_UPLOADED_TOAST.title,
description: getUploadedToastDescription(boardId, state),
duration: null, // we will close the toast manually
});
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
* the user's gallery board and view selection:
* - User uploads multiple images
* - A couple uploads finish, but others are pending still
* - User changes the board selection
* - Pending uploads finish and change the board back to the original board
* - User is confused as to why the board changed
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
}
return;
}
if (postUploadAction.type === 'SET_UPSCALE_INITIAL_IMAGE') {
dispatch(upscaleInitialImageChanged(imageDTO));
toast({
...DEFAULT_UPLOADED_TOAST,
description: 'set as upscale initial image',
});
return;
}
if (postUploadAction.type === 'SET_IPA_IMAGE') {
const { id } = postUploadAction;
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: { id, type: 'reference_image' }, imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction.type === 'SET_RG_IP_ADAPTER_IMAGE') {
const { id, referenceImageId } = postUploadAction;
dispatch(
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, referenceImageId, imageDTO })
);
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
return;
}
if (postUploadAction.type === 'REPLACE_LAYER_WITH_IMAGE') {
const { entityIdentifier } = postUploadAction;
const state = getState();
const imageObject = imageDTOToImageObject(imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
* the user's gallery board and view selection:
* - User uploads multiple images
* - A couple uploads finish, but others are pending still
* - User changes the board selection
* - Pending uploads finish and change the board back to the original board
* - User is confused as to why the board changed
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
}
},
});

View File

@@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import {
controlLayerModelChanged,
referenceImageIPAdapterModelChanged,
@@ -41,6 +40,7 @@ import {
isSpandrelImageToImageModelConfig,
isT5EncoderModelConfig,
} from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('models');
@@ -85,7 +85,7 @@ type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<SerializableObject>
log: Logger<JsonObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
@@ -164,7 +164,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
// We have a VAE selected, need to check if it is available
// Grab just the VAE models
const vaeModels = models.filter(isNonFluxVAEModelConfig);
const vaeModels = models.filter((m) => isNonFluxVAEModelConfig(m));
// If the current VAE model is available, we don't need to do anything
if (vaeModels.some((m) => m.key === selectedVAEModel.key)) {
@@ -297,7 +297,7 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
@@ -325,7 +325,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {
@@ -353,7 +353,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
const selectedFLUXVAEModel = state.params.fluxVAE;
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedFLUXVAEModel && fluxVAEModels.some((m) => m.key === selectedFLUXVAEModel.key)) {

View File

@@ -4,8 +4,10 @@ import { atom } from 'nanostores';
/**
* A fallback non-writable atom that always returns `false`, used when a nanostores atom is only conditionally available
* in a hook or component.
*
* @knipignore
*/
// export const $false: ReadableAtom<boolean> = atom(false);
export const $false: ReadableAtom<boolean> = atom(false);
/**
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
* in a hook or component.

View File

@@ -3,7 +3,6 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
@@ -37,6 +36,7 @@ import undoable from 'redux-undo';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
@@ -139,7 +139,7 @@ const unserialize: UnserializeFunction = (data, key) => {
{
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as SerializableObject, // this is always serializable
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
},
`Rehydrated slice "${key}"`
);

View File

@@ -1,251 +0,0 @@
import type { ChakraProps, FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { IAILoadingImageFallback, IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable';
const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
const baseStyles: SystemStyleObject = {
touchAction: 'none',
userSelect: 'none',
webkitUserSelect: 'none',
};
const sx: SystemStyleObject = {
...baseStyles,
'.gallery-image-container::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected="selected"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected="selectedForCompare"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected="selected"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected="selectedForCompare"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
};
type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
withMetadataOverlay?: boolean;
isDragDisabled?: boolean;
isDropDisabled?: boolean;
isUploadDisabled?: boolean;
minSize?: number;
postUploadAction?: PostUploadAction;
imageSx?: ChakraProps['sx'];
fitContainer?: boolean;
droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData;
dropLabel?: string;
isSelected?: boolean;
isSelectedForCompare?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
withHoverOverlay?: boolean;
children?: JSX.Element;
uploadElement?: ReactNode;
dataTestId?: string;
};
const IAIDndImage = (props: IAIDndImageProps) => {
const {
imageDTO,
onError,
onClick,
withMetadataOverlay = false,
isDropDisabled = false,
isDragDisabled = false,
isUploadDisabled = false,
minSize = 24,
postUploadAction,
imageSx,
fitContainer = false,
droppableData,
draggableData,
dropLabel,
isSelected = false,
isSelectedForCompare = false,
thumbnail = false,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
useThumbailFallback,
withHoverOverlay = false,
children,
dataTestId,
...rest
} = props;
const openInNewTab = useCallback(
(e: MouseEvent) => {
if (!imageDTO) {
return;
}
if (e.button !== 1) {
return;
}
window.open(imageDTO.image_url, '_blank');
},
[imageDTO]
);
const ref = useRef<HTMLDivElement>(null);
useImageContextMenu(imageDTO, ref);
return (
<Flex
ref={ref}
width="full"
height="full"
alignItems="center"
justifyContent="center"
position="relative"
minW={minSize ? minSize : undefined}
minH={minSize ? minSize : undefined}
userSelect="none"
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
sx={withHoverOverlay ? sx : baseStyles}
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
{...rest}
>
{imageDTO && (
<Flex
className="gallery-image-container"
w="full"
h="full"
position={fitContainer ? 'absolute' : 'relative'}
alignItems="center"
justifyContent="center"
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
onError={onError}
draggable={false}
w={imageDTO.width}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
sx={imageSx}
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<UploadButton
isUploadDisabled={isUploadDisabled}
postUploadAction={postUploadAction}
uploadElement={uploadElement}
minSize={minSize}
/>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
</Flex>
);
};
export default memo(IAIDndImage);
const UploadButton = memo(
({
isUploadDisabled,
postUploadAction,
uploadElement,
minSize,
}: {
isUploadDisabled: boolean;
postUploadAction?: PostUploadAction;
uploadElement: ReactNode;
minSize: number;
}) => {
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
isDisabled: isUploadDisabled,
});
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
const styles: SystemStyleObject = {
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: 'base.500',
};
if (!isUploadDisabled) {
Object.assign(styles, {
cursor: 'pointer',
bg: 'base.700',
_hover: {
bg: 'base.650',
color: 'base.300',
},
});
}
return styles;
}, [isUploadDisabled, minSize]);
return (
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
<input {...getUploadInputProps()} />
{uploadElement}
</Flex>
);
}
);
UploadButton.displayName = 'UploadButton';

View File

@@ -1,38 +0,0 @@
import type { BoxProps } from '@invoke-ai/ui-library';
import { Box } from '@invoke-ai/ui-library';
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDraggableData } from 'features/dnd/types';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
type IAIDraggableProps = BoxProps & {
disabled?: boolean;
data?: TypesafeDraggableData;
};
const IAIDraggable = (props: IAIDraggableProps) => {
const { data, disabled, ...rest } = props;
const dndId = useRef(uuidv4());
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
id: dndId.current,
disabled,
data,
});
return (
<Box
ref={setNodeRef}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
{...attributes}
{...listeners}
{...rest}
/>
);
};
export default memo(IAIDraggable);

View File

@@ -1,64 +0,0 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
type Props = {
isOver: boolean;
label?: string;
withBackdrop?: boolean;
};
const IAIDropOverlay = (props: Props) => {
const { isOver, label, withBackdrop = true } = props;
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0}>
<Flex
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
w="full"
h="full"
bg={withBackdrop ? 'base.900' : 'transparent'}
opacity={0.7}
borderRadius="base"
alignItems="center"
justifyContent="center"
transitionProperty="common"
transitionDuration="0.1s"
/>
<Flex
position="absolute"
top={0.5}
right={0.5}
bottom={0.5}
left={0.5}
opacity={1}
borderWidth={1.5}
borderColor={isOver ? 'invokeYellow.300' : 'base.500'}
borderRadius="base"
borderStyle="dashed"
transitionProperty="common"
transitionDuration="0.1s"
alignItems="center"
justifyContent="center"
>
{label && (
<Text
fontSize="lg"
fontWeight="semibold"
color={isOver ? 'invokeYellow.300' : 'base.500'}
transitionProperty="common"
transitionDuration="0.1s"
textAlign="center"
>
{label}
</Text>
)}
</Flex>
</Flex>
);
};
export default memo(IAIDropOverlay);

View File

@@ -1,46 +0,0 @@
import { Box } from '@invoke-ai/ui-library';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { AnimatePresence } from 'framer-motion';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay';
type IAIDroppableProps = {
dropLabel?: string;
disabled?: boolean;
data?: TypesafeDroppableData;
};
const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppableTypesafe({
id: dndId.current,
disabled,
data,
});
return (
<Box
ref={setNodeRef}
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
w="full"
h="full"
pointerEvents={active ? 'auto' : 'none'}
>
<AnimatePresence>
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
</AnimatePresence>
</Box>
);
};
export default memo(IAIDroppable);

View File

@@ -1,24 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Skeleton } from '@invoke-ai/ui-library';
import { memo } from 'react';
const skeletonStyles: SystemStyleObject = {
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
};
const IAIFillSkeleton = () => {
return (
<Skeleton sx={skeletonStyles}>
<Box position="absolute" top={0} insetInlineStart={0} height="full" width="full" />
</Skeleton>
);
};
export default memo(IAIFillSkeleton);

View File

@@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
type Props = { image: ImageDTO | undefined };
export const IAILoadingImageFallback = memo((props: Props) => {
const IAILoadingImageFallback = memo((props: Props) => {
if (props.image) {
return (
<Skeleton

View File

@@ -1,28 +0,0 @@
import { Badge, Flex } from '@invoke-ai/ui-library';
import { memo } from 'react';
import type { ImageDTO } from 'services/api/types';
type ImageMetadataOverlayProps = {
imageDTO: ImageDTO;
};
const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => {
return (
<Flex
pointerEvents="none"
flexDirection="column"
position="absolute"
top={0}
insetInlineStart={0}
p={2}
alignItems="flex-start"
gap={2}
>
<Badge variant="solid" colorScheme="base">
{imageDTO.width} × {imageDTO.height}
</Badge>
</Flex>
);
};
export default memo(ImageMetadataOverlay);

View File

@@ -1,89 +0,0 @@
import { Box, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { memo } from 'react';
import type { DropzoneState } from 'react-dropzone';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useBoardName } from 'services/api/hooks/useBoardName';
type ImageUploadOverlayProps = {
dropzone: DropzoneState;
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
};
const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
const { dropzone, setIsHandlingUpload } = props;
useHotkeys(
'esc',
() => {
setIsHandlingUpload(false);
},
[setIsHandlingUpload]
);
return (
<Box position="absolute" top={0} right={0} bottom={0} left={0} zIndex={999} backdropFilter="blur(20px)">
<Flex position="absolute" top={0} right={0} bottom={0} left={0} bg="base.900" opacity={0.7} />
<Flex
position="absolute"
flexDir="column"
gap={4}
top={2}
right={2}
bottom={2}
left={2}
opacity={1}
borderWidth={2}
borderColor={dropzone.isDragAccept ? 'invokeYellow.300' : 'error.500'}
borderRadius="base"
borderStyle="dashed"
transitionProperty="common"
transitionDuration="0.1s"
alignItems="center"
justifyContent="center"
color={dropzone.isDragReject ? 'error.300' : undefined}
>
{dropzone.isDragAccept && <DragAcceptMessage />}
{!dropzone.isDragAccept && <DragRejectMessage />}
</Flex>
</Box>
);
};
export default memo(ImageUploadOverlay);
const DragAcceptMessage = () => {
const { t } = useTranslation();
const selectedBoardId = useAppSelector(selectSelectedBoardId);
const boardName = useBoardName(selectedBoardId);
return (
<>
<Heading size="lg">{t('gallery.dropToUpload')}</Heading>
<Heading size="md">{t('toast.imagesWillBeAddedTo', { boardName })}</Heading>
</>
);
};
const DragRejectMessage = () => {
const { t } = useTranslation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
if (maxImageUploadCount === undefined) {
return (
<>
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc')}</Heading>
</>
);
}
return (
<>
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount })}</Heading>
</>
);
};

View File

@@ -1,5 +1,6 @@
import type { PopoverProps } from '@invoke-ai/ui-library';
import commercialLicenseBg from 'public/assets/images/commercial-license-bg.png';
import denoisingStrength from 'public/assets/images/denoising-strength.png';
export type Feature =
| 'clipSkip'
@@ -125,7 +126,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
},
infillMethod: {
href: 'https://support.invoke.ai/support/solutions/articles/151000158841-infill-and-scaling',
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
},
scaleBeforeProcessing: {
href: 'https://support.invoke.ai/support/solutions/articles/151000158841',
@@ -138,6 +139,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
},
paramDenoisingStrength: {
href: 'https://support.invoke.ai/support/solutions/articles/151000094998-image-to-image',
image: denoisingStrength,
},
paramHrf: {
href: 'https://support.invoke.ai/support/solutions/articles/151000096700-how-can-i-get-larger-images-what-does-upscaling-do-',

View File

@@ -1,9 +1,13 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { autoScrollForElements } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/element';
import { autoScrollForExternal } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/external';
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import type { OverlayScrollbarsComponentRef } from 'overlayscrollbars-react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties, PropsWithChildren } from 'react';
import { memo, useMemo } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
type Props = PropsWithChildren & {
maxHeight?: ChakraProps['maxHeight'];
@@ -11,17 +15,38 @@ type Props = PropsWithChildren & {
overflowY?: 'hidden' | 'scroll';
};
const styles: CSSProperties = { height: '100%', width: '100%' };
const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 };
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
const overlayscrollbarsOptions = useMemo(
() => getOverlayScrollbarsParams(overflowX, overflowY).options,
[overflowX, overflowY]
);
const [os, osRef] = useState<OverlayScrollbarsComponentRef | null>(null);
useEffect(() => {
const osInstance = os?.osInstance();
if (!osInstance) {
return;
}
const element = osInstance.elements().viewport;
// `pragmatic-drag-and-drop-auto-scroll` requires the element to have `overflow-y: scroll` or `overflow-y: auto`
// else it logs an ugly warning. In our case, using a custom scrollbar library, it will be 'hidden' by default.
// To prevent the erroneous warning, we temporarily set the overflow-y to 'scroll' and then revert it back.
const overflowY = element.style.overflowY; // starts 'hidden'
element.style.setProperty('overflow-y', 'scroll', 'important');
const cleanup = combine(autoScrollForElements({ element }), autoScrollForExternal({ element }));
element.style.setProperty('overflow-y', overflowY);
return cleanup;
}, [os]);
return (
<Flex w="full" h="full" maxHeight={maxHeight} position="relative">
<Box position="absolute" top={0} left={0} right={0} bottom={0}>
<OverlayScrollbarsComponent defer style={styles} options={overlayscrollbarsOptions}>
<OverlayScrollbarsComponent ref={osRef} style={styles} options={overlayscrollbarsOptions}>
{children}
</OverlayScrollbarsComponent>
</Box>

View File

@@ -0,0 +1,57 @@
type Props = {
/**
* The amplitude of the wave. 0 is a straight line, higher values create more pronounced waves.
*/
amplitude: number;
/**
* The number of segments in the line. More segments create a smoother wave.
*/
segments?: number;
/**
* The color of the wave.
*/
stroke: string;
/**
* The width of the wave.
*/
strokeWidth: number;
/**
* The width of the SVG.
*/
width: number;
/**
* The height of the SVG.
*/
height: number;
};
const WavyLine = ({ amplitude, stroke, strokeWidth, width, height, segments = 5 }: Props) => {
// Calculate the path dynamically based on waviness
const generatePath = () => {
if (amplitude === 0) {
// If waviness is 0, return a straight line
return `M0,${height / 2} L${width},${height / 2}`;
}
const clampedAmplitude = Math.min(height / 2, amplitude); // Cap amplitude to half the height
const segmentWidth = width / segments;
let path = `M0,${height / 2}`; // Start in the middle of the left edge
// Loop through each segment and alternate the y position to create waves
for (let i = 1; i <= segments; i++) {
const x = i * segmentWidth;
const y = height / 2 + (i % 2 === 0 ? clampedAmplitude : -clampedAmplitude);
path += ` Q${x - segmentWidth / 2},${y} ${x},${height / 2}`;
}
return path;
};
return (
<svg width={width} height={height} viewBox={`0 0 ${width} ${height}`} xmlns="http://www.w3.org/2000/svg">
<path d={generatePath()} fill="none" stroke={stroke} strokeWidth={strokeWidth} />
</svg>
);
};
export default WavyLine;

View File

@@ -1,124 +0,0 @@
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import type { Accept, FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
const log = logger('gallery');
const accept: Accept = {
'image/png': ['.png'],
'image/jpeg': ['.jpg', '.jpeg', '.png'],
};
export const useFullscreenDropzone = () => {
useAssertSingleton('useFullscreenDropzone');
const { t } = useTranslation();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
const [uploadImage] = useUploadImageMutation();
const activeTabName = useAppSelector(selectActiveTab);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const getPostUploadAction = useCallback((): PostUploadAction => {
if (activeTabName === 'upscaling') {
return { type: 'SET_UPSCALE_INITIAL_IMAGE' };
} else {
return { type: 'TOAST' };
}
}, [activeTabName]);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
if (fileRejections.length > 0) {
const errors = fileRejections.map((rejection) => ({
errors: rejection.errors.map(({ message }) => message),
file: rejection.file.path,
}));
log.error({ errors }, 'Invalid upload');
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description,
status: 'error',
});
setIsHandlingUpload(false);
return;
}
for (const [i, file] of acceptedFiles.entries()) {
uploadImage({
file,
image_category: 'user',
is_intermediate: false,
postUploadAction: getPostUploadAction(),
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
// The `imageUploaded` listener does some extra logic, like switching to the asset view on upload on the
// first upload of a "batch".
isFirstUploadOfBatch: i === 0,
});
}
setIsHandlingUpload(false);
},
[t, maxImageUploadCount, uploadImage, getPostUploadAction, autoAddBoardId]
);
const onDragOver = useCallback(() => {
setIsHandlingUpload(true);
}, []);
const onDragLeave = useCallback(() => {
setIsHandlingUpload(false);
}, []);
const dropzone = useDropzone({
accept,
noClick: true,
onDrop,
onDragOver,
onDragLeave,
noKeyboard: true,
multiple: maxImageUploadCount === undefined || maxImageUploadCount > 1,
maxFiles: maxImageUploadCount,
});
useEffect(() => {
// This is a hack to allow pasting images into the uploader
const handlePaste = (e: ClipboardEvent) => {
if (!dropzone.inputRef.current) {
return;
}
if (e.clipboardData?.files) {
// Set the files on the dropzone.inputRef
dropzone.inputRef.current.files = e.clipboardData.files;
// Dispatch the change event, dropzone catches this and we get to use its own validation
dropzone.inputRef.current?.dispatchEvent(new Event('change', { bubbles: true }));
}
};
// Add the paste event listener
document.addEventListener('paste', handlePaste);
return () => {
document.removeEventListener('paste', handlePaste);
};
}, [dropzone.inputRef]);
return { dropzone, isHandlingUpload, setIsHandlingUpload };
};

View File

@@ -1,3 +1,5 @@
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
@@ -7,14 +9,23 @@ import { useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
import { PiUploadBold } from 'react-icons/pi';
import { uploadImages, useUploadImageMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
type UseImageUploadButtonArgs = {
postUploadAction?: PostUploadAction;
isDisabled?: boolean;
allowMultiple?: boolean;
};
type UseImageUploadButtonArgs =
| {
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
};
const log = logger('gallery');
@@ -37,30 +48,46 @@ const log = logger('gallery');
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
*/
export const useImageUploadButton = ({
postUploadAction,
isDisabled,
allowMultiple = false,
}: UseImageUploadButtonArgs) => {
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [uploadImage] = useUploadImageMutation();
const [uploadImage, request] = useUploadImageMutation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const { t } = useTranslation();
const onDropAccepted = useCallback(
(files: File[]) => {
for (const [i, file] of files.entries()) {
uploadImage({
async (files: File[]) => {
if (!allowMultiple) {
if (files.length > 1) {
log.warn('Multiple files dropped but only one allowed');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
file,
image_category: 'user',
is_intermediate: false,
postUploadAction: postUploadAction ?? { type: 'TOAST' },
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
});
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
//
const imageDTOs = await uploadImages(
files.map((file) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}))
);
if (onUpload) {
onUpload(imageDTOs);
}
}
},
[autoAddBoardId, postUploadAction, uploadImage]
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
);
const onDropRejected = useCallback(
@@ -103,5 +130,42 @@ export const useImageUploadButton = ({
maxFiles: maxImageUploadCount,
});
return { getUploadButtonProps, getUploadInputProps, openUploader };
return { getUploadButtonProps, getUploadInputProps, openUploader, request };
};
const sx = {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 0,
borderRadius: 'base',
'&[data-error=true]': {
borderWidth: 1,
},
} satisfies SystemStyleObject;
export const UploadImageButton = ({
isDisabled = false,
onUpload,
isError = false,
...rest
}: {
onUpload?: (imageDTO: ImageDTO) => void;
isError?: boolean;
} & SetOptional<IconButtonProps, 'aria-label'>) => {
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload });
return (
<>
<IconButton
aria-label="Upload image"
variant="ghost"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
isLoading={uploadApi.request.isLoading}
{...rest}
{...uploadApi.getUploadButtonProps()}
/>
<input {...uploadApi.getUploadInputProps()} />
</>
);
};

View File

@@ -1,14 +0,0 @@
import { getPrefixedId, nanoid } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
export const useNanoid = (prefix?: string) => {
const id = useMemo(() => {
if (prefix) {
return getPrefixedId(prefix);
} else {
return nanoid();
}
}, [prefix]);
return id;
};

View File

@@ -1,12 +0,0 @@
type SerializableValue =
| string
| number
| boolean
| null
| undefined
| SerializableValue[]
| readonly SerializableValue[]
| SerializableObject;
export type SerializableObject = {
[k: string | number]: SerializableValue;
};

View File

@@ -0,0 +1,6 @@
import type { AssertionError } from 'tsafe';
export function extractMessageFromAssertionError(error: AssertionError): string | null {
const match = error.message.match(/Wrong assertion encountered: "(.*)"/);
return match ? (match[1] ?? null) : null;
}

View File

@@ -0,0 +1,15 @@
import type { CSSProperties } from 'react';
/**
* Chakra's Tooltip's method of finding the nearest scroll parent has a problem - it assumes the first parent with
* `overflow: hidden` is the scroll parent. In this case, the Collapse component has that style, but isn't scrollable
* itself. The result is that the tooltip does not close on scroll, because the scrolling happens higher up in the DOM.
*
* As a hacky workaround, we can set the overflow to `visible`, which allows the scroll parent search to continue up to
* the actual scroll parent (in this case, the OverlayScrollbarsComponent in BoardsListWrapper).
*
* See: https://github.com/chakra-ui/chakra-ui/issues/7871#issuecomment-2453780958
*/
export const fixTooltipCloseOnScrollStyles: CSSProperties = {
overflow: 'visible',
};

View File

@@ -1,38 +1,26 @@
import { Grid, GridItem } from '@invoke-ai/ui-library';
import IAIDroppable from 'common/components/IAIDroppable';
import type {
AddControlLayerFromImageDropData,
AddGlobalReferenceImageFromImageDropData,
AddRasterLayerFromImageDropData,
AddRegionalReferenceImageFromImageDropData,
} from 'features/dnd/types';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { newCanvasEntityFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const addRasterLayerFromImageDropData: AddRasterLayerFromImageDropData = {
id: 'add-raster-layer-from-image-drop-data',
actionType: 'ADD_RASTER_LAYER_FROM_IMAGE',
};
const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_CONTROL_LAYER_FROM_IMAGE',
};
const addRegionalReferenceImageFromImageDropData: AddRegionalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE',
};
const addGlobalReferenceImageFromImageDropData: AddGlobalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE',
};
const addRasterLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({ type: 'raster_layer' });
const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'control_layer',
});
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addGlobalReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'reference_image',
});
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
if (imageViewer.isOpen) {
return null;
@@ -51,28 +39,36 @@ export const CanvasDropArea = memo(() => {
pointerEvents="none"
>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newRasterLayer')}
data={addRasterLayerFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRasterLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRasterLayer')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newControlLayer')}
data={addControlLayerFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addControlLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newControlLayer')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
data={addRegionalReferenceImageFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRegionalGuidanceReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
data={addGlobalReferenceImageFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addGlobalReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
isDisabled={isBusy}
/>
</GridItem>
</Grid>

View File

@@ -0,0 +1,59 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasEntityListDnd } from 'features/controlLayers/components/CanvasEntityList/useCanvasEntityListDnd';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsSelected } from 'features/controlLayers/hooks/useEntityIsSelected';
import { entitySelected } from 'features/controlLayers/store/canvasSlice';
import { DndListDropIndicator } from 'features/dnd/DndListDropIndicator';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useRef } from 'react';
const sx = {
position: 'relative',
flexDir: 'column',
w: 'full',
bg: 'base.850',
borderRadius: 'base',
'&[data-selected=true]': {
bg: 'base.800',
},
'&[data-is-dragging=true]': {
opacity: 0.3,
},
transitionProperty: 'common',
} satisfies SystemStyleObject;
export const CanvasEntityContainer = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
const isSelected = useEntityIsSelected(entityIdentifier);
const onClick = useCallback(() => {
if (isSelected) {
return;
}
dispatch(entitySelected({ entityIdentifier }));
}, [dispatch, entityIdentifier, isSelected]);
const ref = useRef<HTMLDivElement>(null);
const [dndListState, isDragging] = useCanvasEntityListDnd(ref, entityIdentifier);
return (
<Box position="relative">
<Flex
// This is used to trigger the post-move flash animation
data-entity-id={entityIdentifier.id}
data-selected={isSelected}
data-is-dragging={isDragging}
ref={ref}
onClick={onClick}
sx={sx}
>
{props.children}
</Flex>
<DndListDropIndicator dndState={dndListState} />
</Box>
);
});
CanvasEntityContainer.displayName = 'CanvasEntityContainer';

View File

@@ -0,0 +1,181 @@
import { monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
import { reorderWithEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/util/reorder-with-edge';
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useBoolean } from 'common/hooks/useBoolean';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import { entitiesReordered } from 'features/controlLayers/store/canvasSlice';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isRenderableEntityType } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { triggerPostMoveFlash } from 'features/dnd/util';
import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';
import { flushSync } from 'react-dom';
import { PiCaretDownBold } from 'react-icons/pi';
type Props = PropsWithChildren<{
isSelected: boolean;
type: CanvasEntityIdentifier['type'];
entityIdentifiers: CanvasEntityIdentifier[];
}>;
export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityIdentifiers }: Props) => {
const title = useEntityTypeTitle(type);
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
const collapse = useBoolean(true);
const dispatch = useAppDispatch();
useEffect(() => {
return monitorForElements({
canMonitor({ source }) {
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
return false;
}
if (source.data.payload.entityIdentifier.type !== type) {
return false;
}
return true;
},
onDrop({ location, source }) {
const target = location.current.dropTargets[0];
if (!target) {
return;
}
const sourceData = source.data;
const targetData = target.data;
if (!singleCanvasEntityDndSource.typeGuard(sourceData) || !singleCanvasEntityDndSource.typeGuard(targetData)) {
return;
}
const indexOfSource = entityIdentifiers.findIndex(
(entityIdentifier) => entityIdentifier.id === sourceData.payload.entityIdentifier.id
);
const indexOfTarget = entityIdentifiers.findIndex(
(entityIdentifier) => entityIdentifier.id === targetData.payload.entityIdentifier.id
);
if (indexOfTarget < 0 || indexOfSource < 0) {
return;
}
// Don't move if the source and target are the same index, meaning same position in the list
if (indexOfSource === indexOfTarget) {
return;
}
const closestEdgeOfTarget = extractClosestEdge(targetData);
// It's possible that the indices are different, but refer to the same position. For example, if the source is
// at 2 and the target is at 3, but the target edge is 'top', then the entity is already in the correct position.
// We should bail if this is the case.
let edgeIndexDelta = 0;
if (closestEdgeOfTarget === 'bottom') {
edgeIndexDelta = 1;
} else if (closestEdgeOfTarget === 'top') {
edgeIndexDelta = -1;
}
// If the source is already in the correct position, we don't need to move it.
if (indexOfSource === indexOfTarget + edgeIndexDelta) {
return;
}
// Using `flushSync` so we can query the DOM straight after this line
flushSync(() => {
dispatch(
entitiesReordered({
type,
entityIdentifiers: reorderWithEdge({
list: entityIdentifiers,
startIndex: indexOfSource,
indexOfTarget,
closestEdgeOfTarget,
axis: 'vertical',
}),
})
);
});
// Flash the element that was moved
const element = document.querySelector(`[data-entity-id="${sourceData.payload.entityIdentifier.id}"]`);
if (element instanceof HTMLElement) {
triggerPostMoveFlash(element, colorTokenToCssVar('base.700'));
}
},
});
}, [dispatch, entityIdentifiers, type]);
return (
<Flex flexDir="column" w="full">
<Flex w="full">
<Flex
flexGrow={1}
as={Button}
onClick={collapse.toggle}
justifyContent="space-between"
alignItems="center"
gap={3}
variant="unstyled"
p={0}
h={8}
>
<Icon
boxSize={4}
as={PiCaretDownBold}
transform={collapse.isTrue ? undefined : 'rotate(-90deg)'}
fill={isSelected ? 'base.200' : 'base.500'}
transitionProperty="common"
transitionDuration="fast"
/>
{informationalPopoverFeature ? (
<InformationalPopover feature={informationalPopoverFeature}>
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
</InformationalPopover>
) : (
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
)}
<Spacer />
</Flex>
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>
<Flex flexDir="column" gap={2} pt={2}>
{children}
</Flex>
</Collapse>
</Flex>
);
});
CanvasEntityGroupList.displayName = 'CanvasEntityGroupList';

View File

@@ -0,0 +1,83 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable, dropTargetForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { type DndListTargetState, idle } from 'features/dnd/types';
import type { RefObject } from 'react';
import { useEffect, useState } from 'react';
export const useCanvasEntityListDnd = (ref: RefObject<HTMLElement>, entityIdentifier: CanvasEntityIdentifier) => {
const [dndListState, setDndListState] = useState<DndListTargetState>(idle);
const [isDragging, setIsDragging] = useState(false);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
return combine(
draggable({
element,
getInitialData() {
return singleCanvasEntityDndSource.getData({ entityIdentifier });
},
onDragStart() {
setDndListState({ type: 'is-dragging' });
setIsDragging(true);
},
onDrop() {
setDndListState(idle);
setIsDragging(false);
},
}),
dropTargetForElements({
element,
canDrop({ source }) {
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
return false;
}
if (source.data.payload.entityIdentifier.type !== entityIdentifier.type) {
return false;
}
return true;
},
getData({ input }) {
const data = singleCanvasEntityDndSource.getData({ entityIdentifier });
return attachClosestEdge(data, {
element,
input,
allowedEdges: ['top', 'bottom'],
});
},
getIsSticky() {
return true;
},
onDragEnter({ self }) {
const closestEdge = extractClosestEdge(self.data);
setDndListState({ type: 'is-dragging-over', closestEdge });
},
onDrag({ self }) {
const closestEdge = extractClosestEdge(self.data);
// Only need to update react state if nothing has changed.
// Prevents re-rendering.
setDndListState((current) => {
if (current.type === 'is-dragging-over' && current.closestEdge === closestEdge) {
return current;
}
return { type: 'is-dragging-over', closestEdge };
});
},
onDragLeave() {
setDndListState(idle);
},
onDrop() {
setDndListState(idle);
},
})
);
}, [entityIdentifier, ref]);
return [dndListState, isDragging] as const;
};

View File

@@ -7,6 +7,8 @@ import { EntityListSelectedEntityActionBar } from 'features/controlLayers/compon
import { selectHasEntities } from 'features/controlLayers/store/selectors';
import { memo, useRef } from 'react';
import { ParamDenoisingStrength } from './ParamDenoisingStrength';
export const CanvasLayersPanelContent = memo(() => {
const hasEntities = useAppSelector(selectHasEntities);
const layersPanelFocusRef = useRef<HTMLDivElement>(null);
@@ -16,6 +18,8 @@ export const CanvasLayersPanelContent = memo(() => {
<Flex ref={layersPanelFocusRef} flexDir="column" gap={2} w="full" h="full">
<EntityListSelectedEntityActionBar />
<Divider py={0} />
<ParamDenoisingStrength />
<Divider py={0} />
{!hasEntities && <CanvasAddEntityButtons />}
{hasEntities && <CanvasEntityList />}
</Flex>

View File

@@ -109,7 +109,9 @@ export const CanvasMainPanelContent = memo(() => {
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasDropArea />
<CanvasManagerProviderGate>
<CanvasDropArea />
</CanvasManagerProviderGate>
<GatedImageViewer />
</Flex>
);

View File

@@ -1,16 +1,20 @@
import { useDndContext } from '@dnd-kit/core';
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { dropTargetForElements, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { dropTargetForExternal, monitorForExternal } from '@atlaskit/pragmatic-drag-and-drop/external/adapter';
import { Box, Button, Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { CanvasLayersPanelContent } from 'features/controlLayers/components/CanvasLayersPanelContent';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectEntityCountActive } from 'features/controlLayers/store/selectors';
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { selectActiveTabCanvasRightPanel } from 'features/ui/store/uiSelectors';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasRightPanel = memo(() => {
@@ -79,37 +83,13 @@ CanvasRightPanel.displayName = 'CanvasRightPanel';
const PanelTabs = memo(() => {
const { t } = useTranslation();
const activeTab = useAppSelector(selectActiveTabCanvasRightPanel);
const store = useAppStore();
const activeEntityCount = useAppSelector(selectEntityCountActive);
const tabTimeout = useRef<number | null>(null);
const dndCtx = useDndContext();
const dispatch = useAppDispatch();
const [mouseOverTab, setMouseOverTab] = useState<'layers' | 'gallery' | null>(null);
const onOnMouseOverLayersTab = useCallback(() => {
setMouseOverTab('layers');
tabTimeout.current = window.setTimeout(() => {
if (dndCtx.active) {
dispatch(activeTabCanvasRightPanelChanged('layers'));
}
}, 300);
}, [dndCtx.active, dispatch]);
const onOnMouseOverGalleryTab = useCallback(() => {
setMouseOverTab('gallery');
tabTimeout.current = window.setTimeout(() => {
if (dndCtx.active) {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}
}, 300);
}, [dndCtx.active, dispatch]);
const onMouseOut = useCallback(() => {
setMouseOverTab(null);
if (tabTimeout.current) {
clearTimeout(tabTimeout.current);
}
}, []);
const [layersTabDndState, setLayersTabDndState] = useState<DndTargetState>('idle');
const [galleryTabDndState, setGalleryTabDndState] = useState<DndTargetState>('idle');
const layersTabRef = useRef<HTMLDivElement>(null);
const galleryTabRef = useRef<HTMLDivElement>(null);
const timeoutRef = useRef<number | null>(null);
const layersTabLabel = useMemo(() => {
if (activeEntityCount === 0) {
@@ -118,23 +98,172 @@ const PanelTabs = memo(() => {
return `${t('controlLayers.layer_other')} (${activeEntityCount})`;
}, [activeEntityCount, t]);
useEffect(() => {
if (!layersTabRef.current) {
return;
}
const getIsOnLayersTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'layers';
const onDragEnter = () => {
// If we are already on the layers tab, do nothing
if (getIsOnLayersTab()) {
return;
}
// Else set the state to active and switch to the layers tab after a timeout
setLayersTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('layers'));
// When we switch tabs, the other tab should be pending
setLayersTabDndState('idle');
setGalleryTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnLayersTab()) {
setLayersTabDndState('idle');
} else {
setLayersTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setLayersTabDndState('potential');
};
return combine(
dropTargetForElements({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnLayersTab();
},
onDragStart,
}),
dropTargetForExternal({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnLayersTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
if (!galleryTabRef.current) {
return;
}
const getIsOnGalleryTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'gallery';
const onDragEnter = () => {
// If we are already on the gallery tab, do nothing
if (getIsOnGalleryTab()) {
return;
}
// Else set the state to active and switch to the gallery tab after a timeout
setGalleryTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
// When we switch tabs, the other tab should be pending
setGalleryTabDndState('idle');
setLayersTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnGalleryTab()) {
setGalleryTabDndState('idle');
} else {
setGalleryTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setGalleryTabDndState('potential');
};
return combine(
dropTargetForElements({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnGalleryTab();
},
onDragStart,
}),
dropTargetForExternal({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnGalleryTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
const onDrop = () => {
// Reset the dnd state when a drop happens
setGalleryTabDndState('idle');
setLayersTabDndState('idle');
};
const cleanup = combine(monitorForElements({ onDrop }), monitorForExternal({ onDrop }));
return () => {
cleanup();
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
}, []);
return (
<>
<Tab position="relative" onMouseOver={onOnMouseOverLayersTab} onMouseOut={onMouseOut} w={32}>
<Tab ref={layersTabRef} position="relative" w={32}>
<Box as="span" w="full">
{layersTabLabel}
</Box>
{dndCtx.active && activeTab !== 'layers' && (
<IAIDropOverlay isOver={mouseOverTab === 'layers'} withBackdrop={false} />
)}
<DndDropOverlay dndState={layersTabDndState} withBackdrop={false} />
</Tab>
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut} w={32}>
<Tab ref={galleryTabRef} position="relative" w={32}>
<Box as="span" w="full">
{t('gallery.gallery')}
</Box>
{dndCtx.active && activeTab !== 'gallery' && (
<IAIDropOverlay isOver={mouseOverTab === 'gallery'} withBackdrop={false} />
)}
<DndDropOverlay dndState={galleryTabDndState} withBackdrop={false} />
</Tab>
</>
);

View File

@@ -1,17 +1,19 @@
import { Spacer } from '@invoke-ai/ui-library';
import IAIDroppable from 'common/components/IAIDroppable';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
import { ControlLayerBadges } from 'features/controlLayers/components/ControlLayer/ControlLayerBadges';
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
import { ControlLayerSettings } from 'features/controlLayers/components/ControlLayer/ControlLayerSettings';
import { ControlLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { ReplaceLayerImageDropData } from 'features/dnd/types';
import type { ReplaceCanvasEntityObjectsWithImageDndTargetData } from 'features/dnd/dnd';
import { replaceCanvasEntityObjectsWithImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -21,14 +23,16 @@ type Props = {
export const ControlLayer = memo(({ id }: Props) => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const entityIdentifier = useMemo<CanvasEntityIdentifier<'control_layer'>>(
() => ({ id, type: 'control_layer' }),
[id]
);
const dropData = useMemo<ReplaceLayerImageDropData>(
() => ({ id, actionType: 'REPLACE_LAYER_WITH_IMAGE', context: { entityIdentifier } }),
[id, entityIdentifier]
const dndTargetData = useMemo<ReplaceCanvasEntityObjectsWithImageDndTargetData>(
() => replaceCanvasEntityObjectsWithImageDndTarget.getData({ entityIdentifier }, entityIdentifier.id),
[entityIdentifier]
);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<ControlLayerAdapterGate>
@@ -41,9 +45,14 @@ export const ControlLayer = memo(({ id }: Props) => {
<CanvasEntityHeaderCommonActions />
</CanvasEntityHeader>
<CanvasEntitySettingsWrapper>
<ControlLayerControlAdapter />
<ControlLayerSettings />
</CanvasEntitySettingsWrapper>
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
<DndDropTarget
dndTarget={replaceCanvasEntityObjectsWithImageDndTarget}
dndTargetData={dndTargetData}
label={t('controlLayers.replaceLayer')}
isDisabled={isBusy}
/>
</CanvasEntityContainer>
</ControlLayerAdapterGate>
</EntityIdentifierContext.Provider>

View File

@@ -1,11 +1,13 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
@@ -16,13 +18,15 @@ import {
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { getFilterForModel } from 'features/controlLayers/store/filters';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
const selectControlAdapter = useMemo(
@@ -39,11 +43,12 @@ const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<
export const ControlLayerControlAdapter = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { dispatch, getState } = useAppStore();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
const isFLUX = useAppSelector(selectIsFLUX);
const adapter = useEntityAdapterContext('control_layer');
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@@ -69,17 +74,58 @@ export const ControlLayerControlAdapter = memo(() => {
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
// filter config.
const isFiltering = adapter.filterer.$isFiltering.get();
const isSimple = adapter.filterer.$simple.get();
// If we are filtering and _not_ in simple mode, that means the user has clicked Advanced. They want to be in control
// of the settings. Bail early without doing anything else.
if (isFiltering && !isSimple) {
return;
}
// Else, we are in simple mode and will take care of some things for the user.
// First, check if the newly-selected model has a default filter. It may not - for example, Tile controlnet models
// don't have a default filter.
const defaultFilterForNewModel = getFilterForModel(modelConfig);
if (!defaultFilterForNewModel) {
// The user has chosen a model that doesn't have a default filter - cancel any in-progress filtering and bail.
if (isFiltering) {
adapter.filterer.cancel();
}
return;
}
// At this point, we know the user has selected a model that has a default filter. We need to either start filtering
// with that default filter, or update the existing filter config to match the new model's default filter.
const filterConfig = defaultFilterForNewModel.buildDefaults();
if (isFiltering) {
adapter.filterer.$filterConfig.set(filterConfig);
} else {
adapter.filterer.start(filterConfig);
}
// The user may have disabled auto-processing, so we should process the filter manually. This is essentially a
// no-op if auto-processing is already enabled, because the process method is debounced.
adapter.filterer.process();
},
[dispatch, entityIdentifier]
[adapter.filterer, dispatch, entityIdentifier]
);
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
const isBusy = useCanvasIsBusy();
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
[entityIdentifier]
const uploadOptions = useMemo(
() =>
({
onUpload: (imageDTO: ImageDTO) => {
replaceCanvasEntityObjectsWithImage({ entityIdentifier, imageDTO, dispatch, getState });
},
allowMultiple: false,
}) as const,
[dispatch, entityIdentifier, getState]
);
const uploadApi = useImageUploadButton({ postUploadAction });
const uploadApi = useImageUploadButton(uploadOptions);
return (
<Flex flexDir="column" gap={3} position="relative" w="full">

View File

@@ -1,14 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { ControlLayer } from 'features/controlLayers/components/ControlLayer/ControlLayer';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.controlLayers.entities.map(mapId).reverse();
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.controlLayers.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const ControlLayerEntityList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const layerIds = useAppSelector(selectEntityIds);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (layerIds.length === 0) {
if (entityIdentifiers.length === 0) {
return null;
}
if (layerIds.length > 0) {
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="control_layer" isSelected={isSelected}>
{layerIds.map((id) => (
<ControlLayer key={id} id={id} />
<CanvasEntityGroupList type="control_layer" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifier) => (
<ControlLayer key={entityIdentifier.id} id={entityIdentifier.id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -0,0 +1,18 @@
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
import { ControlLayerSettingsEmptyState } from 'features/controlLayers/components/ControlLayer/ControlLayerSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
import { memo } from 'react';
export const ControlLayerSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const isEmpty = useEntityIsEmpty(entityIdentifier);
if (isEmpty) {
return <ControlLayerSettingsEmptyState />;
}
return <ControlLayerControlAdapter />;
});
ControlLayerSettings.displayName = 'ControlLayerSettings';

View File

@@ -0,0 +1,53 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { Trans } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
export const ControlLayerSettingsEmptyState = memo(() => {
const entityIdentifier = useEntityIdentifierContext('control_layer');
const { dispatch, getState } = useAppStore();
const isBusy = useCanvasIsBusy();
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
replaceCanvasEntityObjectsWithImage({ imageDTO, entityIdentifier, dispatch, getState });
},
[dispatch, entityIdentifier, getState]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
return (
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
<Text textAlign="center" color="base.300">
<Trans
i18nKey="controlLayers.controlLayerEmptyState"
components={{
UploadButton: (
<Button
isDisabled={isBusy}
size="sm"
variant="link"
color="base.300"
{...uploadApi.getUploadButtonProps()}
/>
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}}
/>
</Text>
<input {...uploadApi.getUploadInputProps()} />
</Flex>
);
});
ControlLayerSettingsEmptyState.displayName = 'ControlLayerSettingsEmptyState';

View File

@@ -9,6 +9,7 @@ import {
MenuList,
Spacer,
Spinner,
Text,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
@@ -28,13 +29,10 @@ import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
const FilterContent = memo(
const FilterContentAdvanced = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const config = useStore(adapter.filterer.$filterConfig);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasImageState = useStore(adapter.filterer.$hasImageState);
const autoProcess = useAppSelector(selectAutoProcess);
@@ -73,36 +71,8 @@ const FilterContent = memo(
adapter.filterer.saveAs('control_layer');
}, [adapter.filterer]);
useRegisteredHotkeys({
id: 'applyFilter',
category: 'canvas',
callback: adapter.filterer.apply,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelFilter',
category: 'canvas',
callback: adapter.filterer.cancel,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
w={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
<>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
@@ -169,12 +139,67 @@ const FilterContent = memo(
{t('controlLayers.filter.cancel')}
</Button>
</ButtonGroup>
</Flex>
</>
);
}
);
FilterContent.displayName = 'FilterContent';
FilterContentAdvanced.displayName = 'FilterContentAdvanced';
const FilterContentSimple = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const config = useStore(adapter.filterer.$filterConfig);
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasImageState = useStore(adapter.filterer.$hasImageState);
const isValid = useMemo(() => {
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
}, [config]);
const onClickAdvanced = useCallback(() => {
adapter.filterer.$simple.set(false);
}, [adapter.filterer.$simple]);
return (
<>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
</Heading>
<Spacer />
</Flex>
<Flex flexDir="column" w="full" gap={2} pb={2}>
<Text color="base.500" textAlign="center">
{t('controlLayers.filter.processingLayerWith', { type: t(`controlLayers.filter.${config.type}.label`) })}
</Text>
<Text color="base.500" textAlign="center">
{t('controlLayers.filter.forMoreControl')}
</Text>
</Flex>
<ButtonGroup isAttached={false} size="sm" w="full">
<Button variant="ghost" onClick={onClickAdvanced}>
{t('controlLayers.filter.advanced')}
</Button>
<Spacer />
<Button
onClick={adapter.filterer.apply}
loadingText={t('controlLayers.filter.apply')}
variant="ghost"
isDisabled={isProcessing || !isValid || !hasImageState}
>
{t('controlLayers.filter.apply')}
</Button>
<Button variant="ghost" onClick={adapter.filterer.cancel} loadingText={t('controlLayers.filter.cancel')}>
{t('controlLayers.filter.cancel')}
</Button>
</ButtonGroup>
</>
);
}
);
FilterContentSimple.displayName = 'FilterContentSimple';
export const Filter = () => {
const canvasManager = useCanvasManager();
@@ -182,8 +207,54 @@ export const Filter = () => {
if (!adapter) {
return null;
}
return <FilterContent adapter={adapter} />;
};
Filter.displayName = 'Filter';
const FilterContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const simplified = useStore(adapter.filterer.$simple);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
useRegisteredHotkeys({
id: 'applyFilter',
category: 'canvas',
callback: adapter.filterer.apply,
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelFilter',
category: 'canvas',
callback: adapter.filterer.cancel,
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
w={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
{simplified && <FilterContentSimple adapter={adapter} />}
{!simplified && <FilterContentAdvanced adapter={adapter} />}
</Flex>
);
}
);
FilterContent.displayName = 'FilterContent';

View File

@@ -1,5 +1,5 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';

View File

@@ -1,82 +1,80 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { useNanoid } from 'common/hooks/useNanoid';
import { UploadImageButton } from 'common/hooks/useImageUploadButton';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { memo, useCallback, useEffect, useMemo } from 'react';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props = {
type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget> = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction;
dndTarget: T;
dndTargetData: ReturnType<T['getData']>;
};
export const IPAdapterImagePreview = memo(({ image, onChangeImage, droppableData, postUploadAction }: Props) => {
const { t } = useTranslation();
const isConnected = useStore($isConnected);
const dndId = useNanoid('ip_adapter_image_preview');
export const IPAdapterImagePreview = memo(
<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget>({
image,
onChangeImage,
dndTarget,
dndTargetData,
}: Props<T>) => {
const { t } = useTranslation();
const isConnected = useStore($isConnected);
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.image_name ?? skipToken
);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
useEffect(() => {
if (isConnected && isError) {
handleResetControlImage();
}
}, [handleResetControlImage, isError, isConnected]);
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
if (controlImage) {
return {
id: dndId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, dndId]);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
onChangeImage(imageDTO);
},
[onChangeImage]
);
useEffect(() => {
if (isConnected && isErrorControlImage) {
handleResetControlImage();
}
}, [handleResetControlImage, isConnected, isErrorControlImage]);
return (
<Flex
position="relative"
w="full"
h="full"
alignItems="center"
borderColor="error.500"
borderStyle="solid"
borderWidth={controlImage ? 0 : 1}
borderRadius="base"
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
postUploadAction={postUploadAction}
/>
{controlImage && (
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
<UploadImageButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
onUpload={onUpload}
fontSize={36}
/>
</Flex>
)}
</Flex>
);
});
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
</Flex>
);
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';

View File

@@ -2,14 +2,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { IPAdapter } from 'features/controlLayers/components/IPAdapter/IPAdapter';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(mapId).reverse();
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'reference_image';
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const IPAdapterList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const ipaIds = useAppSelector(selectEntityIds);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (ipaIds.length === 0) {
if (entityIdentifiers.length === 0) {
return null;
}
if (ipaIds.length > 0) {
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="reference_image" isSelected={isSelected}>
{ipaIds.map((id) => (
<IPAdapter key={id} id={id} />
<CanvasEntityGroupList type="reference_image" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifiers) => (
<IPAdapter key={entityIdentifiers.id} id={entityIdentifiers.id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -19,11 +19,12 @@ import {
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { IPAImageDropData } from 'features/dnd/types';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig, IPALayerImagePostUploadAction } from 'services/api/types';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';
@@ -80,13 +81,9 @@ export const IPAdapterSettings = memo(() => {
[dispatch, entityIdentifier]
);
const droppableData = useMemo<IPAImageDropData>(
() => ({ actionType: 'SET_IPA_IMAGE', context: { id: entityIdentifier.id }, id: entityIdentifier.id }),
[entityIdentifier.id]
);
const postUploadAction = useMemo<IPALayerImagePostUploadAction>(
() => ({ type: 'SET_IPA_IMAGE', id: entityIdentifier.id }),
[entityIdentifier.id]
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }, ipAdapter.image?.image_name),
[entityIdentifier, ipAdapter.image?.image_name]
);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const isBusy = useCanvasIsBusy();
@@ -122,10 +119,10 @@ export const IPAdapterSettings = memo(() => {
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
<IPAdapterImagePreview
image={ipAdapter.image ?? null}
image={ipAdapter.image}
onChangeImage={onChangeImage}
droppableData={droppableData}
postUploadAction={postUploadAction}
dndTarget={setGlobalReferenceImageDndTarget}
dndTargetData={dndTargetData}
/>
</Flex>
</Flex>

View File

@@ -1,5 +1,5 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';

View File

@@ -1,14 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { InpaintMask } from 'features/controlLayers/components/InpaintMask/InpaintMask';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.inpaintMasks.entities.map(mapId).reverse();
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.inpaintMasks.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const InpaintMaskList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const entityIds = useAppSelector(selectEntityIds);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (entityIds.length === 0) {
if (entityIdentifiers.length === 0) {
return null;
}
if (entityIds.length > 0) {
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="inpaint_mask" isSelected={isSelected}>
{entityIds.map((id) => (
<InpaintMask key={id} id={id} />
<CanvasEntityGroupList type="inpaint_mask" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifier) => (
<InpaintMask key={entityIdentifier.id} id={entityIdentifier.id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -0,0 +1,82 @@
import {
Badge,
CompositeNumberInput,
CompositeSlider,
Flex,
FormControl,
FormLabel,
useToken,
} from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import WavyLine from 'common/components/WavyLine';
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
export const ParamDenoisingStrength = memo(() => {
const img2imgStrength = useAppSelector(selectImg2imgStrength);
const dispatch = useAppDispatch();
const isEnabled = useAppSelector(selectIsEnabled);
const onChange = useCallback(
(v: number) => {
dispatch(setImg2imgStrength(v));
},
[dispatch]
);
const config = useAppSelector(selectImg2imgStrengthConfig);
const { t } = useTranslation();
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
return (
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
<Flex gap={3} alignItems="center">
<InformationalPopover feature="paramDenoisingStrength">
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
</InformationalPopover>
{isEnabled && (
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
)}
</Flex>
{isEnabled ? (
<>
<CompositeSlider
step={config.coarseStep}
fineStep={config.fineStep}
min={config.sliderMin}
max={config.sliderMax}
defaultValue={config.initial}
onChange={onChange}
value={img2imgStrength}
/>
<CompositeNumberInput
step={config.coarseStep}
fineStep={config.fineStep}
min={config.numberInputMin}
max={config.numberInputMax}
defaultValue={config.initial}
onChange={onChange}
value={img2imgStrength}
variant="outline"
/>
</>
) : (
<Flex alignItems="center">
<Badge opacity="0.6">
{t('common.disabled')} - {t('parameters.noRasterLayers')}
</Badge>
</Flex>
)}
</FormControl>
);
});
ParamDenoisingStrength.displayName = 'ParamDenoisingStrength';

View File

@@ -1,14 +1,16 @@
import { Spacer } from '@invoke-ai/ui-library';
import IAIDroppable from 'common/components/IAIDroppable';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
import { RasterLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { ReplaceLayerImageDropData } from 'features/dnd/types';
import type { ReplaceCanvasEntityObjectsWithImageDndTargetData } from 'features/dnd/dnd';
import { replaceCanvasEntityObjectsWithImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -18,10 +20,11 @@ type Props = {
export const RasterLayer = memo(({ id }: Props) => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const entityIdentifier = useMemo<CanvasEntityIdentifier<'raster_layer'>>(() => ({ id, type: 'raster_layer' }), [id]);
const dropData = useMemo<ReplaceLayerImageDropData>(
() => ({ id, actionType: 'REPLACE_LAYER_WITH_IMAGE', context: { entityIdentifier } }),
[id, entityIdentifier]
const dndTargetData = useMemo<ReplaceCanvasEntityObjectsWithImageDndTargetData>(
() => replaceCanvasEntityObjectsWithImageDndTarget.getData({ entityIdentifier }, entityIdentifier.id),
[entityIdentifier]
);
return (
@@ -34,7 +37,12 @@ export const RasterLayer = memo(({ id }: Props) => {
<Spacer />
<CanvasEntityHeaderCommonActions />
</CanvasEntityHeader>
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
<DndDropTarget
dndTarget={replaceCanvasEntityObjectsWithImageDndTarget}
dndTargetData={dndTargetData}
label={t('controlLayers.replaceLayer')}
isDisabled={isBusy}
/>
</CanvasEntityContainer>
</RasterLayerAdapterGate>
</EntityIdentifierContext.Provider>

View File

@@ -1,14 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { RasterLayer } from 'features/controlLayers/components/RasterLayer/RasterLayer';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.rasterLayers.entities.map(mapId).reverse();
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.rasterLayers.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'raster_layer';
@@ -16,17 +16,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const RasterLayerEntityList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const layerIds = useAppSelector(selectEntityIds);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (layerIds.length === 0) {
if (entityIdentifiers.length === 0) {
return null;
}
if (layerIds.length > 0) {
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="raster_layer" isSelected={isSelected}>
{layerIds.map((id) => (
<RasterLayer key={id} id={id} />
<CanvasEntityGroupList type="raster_layer" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifier) => (
<RasterLayer key={entityIdentifier.id} id={entityIdentifier.id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -1,8 +1,8 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { deepClone } from 'common/util/deepClone';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useEntityIsLocked } from 'features/controlLayers/hooks/useEntityIsLocked';
import {
@@ -10,6 +10,7 @@ import {
rasterLayerConvertedToInpaintMask,
rasterLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { initialControlNet } from 'features/controlLayers/store/util';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSwapBold } from 'react-icons/pi';
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isBusy = useCanvasIsBusy();
const isLocked = useEntityIsLocked(entityIdentifier);
@@ -37,10 +37,10 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
rasterLayerConvertedToControlLayer({
entityIdentifier,
replace: true,
overrides: { controlAdapter: defaultControlAdapter },
overrides: { controlAdapter: deepClone(initialControlNet) },
})
);
}, [defaultControlAdapter, dispatch, entityIdentifier]);
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />} isDisabled={isBusy || isLocked}>

View File

@@ -1,15 +1,16 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { deepClone } from 'common/util/deepClone';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
rasterLayerConvertedToControlLayer,
rasterLayerConvertedToInpaintMask,
rasterLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { initialControlNet } from 'features/controlLayers/store/util';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold } from 'react-icons/pi';
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isBusy = useCanvasIsBusy();
const copyToInpaintMask = useCallback(() => {
@@ -35,10 +35,10 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
dispatch(
rasterLayerConvertedToControlLayer({
entityIdentifier,
overrides: { controlAdapter: defaultControlAdapter },
overrides: { controlAdapter: deepClone(initialControlNet) },
})
);
}, [defaultControlAdapter, dispatch, entityIdentifier]);
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />} isDisabled={isBusy}>

View File

@@ -1,5 +1,5 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';

View File

@@ -1,14 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { RegionalGuidance } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidance';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.regionalGuidance.entities.map(mapId).reverse();
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.regionalGuidance.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'regional_guidance';
@@ -16,17 +16,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const RegionalGuidanceEntityList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const rgIds = useAppSelector(selectEntityIds);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (rgIds.length === 0) {
if (entityIdentifiers.length === 0) {
return null;
}
if (rgIds.length > 0) {
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="regional_guidance" isSelected={isSelected}>
{rgIds.map((id) => (
<RegionalGuidance key={id} id={id} />
<CanvasEntityGroupList type="regional_guidance" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifier) => (
<RegionalGuidance key={entityIdentifier.id} id={entityIdentifier.id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -20,11 +20,12 @@ import {
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice, selectRegionalGuidanceReferenceImage } from 'features/controlLayers/store/selectors';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { RGIPAdapterImageDropData } from 'features/dnd/types';
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiTrashSimpleFill } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig, RGIPAdapterImagePostUploadAction } from 'services/api/types';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type Props = {
@@ -91,18 +92,15 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
[dispatch, entityIdentifier, referenceImageId]
);
const droppableData = useMemo<RGIPAdapterImageDropData>(
() => ({
actionType: 'SET_RG_IP_ADAPTER_IMAGE',
context: { id: entityIdentifier.id, referenceImageId: referenceImageId },
id: entityIdentifier.id,
}),
[entityIdentifier.id, referenceImageId]
);
const postUploadAction = useMemo<RGIPAdapterImagePostUploadAction>(
() => ({ type: 'SET_RG_IP_ADAPTER_IMAGE', id: entityIdentifier.id, referenceImageId: referenceImageId }),
[entityIdentifier.id, referenceImageId]
const dndTargetData = useMemo<SetRegionalGuidanceReferenceImageDndTargetData>(
() =>
setRegionalGuidanceReferenceImageDndTarget.getData(
{ entityIdentifier, referenceImageId },
ipAdapter.image?.image_name
),
[entityIdentifier, ipAdapter.image?.image_name, referenceImageId]
);
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
const isBusy = useCanvasIsBusy();
@@ -151,10 +149,10 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
<IPAdapterImagePreview
image={ipAdapter.image ?? null}
image={ipAdapter.image}
onChangeImage={onChangeImage}
droppableData={droppableData}
postUploadAction={postUploadAction}
dndTarget={setRegionalGuidanceReferenceImageDndTarget}
dndTargetData={dndTargetData}
/>
</Flex>
</Flex>

View File

@@ -1,38 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsSelected } from 'features/controlLayers/hooks/useEntityIsSelected';
import { useEntitySelectionColor } from 'features/controlLayers/hooks/useEntitySelectionColor';
import { entitySelected } from 'features/controlLayers/store/canvasSlice';
import type { PropsWithChildren } from 'react';
import { memo, useCallback } from 'react';
export const CanvasEntityContainer = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
const isSelected = useEntityIsSelected(entityIdentifier);
const selectionColor = useEntitySelectionColor(entityIdentifier);
const onClick = useCallback(() => {
if (isSelected) {
return;
}
dispatch(entitySelected({ entityIdentifier }));
}, [dispatch, entityIdentifier, isSelected]);
return (
<Flex
position="relative"
flexDir="column"
w="full"
bg={isSelected ? 'base.800' : 'base.850'}
onClick={onClick}
borderInlineStartWidth={5}
borderColor={isSelected ? selectionColor : 'base.800'}
borderRadius="base"
>
{props.children}
</Flex>
);
});
CanvasEntityContainer.displayName = 'CanvasEntityContainer';

View File

@@ -1,90 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useBoolean } from 'common/hooks/useBoolean';
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import { type CanvasEntityIdentifier, isRenderableEntityType } from 'features/controlLayers/store/types';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
import { PiCaretDownBold } from 'react-icons/pi';
type Props = PropsWithChildren<{
isSelected: boolean;
type: CanvasEntityIdentifier['type'];
}>;
const _hover: SystemStyleObject = {
opacity: 1,
};
export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props) => {
const title = useEntityTypeTitle(type);
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
const collapse = useBoolean(true);
return (
<Flex flexDir="column" w="full">
<Flex w="full">
<Flex
flexGrow={1}
as={Button}
onClick={collapse.toggle}
justifyContent="space-between"
alignItems="center"
gap={3}
variant="unstyled"
p={0}
h={8}
>
<Icon
boxSize={4}
as={PiCaretDownBold}
transform={collapse.isTrue ? undefined : 'rotate(-90deg)'}
fill={isSelected ? 'base.200' : 'base.500'}
transitionProperty="common"
transitionDuration="fast"
/>
{informationalPopoverFeature ? (
<InformationalPopover feature={informationalPopoverFeature}>
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
</InformationalPopover>
) : (
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
)}
<Spacer />
</Flex>
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue}>
<Flex flexDir="column" gap={2} pt={2}>
{children}
</Flex>
</Collapse>
</Flex>
);
});
CanvasEntityGroupList.displayName = 'CanvasEntityGroupList';

View File

@@ -1,4 +1,4 @@
import { Box, chakra, Flex } from '@invoke-ai/ui-library';
import { Box, chakra, Flex, Tooltip } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { rgbColorToString } from 'common/util/colorCodeTransformers';
@@ -86,13 +86,63 @@ export const CanvasEntityPreviewImage = memo(() => {
useEffect(updatePreview, [updatePreview, canvasCache, nodeRect, pixelRect]);
return (
<Tooltip label={<TooltipContent canvasRef={canvasRef} />} p={2} closeOnScroll>
<Flex
position="relative"
alignItems="center"
justifyContent="center"
w={CONTAINER_WIDTH_PX}
h={CONTAINER_WIDTH_PX}
borderRadius="sm"
borderWidth={1}
bg="base.900"
flexShrink={0}
>
<Box
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
bgSize="5px"
/>
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
</Flex>
</Tooltip>
);
});
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
const TooltipContent = ({ canvasRef }: { canvasRef: React.RefObject<HTMLCanvasElement> }) => {
const canvasRef2 = useRef<HTMLCanvasElement>(null);
useEffect(() => {
if (!canvasRef2.current || !canvasRef.current) {
return;
}
const ctx = canvasRef2.current.getContext('2d');
if (!ctx) {
return;
}
canvasRef2.current.width = canvasRef.current.width;
canvasRef2.current.height = canvasRef.current.height;
ctx.clearRect(0, 0, canvasRef2.current.width, canvasRef2.current.height);
ctx.drawImage(canvasRef.current, 0, 0);
}, [canvasRef]);
return (
<Flex
position="relative"
alignItems="center"
justifyContent="center"
w={CONTAINER_WIDTH_PX}
h={CONTAINER_WIDTH_PX}
w={150}
h={150}
borderRadius="sm"
borderWidth={1}
bg="base.900"
@@ -105,11 +155,9 @@ export const CanvasEntityPreviewImage = memo(() => {
bottom={0}
left={0}
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
bgSize="5px"
bgSize="8px"
/>
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
<ChakraCanvas position="relative" ref={canvasRef2} objectFit="contain" maxW="full" maxH="full" />
</Flex>
);
});
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
};

View File

@@ -4,9 +4,10 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasEntityIdentifier, CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useMemo, useSyncExternalStore } from 'react';
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
import { assert } from 'tsafe';
const EntityAdapterContext = createContext<
@@ -95,6 +96,17 @@ export const RegionalGuidanceAdapterGate = memo(({ children }: PropsWithChildren
return <EntityAdapterContext.Provider value={adapter}>{children}</EntityAdapterContext.Provider>;
});
export const useEntityAdapterContext = <T extends CanvasRenderableEntityType | undefined = CanvasRenderableEntityType>(
type?: T
): CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T> => {
const adapter = useContext(EntityAdapterContext);
assert(adapter, 'useEntityIdentifier must be used within a EntityIdentifierProvider');
if (type) {
assert(adapter.entityIdentifier.type === type, 'useEntityIdentifier must be used with the correct type');
}
return adapter as CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T>;
};
RegionalGuidanceAdapterGate.displayName = 'RegionalGuidanceAdapterGate';
export const useEntityAdapterSafe = (

View File

@@ -2,11 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
import {
bboxChangedFromCanvas,
controlLayerAdded,
inpaintMaskAdded,
rasterLayerAdded,
@@ -17,38 +14,22 @@ import {
rgPositivePromptChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import {
selectBboxModelBase,
selectBboxRect,
selectCanvasSlice,
selectEntityOrThrow,
} from 'features/controlLayers/store/selectors';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import {
imageDTOToImageObject,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
/** @knipignore */
export const selectDefaultControlAdapter = createSelector(
selectModelConfigsQuery,
selectBase,
@@ -85,6 +66,9 @@ export const selectDefaultIPAdapter = createSelector(
const ipAdapter = deepClone(initialIPAdapter);
if (model) {
ipAdapter.model = zModelIdentifierField.parse(model);
if (model.base === 'flux') {
ipAdapter.clipVisionModel = 'ViT-L';
}
}
return ipAdapter;
}
@@ -92,11 +76,10 @@ export const selectDefaultIPAdapter = createSelector(
export const useAddControlLayer = () => {
const dispatch = useAppDispatch();
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const func = useCallback(() => {
const overrides = { controlAdapter: defaultControlAdapter };
const overrides = { controlAdapter: deepClone(initialControlNet) };
dispatch(controlLayerAdded({ isSelected: true, overrides }));
}, [defaultControlAdapter, dispatch]);
}, [dispatch]);
return func;
};
@@ -110,150 +93,6 @@ export const useAddRasterLayer = () => {
return func;
};
export const useNewRasterLayerFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useNewControlLayerFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasControlLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useNewInpaintMaskFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasInpaintMaskState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useNewRegionalGuidanceFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRegionalGuidanceState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(rgAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
/**
* Returns a function that adds a new canvas with the given image as the initial image, replicating the img2img flow:
* - Reset the canvas
* - Resize the bbox to the image's aspect ratio at the optimal size for the selected model
* - Add the image as a raster layer
* - Resizes the layer to fit the bbox using the 'fill' strategy
*
* This allows the user to immediately generate a new image from the given image without any additional steps.
*/
export const useNewCanvasFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const base = useAppSelector(selectBboxModelBase);
const func = useCallback(
(imageDTO: ImageDTO, type: CanvasRasterLayerState['type'] | CanvasControlLayerState['type']) => {
// Calculate the new bbox dimensions to fit the image's aspect ratio at the optimal size
const ratio = imageDTO.width / imageDTO.height;
const optimalDimension = getOptimalDimension(base);
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
// The overrides need to include the layer's ID so we can transform the layer it is initialized
let overrides: Partial<CanvasRasterLayerState> | Partial<CanvasControlLayerState>;
if (type === 'raster_layer') {
overrides = {
id: getPrefixedId('raster_layer'),
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageDTOToImageObject(imageDTO)],
} satisfies Partial<CanvasRasterLayerState>;
} else if (type === 'control_layer') {
overrides = {
id: getPrefixedId('control_layer'),
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageDTOToImageObject(imageDTO)],
} satisfies Partial<CanvasControlLayerState>;
} else {
// Catch unhandled types
assert<Equals<typeof type, never>>(false);
}
CanvasEntityAdapterBase.registerInitCallback(async (adapter) => {
// Skip the callback if the adapter is not the one we are creating
if (adapter.id !== overrides.id) {
return false;
}
// Fit the layer to the bbox w/ fill strategy
await adapter.transformer.startTransform({ silent: true });
adapter.transformer.fitToBboxFill();
await adapter.transformer.applyTransform();
return true;
});
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
// The type casts are safe because the type is checked above
if (type === 'raster_layer') {
dispatch(rasterLayerAdded({ overrides: overrides as Partial<CanvasRasterLayerState>, isSelected: true }));
} else if (type === 'control_layer') {
dispatch(controlLayerAdded({ overrides: overrides as Partial<CanvasControlLayerState>, isSelected: true }));
} else {
// Catch unhandled types
assert<Equals<typeof type, never>>(false);
}
},
[base, bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useAddInpaintMask = () => {
const dispatch = useAppDispatch();
const func = useCallback(() => {

View File

@@ -1,10 +1,9 @@
import { logger } from 'app/logging/logger';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
@@ -25,12 +24,13 @@ import type {
Rect,
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('canvas');
@@ -64,7 +64,7 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
return;
}
let metadata: SerializableObject | undefined = undefined;
let metadata: JsonObject | undefined = undefined;
if (withMetadata) {
metadata = selectCanvasMetadata(store.getState());
@@ -72,10 +72,16 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
const result = await withResultAsync(() => {
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
return canvasManager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
is_intermediate: !saveToGallery,
metadata,
});
return canvasManager.compositor.getCompositeImageDTO(
rasterAdapters,
rect,
{
is_intermediate: !saveToGallery,
metadata,
},
undefined,
true // force upload the image to ensure it gets added to the gallery
);
});
if (result.isOk()) {
@@ -223,13 +229,12 @@ export const useNewRasterLayerFromBbox = () => {
export const useNewControlLayerFromBbox = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO, rect: Rect) => {
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageDTOToImageObject(imageDTO)],
controlAdapter: deepClone(defaultControlAdapter),
controlAdapter: deepClone(initialControlNet),
position: { x: rect.x, y: rect.y },
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
@@ -242,7 +247,7 @@ export const useNewControlLayerFromBbox = () => {
toastOk: t('controlLayers.newControlLayerOk'),
toastError: t('controlLayers.newControlLayerError'),
};
}, [defaultControlAdapter, dispatch, t]);
}, [dispatch, t]);
const func = useSaveCanvas(arg);
return func;
};

View File

@@ -1,27 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { rgbColorToString } from 'common/util/colorCodeTransformers';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
export const useEntitySelectionColor = (entityIdentifier: CanvasEntityIdentifier) => {
const selectSelectionColor = useMemo(
() =>
createSelector(selectCanvasSlice, (canvas) => {
const entity = selectEntity(canvas, entityIdentifier);
if (!entity) {
return 'base.400';
} else if (entity.type === 'inpaint_mask') {
return rgbColorToString(entity.fill.color);
} else if (entity.type === 'regional_guidance') {
return rgbColorToString(entity.fill.color);
} else {
return 'base.400';
}
}),
[entityIdentifier]
);
const selectionColor = useAppSelector(selectSelectionColor);
return selectionColor;
};

View File

@@ -5,14 +5,10 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import { canvasToBlob } from 'features/controlLayers/konva/util';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { uploadImage } from 'services/api/endpoints/images';
export const useSaveLayerToAssets = () => {
const { t } = useTranslation();
const [uploadImage] = useUploadImageMutation();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const saveLayerToAssets = useCallback(
@@ -27,30 +23,17 @@ export const useSaveLayerToAssets = () => {
if (!adapter) {
return;
}
try {
const canvas = adapter.getCanvas();
const blob = await canvasToBlob(canvas);
const file = new File([blob], `layer-${adapter.id}.png`, { type: 'image/png' });
await uploadImage({
file,
image_category: 'user',
is_intermediate: false,
postUploadAction: { type: 'TOAST' },
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
});
toast({
status: 'info',
title: t('toast.layerSavedToAssets'),
});
} catch (error) {
toast({
status: 'error',
title: t('toast.problemSavingLayer'),
});
}
const canvas = adapter.getCanvas();
const blob = await canvasToBlob(canvas);
const file = new File([blob], `layer-${adapter.id}.png`, { type: 'image/png' });
uploadImage({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
});
},
[t, autoAddBoardId, uploadImage]
[autoAddBoardId]
);
return saveLayerToAssets;

View File

@@ -1,4 +1,3 @@
import type { SerializableObject } from 'common/types';
import { withResultAsync } from 'common/util/result';
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
import type { CanvasEntityAdapter, CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
@@ -35,12 +34,13 @@ import { t } from 'i18next';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { UploadOptions } from 'services/api/endpoints/images';
import type { UploadImageArg } from 'services/api/endpoints/images';
import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import type { JsonObject } from 'type-fest';
type CompositingOptions = {
/**
@@ -173,14 +173,14 @@ export class CanvasCompositorModule extends CanvasModuleBase {
return adapters as CanvasEntityAdapterFromType<T>[];
};
getCompositeHash = (adapters: CanvasEntityAdapter[], extra: SerializableObject): string => {
const adapterHashes: SerializableObject[] = [];
getCompositeHash = (adapters: CanvasEntityAdapter[], extra: JsonObject): string => {
const adapterHashes: JsonObject[] = [];
for (const adapter of adapters) {
adapterHashes.push(adapter.getHashableState());
}
const data: SerializableObject = {
const data: JsonObject = {
extra,
adapterHashes,
};
@@ -253,18 +253,20 @@ export class CanvasCompositorModule extends CanvasModuleBase {
* @param rect The region to include in the rasterized image
* @param uploadOptions Options for uploading the image
* @param compositingOptions Options for compositing the entities
* @param forceUpload If true, the image is always re-uploaded, returning a new image DTO
* @returns A promise that resolves to the image DTO
*/
getCompositeImageDTO = async (
adapters: CanvasEntityAdapter[],
rect: Rect,
uploadOptions: Pick<UploadOptions, 'is_intermediate' | 'metadata'>,
compositingOptions?: CompositingOptions
uploadOptions: Pick<UploadImageArg, 'is_intermediate' | 'metadata'>,
compositingOptions?: CompositingOptions,
forceUpload?: boolean
): Promise<ImageDTO> => {
assert(rect.width > 0 && rect.height > 0, 'Unable to rasterize empty rect');
const hash = this.getCompositeHash(adapters, { rect });
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
const cachedImageName = forceUpload ? undefined : this.manager.cache.imageNameCache.get(hash);
let imageDTO: ImageDTO | null = null;
@@ -295,12 +297,12 @@ export class CanvasCompositorModule extends CanvasModuleBase {
this.$isUploading.set(true);
const uploadResult = await withResultAsync(() =>
uploadImage({
blob,
fileName: 'canvas-composite.png',
file: new File([blob], 'canvas-composite.png', { type: 'image/png' }),
image_category: 'general',
is_intermediate: uploadOptions.is_intermediate,
board_id: uploadOptions.is_intermediate ? undefined : selectAutoAddBoardId(this.manager.store.getState()),
metadata: uploadOptions.metadata,
withToast: false,
})
);
this.$isUploading.set(false);
@@ -327,6 +329,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
entityIdentifiers: T[],
deleteMergedEntities: boolean
): Promise<ImageDTO | null> => {
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergingLayers'), withCount: false });
if (entityIdentifiers.length <= 1) {
this.log.warn({ entityIdentifiers }, 'Cannot merge less than 2 entities');
return null;
@@ -349,7 +352,12 @@ export class CanvasCompositorModule extends CanvasModuleBase {
if (result.isErr()) {
this.log.error({ error: serializeError(result.error) }, 'Failed to merge selected entities');
toast({ title: t('controlLayers.mergeVisibleError'), status: 'error' });
toast({
id: 'MERGE_LAYERS_TOAST',
title: t('controlLayers.mergeVisibleError'),
status: 'error',
withCount: false,
});
return null;
}
@@ -381,7 +389,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
assert<Equals<typeof type, never>>(false, 'Unsupported type for merge');
}
toast({ title: t('controlLayers.mergeVisibleOk') });
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergeVisibleOk'), status: 'success', withCount: false });
return result.value;
};

View File

@@ -1,7 +1,6 @@
import type { Selector } from '@reduxjs/toolkit';
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
@@ -37,6 +36,7 @@ import type { Logger } from 'roarr';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import { assert } from 'tsafe';
import type { Jsonifiable, JsonObject } from 'type-fest';
// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter`
// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. We'll need to do a
@@ -111,7 +111,7 @@ export abstract class CanvasEntityAdapterBase<
*
* This is used for caching.
*/
abstract getHashableState: () => SerializableObject;
abstract getHashableState: () => JsonObject;
/**
* Callbacks that are executed when the module is initialized.
@@ -566,7 +566,7 @@ export abstract class CanvasEntityAdapterBase<
* Gets a hash of the entity's state, as provided by `getHashableState`. If `extra` is provided, it will be included in
* the hash.
*/
hash = (extra?: SerializableObject): string => {
hash = (extra?: Jsonifiable): string => {
const arg = {
state: this.getHashableState(),
extra,
@@ -614,8 +614,8 @@ export abstract class CanvasEntityAdapterBase<
transformer: this.transformer.repr(),
renderer: this.renderer.repr(),
bufferRenderer: this.bufferRenderer.repr(),
segmentAnything: this.segmentAnything?.repr(),
filterer: this.filterer?.repr(),
segmentAnything: this.segmentAnything?.repr() ?? null,
filterer: this.filterer?.repr() ?? null,
hasCache: this.$canvasCache.get() !== null,
isLocked: this.$isLocked.get(),
isDisabled: this.$isDisabled.get(),

View File

@@ -1,4 +1,3 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
@@ -9,6 +8,7 @@ import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/Canvas
import type { CanvasControlLayerState, CanvasEntityIdentifier, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
import { omit } from 'lodash-es';
import type { JsonObject } from 'type-fest';
export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
CanvasControlLayerState,
@@ -77,7 +77,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
return canvas;
};
getHashableState = (): SerializableObject => {
getHashableState = (): JsonObject => {
const keysToOmit: (keyof CanvasControlLayerState)[] = [
'name',
'controlAdapter',

View File

@@ -1,4 +1,3 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
@@ -7,6 +6,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasEntityIdentifier, CanvasInpaintMaskState, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
import { omit } from 'lodash-es';
import type { JsonObject } from 'type-fest';
export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<
CanvasInpaintMaskState,
@@ -69,7 +69,7 @@ export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<
}
};
getHashableState = (): SerializableObject => {
getHashableState = (): JsonObject => {
const keysToOmit: (keyof CanvasInpaintMaskState)[] = ['fill', 'name', 'opacity', 'isLocked'];
return omit(this.state, keysToOmit);
};

View File

@@ -1,4 +1,3 @@
import type { SerializableObject } from 'common/types';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
@@ -9,6 +8,7 @@ import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/Canvas
import type { CanvasEntityIdentifier, CanvasRasterLayerState, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
import { omit } from 'lodash-es';
import type { JsonObject } from 'type-fest';
export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
CanvasRasterLayerState,
@@ -70,7 +70,7 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
return canvas;
};
getHashableState = (): SerializableObject => {
getHashableState = (): JsonObject => {
const keysToOmit: (keyof CanvasRasterLayerState)[] = ['name', 'isLocked'];
return omit(this.state, keysToOmit);
};

Some files were not shown because too many files have changed in this diff Show More