Compare commits

..

112 Commits

Author SHA1 Message Date
psychedelicious
ceae1dc04f chore: bump version to v5.4.1 2024-11-15 11:21:24 +11:00
psychedelicious
4b390906bc fix(ui): multiple selection dnd sometimes doesn't get full selection
Turns out a gallery image's `imageDTO` object can actually be a different object by reference. I thought this was not possible thanks to how we have a quasi-normalized cache.

Need to check against image name instead of reference equality when deciding whether or not to use the single image or the gallery selection for the dnd payload.
2024-11-15 11:21:03 +11:00
psychedelicious
c5b8efe03b fix(ui): unable to use text inputs within draggable 2024-11-15 10:25:30 +11:00
psychedelicious
4d08d00ad8 chore(ui): knip 2024-11-14 13:38:40 -08:00
psychedelicious
9b0130262b fix(ui): use silent upload for single-image upload buttons 2024-11-14 13:38:40 -08:00
psychedelicious
878093f64e fix(ui): image uploading handling
Rework uploadImage and uploadImages helpers and the RTK listener, ensuring gallery view isn't changed unexpectedly and preventing extraneous toasts.

Fix staging area save to gallery button to essentially make a copy of the image, instead of changing its intermediate status.
2024-11-14 13:38:40 -08:00
psychedelicious
d5ff7ef250 feat(ui): update output only masked regions
- New name: "Output only Generated Regions"
- New default: true (this was the intention, but at some point the behaviour of the setting was inverted without the default being changed)
2024-11-14 13:35:55 -08:00
psychedelicious
f36583f866 feat(ui): tweak image selection/hover styling
The styling in gallery for selected vs hovered was very similar, leading users to think that the hovered image was also selected.

Reducing the borders for hovered images to a single pixel makes it easier to distinguish between selected and hovered.
2024-11-14 16:28:53 -05:00
psychedelicious
829bc1bc7d feat(ui): progress alert config setting
- Add `invocationProgressAlert` as a disable-able feature. Hide the alert and the setting in system settings when disabled.
- Fix merge conflict
2024-11-15 05:49:05 +11:00
Mary Hipp
17c7b57145 (ui): make detailed progress view a setting that can be hidden 2024-11-15 05:49:05 +11:00
psychedelicious
6a12189542 feat(ui): updated progress event display
- Tweak layout/styling of alerts for consistent spacing
- Add percentage to message if it has percentage
- Only show events if the destination is canvas (so workflows events are hidden for example)
2024-11-15 05:49:05 +11:00
psychedelicious
96a31a5563 feat(app): add more events when loading/running models 2024-11-15 05:49:05 +11:00
psychedelicious
067747eca9 feat(app): tweak model load events
- Pass in the `UtilInterface` to the `ModelsInterface` so we can call the simple `signal_progress` method instead of the complicated `emit_invocation_progress` method.
- Only emit load events when starting to load - not after.
- Add more detail to the messages, like submodel type
2024-11-15 05:49:05 +11:00
Mary Hipp
c7878fddc6 (pytest) mock emit_invocation_progress on events service 2024-11-15 05:49:05 +11:00
maryhipp
54c51e0a06 (worker) add progress images for downloading remote models 2024-11-15 05:49:05 +11:00
Mary Hipp
1640ea0298 (pytest) add missing arg for mocked context 2024-11-15 05:49:05 +11:00
Mary Hipp
0c32ae9775 (pytest) fix import 2024-11-15 05:49:05 +11:00
maryhipp
fdb8ca5165 (worker) use source if name is not available 2024-11-15 05:49:05 +11:00
Mary Hipp
571faf6d7c (pytest) add queue_item and invocation to data in context for test 2024-11-15 05:49:05 +11:00
Mary Hipp
bdbdb22b74 (ui) add Canvas Alert for invocation progress messages 2024-11-15 05:49:05 +11:00
maryhipp
9bbb5644af (worker) add invocation_progress events to model loading 2024-11-15 05:49:05 +11:00
Mary Hipp
e90ad19f22 (ui): update en string for full IP adapter 2024-11-14 10:07:42 -08:00
Ryan Dick
0ba11e8f73 SD3 Image-to-Image and Inpainting (#7295)
## Summary

Add support for SD3 image-to-image and inpainting. Similar to FLUX, the
implementation supports fractional denoise_start/denoise_end for more
fine-grained denoise strength control, and a gradient mask adjustment
schedule for smoother inpainting seams.

## Example
Workflow
<img width="1016" alt="image"
src="https://github.com/user-attachments/assets/ee598d77-be80-4ca7-9355-c3cbefa2ef43">

Result

![image](https://github.com/user-attachments/assets/43953fa7-0e4e-42b5-84e8-85cfeeeee00b)

## QA Instructions

- [x] Regression test of text-to-image
- [x] Test image-to-image without mask
- [x] Test that adjusting denoising_start allows fine-grained control of
amount of change in image-to-image
- [x] Test inpainting with mask
- [x] Smoke test SD1, SDXL, FLUX image-to-image to make sure there was
no regression with the frontend changes.

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2024-11-14 09:33:51 -08:00
Ryan Dick
1cf7600f5b Merge branch 'main' into ryan/sd3-image-to-image 2024-11-14 09:25:23 -08:00
Ryan Dick
4f9d12b872 Fix FLUX diffusers LoRA models with no .proj_mlp layers (#7313)
## Summary

Add support for FLUX diffusers LoRA models without `.proj_mlp` layers.

## Related Issues / Discussions

Closes #7129 

## QA Instructions

- [x] FLUX diffusers LoRA **without .proj_mlp** layers
- [x] FLUX diffusers LoRA **with .proj_mlp** layers
- [x] FLUX diffusers LoRA **without .proj_mlp** layers, quantized base
model
- [x] FLUX diffusers LoRA **with .proj_mlp** layers, quantized base
model

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2024-11-14 09:09:10 -08:00
Ryan Dick
68c3b0649b Add unit tests for FLUX diffusers LoRA without .proj_mlp layers. 2024-11-14 16:53:49 +00:00
Ryan Dick
8ef8bd4261 Add state dict tensor shapes for existing LoRA unit tests. 2024-11-14 16:53:49 +00:00
Ryan Dick
50897ba066 Add flag to optionally allow missing layer keys in FLUX lora loader. 2024-11-14 16:53:49 +00:00
Ryan Dick
3510643870 Support FLUX LoRAs without .proj_mlp layers. 2024-11-14 16:53:49 +00:00
Ryan Dick
ca9cb1c9ef Flux Vae broke for float16, force bfloat16 or float32 were compatible (#7213)
## Summary

The Flux VAE, like many VAEs, is broken if run using float16 inputs
returning black images due to NaNs
This will fix the issue by forcing the VAE to run in bfloat16 or float32
were compatible

## Related Issues / Discussions

Fix for issue https://github.com/invoke-ai/InvokeAI/issues/7208

## QA Instructions

Tested on MacOS, VAE works with float16 in the invoke.yaml and left to
default.
I also briefly forced it down the float32 route to check that to.
Needs testing on CUDA / ROCm

## Merge Plan

It should be a straight forward merge,
2024-11-13 15:51:40 -08:00
Ryan Dick
b89caa02bd Merge branch 'main' into flux_vae_fp16_broke 2024-11-13 15:33:43 -08:00
Ryan Dick
eaf4e08c44 Use vae.parameters() for more efficient access of the first model parameter. 2024-11-13 23:32:40 +00:00
Darrell
fb19621361 Updated link to flux ip adapter model 2024-11-12 08:11:40 -05:00
Mary Hipp
9179619077 actually use optimized denoising 2024-11-08 20:46:08 -05:00
Mary Hipp
13cb5f0ba2 Merge remote-tracking branch 'origin/main' into ryan/sd3-image-to-image 2024-11-08 20:29:56 -05:00
Mary Hipp
7e52fc1c17 Merge branch 'ryan/sd3-image-to-image' of https://github.com/invoke-ai/InvokeAI into ryan/sd3-image-to-image 2024-11-08 20:14:24 -05:00
Mary Hipp
7f60a4a282 (ui): update more generation settings for SD3 linear UI 2024-11-08 20:14:13 -05:00
psychedelicious
3f880496f7 feat(ui): clarify denoising strength badge text 2024-11-09 08:38:41 +11:00
Ryan Dick
f05efd3270 Fix import for getInfill. 2024-11-08 20:42:44 +00:00
psychedelicious
79eb8172b6 feat(ui): update warnings on upscaling tab based on model arch
When an unsupported model architecture is selected, show that warning only, without the extra warnings (i.e. no "missing tile controlnet" warning)

Update Invoke tooltip warnings accordingly

Closes #7239
Closes #7177
2024-11-09 07:34:03 +11:00
Ryan Dick
7732b5d478 Fix bug related to i2l nodes during graph construction of image-to-image workflows. 2024-11-08 20:15:34 +00:00
Mary Hipp
a2a1934b66 Merge branch 'ryan/sd3-image-to-image' of https://github.com/invoke-ai/InvokeAI into ryan/sd3-image-to-image 2024-11-08 13:43:19 -05:00
Mary Hipp
dff6570078 (ui) SD3 support in linear UI 2024-11-08 13:42:57 -05:00
maryhipp
04e4fb63af add SD3 generation modes for metadata validation 2024-11-08 13:13:58 -05:00
Vargol
83609d5008 Merge branch 'invoke-ai:main' into flux_vae_fp16_broke 2024-11-08 10:37:31 +00:00
David Burnett
2618ed0ae7 ruff complained 2024-11-08 10:31:53 +00:00
David Burnett
bb3cedddd5 Rework change based on comments 2024-11-08 10:27:47 +00:00
psychedelicious
5b3e1593ca fix(ui): restore missing image paste handler
Missed migrating this logic over during dnd migration.
2024-11-08 16:42:39 +11:00
psychedelicious
2d08078a7d fix(ui): fit bbox to layers math 2024-11-08 16:40:24 +11:00
psychedelicious
75acece1f1 fix(ui): excessive toasts when generating on canvas
- Add `withToast` flag to `uploadImage` util
- Skip the toast if this is not set
- Use the flag to disable toasts when canvas does internal image-uploading stuff that should be invisible to user
2024-11-08 10:30:04 +11:00
psychedelicious
a9db2ffefd fix(ui): ensure clip vision model is set correctly for FLUX IP Adapters 2024-11-08 10:02:41 +11:00
psychedelicious
cdd148b4d1 feat(ui): add toast for graph building errors 2024-11-08 10:02:41 +11:00
psychedelicious
730fabe2de feat(ui): add util to extract message from a tsafe AssertionError 2024-11-08 10:02:41 +11:00
psychedelicious
6c59790a7f chore: bump version to v5.4.1rc2 2024-11-08 10:00:20 +11:00
Ryan Dick
0e6cb91863 Update SD3 InpaintExtension with gradient adjustment to match FLUX. 2024-11-07 22:55:30 +00:00
Ryan Dick
a0fefcd43f Switch to using a custom scheduler implementation for SD3 rather than the diffusers FlowMatchEulerDiscreteScheduler. It is easier to work with and enables us to re-use the clip_timestep_schedule_fractional() utility from FLUX. 2024-11-07 22:46:52 +00:00
psychedelicious
c37251d6f7 tweak(ui): workflow linear field styling 2024-11-08 07:39:09 +11:00
psychedelicious
2854210162 fix(ui): dnd autoscroll on elements w/ custom scrollbar
Have to do a bit of fanagling to get it to work and get `pragmatic-drag-and-drop` to not complain.
2024-11-08 07:39:09 +11:00
psychedelicious
5545b980af fix(ui): workflow field sorting doesn't use unique identifier for fields 2024-11-08 07:39:09 +11:00
psychedelicious
0c9434c464 chore(ui): lint 2024-11-08 07:39:09 +11:00
psychedelicious
8771de917d feat(ui): migrate fullscreen drop zone to pdnd 2024-11-08 07:39:09 +11:00
psychedelicious
122946ef4c feat(ui): DndDropOverlay supports react node for label 2024-11-08 07:39:09 +11:00
psychedelicious
2d974f670c feat(ui): restore missing upload buttons 2024-11-08 07:39:09 +11:00
psychedelicious
75f0da9c35 fix(ui): use revised uploader for CL empty state 2024-11-08 07:39:09 +11:00
psychedelicious
5df3c00e28 feat(ui): remove SerializableObject, use type-fest's JsonObject 2024-11-08 07:39:09 +11:00
psychedelicious
b049880502 fix(ui): uploads initiated from canvas 2024-11-08 07:39:09 +11:00
psychedelicious
e5293fdd1a fix(ui): match new default controlnet behaviour 2024-11-08 07:39:09 +11:00
psychedelicious
8883775762 feat(ui): rework image uploads (wip) 2024-11-08 07:39:09 +11:00
psychedelicious
cfadb313d2 fix(ui): ts issues 2024-11-08 07:39:09 +11:00
psychedelicious
b5cadd9a1a fix(ui): scroll issue w/ boards list 2024-11-08 07:39:09 +11:00
psychedelicious
5361b6e014 refactor(ui): image actions sep of concerns 2024-11-08 07:39:09 +11:00
psychedelicious
ff346172af feat(ui): use new image actions system for image menu 2024-11-08 07:39:09 +11:00
psychedelicious
92f660018b refactor(ui): dnd actions to image actions
We don't need a "dnd" image system. We need a "image action" system. We need to execute specific flows with images from various "origins":
- internal dnd e.g. from gallery
- external dnd e.g. user drags an image file into the browser
- direct file upload e.g. user clicks an upload button
- some other internal app button e.g. a context menu

The actions are now generalized to better support these various use-cases.
2024-11-08 07:39:09 +11:00
psychedelicious
1afc2cba4e feat(ui): support different labels for external drop targets (e.g. uploads) 2024-11-08 07:39:09 +11:00
psychedelicious
ee8359242c feat(ui): more dnd cleanup and tidy 2024-11-08 07:39:09 +11:00
psychedelicious
f0c80a8d7a tidy(ui): dnd stuff 2024-11-08 07:39:09 +11:00
psychedelicious
8da9e7c1f6 fix(ui): min height for workflow image field drop target 2024-11-08 07:39:09 +11:00
psychedelicious
6d7a486e5b feat(ui): restore dnd to workflow fields 2024-11-08 07:39:09 +11:00
psychedelicious
57122c6aa3 feat(ui): layer reordering styling 2024-11-08 07:39:09 +11:00
psychedelicious
54abd8d4d1 feat(ui): dnd layer reordering (wip) 2024-11-08 07:39:09 +11:00
psychedelicious
06283cffed feat(ui): use custom drag previews for images 2024-11-08 07:39:09 +11:00
psychedelicious
27fa0e1140 tidy(ui): more efficient dnd overlay styling 2024-11-08 07:39:09 +11:00
psychedelicious
533d48abdb feat(ui): multi-image drag preview 2024-11-08 07:39:09 +11:00
psychedelicious
6845cae4c9 tidy(ui): move new dnd impl into features/dnd 2024-11-08 07:39:09 +11:00
psychedelicious
31c9acb1fa tidy(ui): clean up old dnd stuff 2024-11-08 07:39:09 +11:00
psychedelicious
fb5e462300 tidy(ui): document & clean up dnd 2024-11-08 07:39:09 +11:00
psychedelicious
2f3abc29b1 feat(ui): better types for getData 2024-11-08 07:39:09 +11:00
psychedelicious
c5c071f285 feat(ui): better type name 2024-11-08 07:39:09 +11:00
psychedelicious
93a3ed56e7 feat(ui): simpler dnd typing implementation 2024-11-08 07:39:09 +11:00
psychedelicious
406fc58889 feat(ui): migrate to pragmatic-drag-and-drop (wip 4) 2024-11-08 07:39:09 +11:00
psychedelicious
cf67d084fd feat(ui): migrate to pragmatic-drag-and-drop (wip 3) 2024-11-08 07:39:09 +11:00
psychedelicious
d4a95af14f perf(ui): more gallery perf improvements 2024-11-08 07:39:09 +11:00
psychedelicious
8c8e7102c2 perf(ui): improved gallery perf 2024-11-08 07:39:09 +11:00
psychedelicious
b6b9ea9d70 feat(ui): migrate to pragmatic-drag-and-drop (wip 2) 2024-11-08 07:39:09 +11:00
psychedelicious
63126950bc feat(ui): migrate to pragmatic-drag-and-drop (wip) 2024-11-08 07:39:09 +11:00
psychedelicious
29d63d5dea fix(app): silence pydantic protected namespace warning
Closes #7287
2024-11-08 07:36:50 +11:00
Ryan Dick
a5f8c23dee Add inpainting support for SD3. 2024-11-07 20:21:43 +00:00
Ryan Dick
7bb4ea57c6 Add SD3ImageToLatentsInvocation. 2024-11-07 16:07:57 +00:00
Ryan Dick
75dc961bcb Add image-to-image support for SD3 - WIP. 2024-11-07 15:48:35 +00:00
Vargol
a9a1f6ef21 Merge branch 'invoke-ai:main' into flux_vae_fp16_broke 2024-11-07 14:02:51 +00:00
Jonathan
aa40161f26 Update flux_denoise.py
Added a bool to allow the node user to add noise in to initial latents (default) or to leave them alone.
2024-11-07 14:02:20 +00:00
psychedelicious
6efa812874 chore(ui): bump version to v5.4.1rc1 2024-11-07 14:02:20 +00:00
psychedelicious
8a683f5a3c feat(ui): updated whats new handling and v5.4.1 items 2024-11-07 14:02:20 +00:00
Brandon Rising
f4b0b6a93d fix: Look in known subfolders for configs for clip variants 2024-11-07 14:02:20 +00:00
Brandon Rising
1337c33ad3 fix: Avoid downloading unsafe .bin files if a safetensors file is available 2024-11-07 14:02:20 +00:00
Jonathan
2f6b035138 Update flux_denoise.py
Added a bool to allow the node user to add noise in to initial latents (default) or to leave them alone.
2024-11-07 08:44:10 -05:00
psychedelicious
4f9ae44472 chore(ui): bump version to v5.4.1rc1 2024-11-07 12:19:28 +11:00
psychedelicious
c682330852 feat(ui): updated whats new handling and v5.4.1 items 2024-11-07 12:19:28 +11:00
Brandon Rising
c064257759 fix: Look in known subfolders for configs for clip variants 2024-11-07 12:01:02 +11:00
Brandon Rising
8a4c629576 fix: Avoid downloading unsafe .bin files if a safetensors file is available 2024-11-06 19:31:18 -05:00
David Burnett
496b02a3bc Same issue affects image2image, so do the same again 2024-11-06 17:47:22 -05:00
David Burnett
7b5efc2203 Flux Vae broke for float16, force bfloat16 or float32 were compatible 2024-11-06 17:47:22 -05:00
241 changed files with 8889 additions and 6810 deletions

View File

@@ -751,7 +751,7 @@ async def convert_model(
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
converted_model = loader.load_model(model_config, queue_id="default")
converted_model = loader.load_model(model_config)
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")

View File

@@ -31,7 +31,6 @@ from invokeai.app.services.events.events_common import (
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadEventBase,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
@@ -54,13 +53,6 @@ class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class ModelLoadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io model loading room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationProgressEvent,
@@ -77,6 +69,8 @@ MODEL_EVENTS = {
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
@@ -85,11 +79,6 @@ MODEL_EVENTS = {
ModelInstallErrorEvent,
}
MODEL_LOAD_EVENTS = {
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
@@ -112,7 +101,6 @@ class SocketIO:
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(MODEL_LOAD_EVENTS, self._handle_model_load_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
@@ -127,18 +115,9 @@ class SocketIO:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_sub_model_load(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
async def _handle_unsub_model_load(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_model_load_event(self, event: FastAPIEvent[ModelLoadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))

View File

@@ -63,12 +63,12 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer, queue_id=context.util.get_queue_id())
text_encoder_info = context.models.load(self.clip.text_encoder, queue_id=context.util.get_queue_id())
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora, queue_id=context.util.get_queue_id())
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@@ -95,6 +95,7 @@ class CompelInvocation(BaseInvocation):
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
@@ -137,8 +138,8 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer, queue_id=context.util.get_queue_id())
text_encoder_info = context.models.load(clip_field.text_encoder, queue_id=context.util.get_queue_id())
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
# return zero on empty
if prompt == "" and zero_on_empty:
@@ -163,7 +164,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
@@ -191,6 +192,7 @@ class SDXLPromptInvocationBase:
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)

View File

@@ -649,9 +649,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size],
queue_id=self._context.util.get_queue_id(),
loader=load_depth_anything,
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)

View File

@@ -60,11 +60,12 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)

View File

@@ -124,13 +124,14 @@ class CreateGradientMaskInvocation(BaseInvocation):
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae, context.util.get_queue_id())
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)

View File

@@ -88,7 +88,7 @@ def get_scheduler(
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info, context.util.get_queue_id())
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -435,9 +435,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
control_image_field = control_info.image
@@ -494,9 +492,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
for control_info in control_list:
model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
model = exit_stack.enter_context(context.models.load(control_info.control_model))
ext_manager.add_extension(
ControlNetExt(
model=model,
@@ -549,13 +545,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(
single_ip_adapter.ip_adapter_model, context.util.get_queue_id()
) as ip_adapter_model:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(
single_ip_adapter.image_encoder_model, context.util.get_queue_id()
)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@@ -589,9 +581,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id())
)
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
@@ -631,9 +621,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(
t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id()
)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -938,7 +926,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (cached_weights, unet),
@@ -1001,13 +989,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,

View File

@@ -35,9 +35,7 @@ class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithB
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model
)
loaded_model = context.models.load_remote_model(model_url, DepthAnythingPipeline.load_model)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)

View File

@@ -29,10 +29,10 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
loaded_session_det = context.models.load_local_model(
onnx_det_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
)
loaded_session_pose = context.models.load_local_model(
onnx_pose_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
)
with loaded_session_det as session_det, loaded_session_pose as session_pose:

View File

@@ -56,7 +56,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.2.0",
version="3.2.1",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -81,6 +81,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
@@ -183,7 +184,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed,
)
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
@@ -207,9 +208,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"to be poor. Consider using a FLUX dev model instead."
)
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
if self.add_noise:
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
x = init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
@@ -468,9 +472,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [
context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets
]
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
@@ -481,7 +483,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae, context.util.get_queue_id())
vae_info = context.models.load(self.controlnet_vae.vae)
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
@@ -592,9 +594,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_images.append(pos_image)
neg_images.append(neg_image)
with context.models.load(
ip_adapter_field.image_encoder_model, context.util.get_queue_id()
) as image_encoder_model:
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
@@ -624,9 +624,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
):
ip_adapter_model = exit_stack.enter_context(
context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id())
)
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
if ip_adapter_field.mask is not None:
@@ -655,7 +653,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -57,8 +57,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
return FluxConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
@@ -71,14 +71,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(self.clip.text_encoder, context.util.get_queue_id())
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
@@ -111,6 +112,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
@@ -118,7 +120,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
img = vae.decode(latents)
img = img.clamp(-1, 1)
@@ -52,7 +53,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
vae_info = context.models.load(self.vae.vae)
context.util.signal_progress("Running VAE")
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()

View File

@@ -44,9 +44,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents
@@ -54,12 +53,13 @@ class FluxVaeEncodeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")

View File

@@ -94,9 +94,7 @@ class GroundingDinoInvocation(BaseInvocation):
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=GroundingDinoInvocation._load_grounding_dino,
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@@ -22,9 +22,7 @@ class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(
HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model
)
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
with loaded_model as model:
assert isinstance(model, ControlNetHED_Apache2)

View File

@@ -111,12 +111,13 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)

View File

@@ -36,7 +36,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
@abstractmethod
def infill(self, image: Image.Image, queue_id: str) -> Image.Image:
def infill(self, image: Image.Image) -> Image.Image:
"""Infill the image with the specified method"""
pass
@@ -56,7 +56,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(context.images.get_dto(self.image.image_name))
# Perform Infill action
infilled_image = self.infill(input_image, context.util.get_queue_id())
infilled_image = self.infill(input_image)
# Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image)
@@ -74,7 +74,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
description="The color to use to infill",
)
def infill(self, image: Image.Image, queue_id: str):
def infill(self, image: Image.Image):
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
@@ -93,7 +93,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
description="The seed to use for tile generation (omit for random)",
)
def infill(self, image: Image.Image, queue_id: str):
def infill(self, image: Image.Image):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled
@@ -107,7 +107,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def infill(self, image: Image.Image, queue_id: str):
def infill(self, image: Image.Image):
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width / self.downscale)
@@ -131,10 +131,9 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image, queue_id: str):
def infill(self, image: Image.Image):
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
queue_id=queue_id,
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
@@ -145,7 +144,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
class CV2InfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image, queue_id: str):
def infill(self, image: Image.Image):
return cv2_inpaint(image)
@@ -167,5 +166,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
description="The max threshold for color",
)
def infill(self, image: Image.Image, queue_id: str):
def infill(self, image: Image.Image):
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())

View File

@@ -57,9 +57,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:

View File

@@ -23,9 +23,7 @@ class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartEdgeDetector.get_model_url(self.coarse)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model
)
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
with loaded_model as model:
assert isinstance(model, Generator)

View File

@@ -20,9 +20,7 @@ class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoar
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartAnimeEdgeDetector.get_model_url()
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model
)
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
with loaded_model as model:
assert isinstance(model, UnetGenerator)

View File

@@ -147,6 +147,10 @@ GENERATION_MODES = Literal[
"flux_img2img",
"flux_inpaint",
"flux_outpaint",
"sd3_txt2img",
"sd3_img2img",
"sd3_inpaint",
"sd3_outpaint",
]

View File

@@ -28,9 +28,7 @@ class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(
MLSDDetector.get_model_url(), context.util.get_queue_id(), MLSDDetector.load_model
)
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(), MLSDDetector.load_model)
with loaded_model as model:
assert isinstance(model, MobileV2_MLSD_Large)

View File

@@ -20,9 +20,7 @@ class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(
NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model
)
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
with loaded_model as model:
assert isinstance(model, NNET)

View File

@@ -22,9 +22,7 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(
PIDINetDetector.get_model_url(), context.util.get_queue_id(), PIDINetDetector.load_model
)
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
with loaded_model as model:
assert isinstance(model, PiDiNet)

View File

@@ -1,16 +1,19 @@
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
SD3ConditioningField,
WithBoard,
WithMetadata,
@@ -19,7 +22,9 @@ from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -30,16 +35,24 @@ from invokeai.backend.util.devices import TorchDevice
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
description=FieldDescriptions.sd3_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
@@ -61,6 +74,41 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
represent the regions to be preserved.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
# 1.0 represents the regions to be preserved.
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _load_text_conditioning(
self,
context: InvocationContext,
@@ -147,7 +195,7 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
transformer_info = context.models.load(self.transformer.transformer)
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
@@ -170,14 +218,20 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the timestep schedule.
# We add an extra step to the end to account for the final timestep of 0.0.
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
total_steps = len(timesteps) - 1
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
cfg_scale = self._prepare_cfg_scale(total_steps)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
@@ -191,9 +245,34 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
# Prepare input latent image.
if init_latents is not None:
# Noise the init_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
latents = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return latents
# Prepare inpaint extension.
inpaint_mask = self._prep_inpaint_mask(context, latents)
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
step_callback = self._build_step_callback(context)
step_callback(
@@ -210,11 +289,12 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
@@ -232,21 +312,19 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
latents = latents.to(dtype=torch.float32)
latents = latents + (t_prev - t_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if inpaint_extension is not None:
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
timestep=int(t_curr),
latents=latents,
),
)

View File

@@ -0,0 +1,65 @@
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"sd3_i2l",
title="SD3 Image to Latents",
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image."""
image: ImageField = InputField(description="The image to encode")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
latents = vae.config.scaling_factor * latents
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -44,9 +44,10 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)

View File

@@ -86,8 +86,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
@@ -95,6 +95,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
@@ -127,8 +128,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(clip_model.text_encoder, context.util.get_queue_id())
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
@@ -137,6 +138,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
context.util.signal_progress("Running CLIP encoder")
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
@@ -193,7 +195,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -125,9 +125,7 @@ class SegmentAnythingInvocation(BaseInvocation):
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=SegmentAnythingInvocation._load_sam_model,
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)

View File

@@ -158,7 +158,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
spandrel_model_info = context.models.load(self.image_to_image_model)
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
@@ -207,7 +207,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
spandrel_model_info = context.models.load(self.image_to_image_model)
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.

View File

@@ -196,13 +196,13 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
unet_info = context.models.load(self.unet.unet)
with (
ExitStack() as exit_stack,

View File

@@ -90,7 +90,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(msg)
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name], queue_id=context.util.get_queue_id()
source=ESRGAN_MODEL_URLS[self.model_name],
)
with loadnet as loadnet_model:

View File

@@ -131,17 +131,15 @@ class EventServiceBase:
# region Model loading
def emit_model_load_started(
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
) -> None:
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, queue_id, submodel_type))
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id, submodel_type))
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
# endregion

View File

@@ -383,14 +383,12 @@ class DownloadErrorEvent(DownloadEventBase):
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelLoadEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelLoadEventBase):
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
@@ -399,14 +397,12 @@ class ModelLoadStartedEvent(ModelLoadEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadStartedEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelLoadEventBase):
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
@@ -415,14 +411,8 @@ class ModelLoadCompleteEvent(ModelLoadEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadCompleteEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
class ModelEventBase(EventBase):
"""Base class for model events"""
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register

View File

@@ -14,9 +14,7 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -31,7 +29,7 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def load_model_from_path(
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.

View File

@@ -49,9 +49,7 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -62,7 +60,7 @@ class ModelLoadService(ModelLoadServiceBase):
# We don't have an invoker during testing
# TODO(psyche): Mock this method on the invoker in the tests
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, queue_id, submodel_type)
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
@@ -72,12 +70,12 @@ class ModelLoadService(ModelLoadServiceBase):
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_complete(model_config, queue_id, submodel_type)
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
return loaded_model
def load_model_from_path(
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache

View File

@@ -160,6 +160,10 @@ class LoggerInterface(InvocationContextInterface):
class ImagesInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def save(
self,
image: Image,
@@ -186,6 +190,8 @@ class ImagesInterface(InvocationContextInterface):
The saved image DTO.
"""
self._util.signal_progress("Saving image")
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
@@ -336,6 +342,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Check if a model exists.
@@ -351,10 +361,7 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.store.exists(identifier.key)
def load(
self,
identifier: Union[str, "ModelIdentifierField"],
queue_id: str,
submodel_type: Optional[SubModelType] = None,
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Load a model.
@@ -371,19 +378,18 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, queue_id, submodel_type)
else:
_submodel_type = submodel_type or identifier.submodel_type
submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, queue_id, _submodel_type)
message = f"Loading model {model.name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(model, submodel_type)
def load_by_attrs(
self,
name: str,
base: BaseModelType,
type: ModelType,
queue_id: str,
submodel_type: Optional[SubModelType] = None,
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Load a model by its attributes.
@@ -405,7 +411,11 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], queue_id, submodel_type)
message = f"Loading model {name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Get a model's config.
@@ -475,12 +485,12 @@ class ModelsInterface(InvocationContextInterface):
Returns:
Path to the downloaded model
"""
self._util.signal_progress(f"Downloading model {source}")
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
self,
model_path: Path,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@@ -498,14 +508,13 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
self._util.signal_progress(f"Loading model {model_path.name}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
source: str | AnyHttpUrl,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@@ -526,9 +535,9 @@ class ModelsInterface(InvocationContextInterface):
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface):
@@ -549,14 +558,6 @@ class UtilInterface(InvocationContextInterface):
super().__init__(services, data)
self._is_canceled = is_canceled
def get_queue_id(self) -> str:
"""Checks if the current session has been canceled.
Returns:
True if the current session has been canceled, False if not.
"""
return self._data.queue_item.queue_id
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.
@@ -729,12 +730,12 @@ def build_invocation_context(
"""
logger = LoggerInterface(services=services, data=data)
images = ImagesInterface(services=services, data=data)
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data, util=util)
images = ImagesInterface(services=services, data=data, util=util)
boards = BoardsInterface(services=services, data=data)
ctx = InvocationContext(

View File

@@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(name_or_key, queue_id=context.util.get_queue_id())
loaded_model = context.models.load(name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
@@ -30,7 +30,7 @@ def generate_ti_list(
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
name=name_or_key, base=base, type=ModelType.TextualInversion, queue_id=context.util.get_queue_id()
name=name_or_key, base=base, type=ModelType.TextualInversion
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)

View File

@@ -45,8 +45,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
# Constants for FLUX.1
num_double_layers = 19
num_single_layers = 38
# inner_dim = 3072
# mlp_ratio = 4.0
hidden_size = 3072
mlp_ratio = 4.0
mlp_hidden_dim = int(hidden_size * mlp_ratio)
layers: dict[str, AnyLoRALayer] = {}
@@ -62,30 +63,43 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
def add_qkv_lora_layer_if_present(
src_keys: list[str],
src_weight_shapes: list[tuple[int, int]],
dst_qkv_key: str,
allow_missing_keys: bool = False,
) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
"""
# We expect that either all src keys are present or none of them are. Verify this.
keys_present = [key in grouped_state_dict for key in src_keys]
assert all(keys_present) or not any(keys_present)
# If none of the keys are present, return early.
keys_present = [key in grouped_state_dict for key in src_keys]
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayer] = []
for src_layer_dict in src_layer_dicts:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
src_layer_dict = grouped_state_dict.pop(src_key, None)
if src_layer_dict is not None:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
else:
if not allow_missing_keys:
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
values = {
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
}
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
@@ -118,6 +132,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"transformer_blocks.{i}.attn.to_k",
f"transformer_blocks.{i}.attn.to_v",
],
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
f"double_blocks.{i}.img_attn.qkv",
)
add_qkv_lora_layer_if_present(
@@ -126,6 +141,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"transformer_blocks.{i}.attn.add_k_proj",
f"transformer_blocks.{i}.attn.add_v_proj",
],
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
f"double_blocks.{i}.txt_attn.qkv",
)
@@ -175,7 +191,14 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"single_transformer_blocks.{i}.attn.to_v",
f"single_transformer_blocks.{i}.proj_mlp",
],
[
(hidden_size, hidden_size),
(hidden_size, hidden_size),
(hidden_size, hidden_size),
(mlp_hidden_dim, hidden_size),
],
f"single_blocks.{i}.linear1",
allow_missing_keys=True,
)
# Output projections.

View File

@@ -165,6 +165,8 @@ class SubmodelDefinition(BaseModel):
model_type: ModelType
variant: AnyVariant = None
model_config = ConfigDict(protected_namespaces=())
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")

View File

@@ -35,6 +35,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
self._torch_device = TorchDevice.choose_torch_device()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@@ -84,7 +84,15 @@ class FluxVAELoader(ModelLoader):
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
return model

View File

@@ -300,7 +300,7 @@ ip_adapter_sdxl = StarterModel(
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter)",
base=BaseModelType.Flux,
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors",
description="References images with a more generalized/looser degree of precision.",
type=ModelType.IPAdapter,
dependencies=[clip_vit_l_image_encoder],

View File

@@ -172,6 +172,8 @@ def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
try:
path = Path(location)
config_path = path / "config.json"
if not config_path.exists():
config_path = path / "text_encoder" / "config.json"
if not config_path.exists():
return ClipVariantType.L
with open(config_path) as file:

View File

@@ -85,6 +85,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result: set[Path] = set()
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
safetensors_detected = False
for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX:
@@ -119,10 +120,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
if safetensors_detected and path.suffix == ".bin":
continue
parent = path.parent
score = 0
if path.suffix == ".safetensors":
safetensors_detected = True
if parent in subfolder_weights:
subfolder_weights[parent] = [sfc for sfc in subfolder_weights[parent] if sfc.path.suffix != ".bin"]
score += 1
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None

View File

View File

@@ -0,0 +1,58 @@
import torch
class InpaintExtension:
"""A class for managing inpainting with SD3."""
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
"""Initialize InpaintExtension.
Args:
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0).
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
inpainted region with the background.
noise (torch.Tensor): The noise tensor used to noise the init_latents.
"""
assert init_latents.dim() == inpaint_mask.dim() == noise.dim() == 4
assert init_latents.shape[-2:] == inpaint_mask.shape[-2:] == noise.shape[-2:]
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small)
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff
# threshold worked well.
# We use a small epsilon to avoid any potential issues with floating point precision.
eps = 1e-4
mask_gradient_t_cutoff = 0.5
if t_prev > mask_gradient_t_cutoff:
# Early in the denoising process, use the inpaint mask as-is.
return self._inpaint_mask
else:
# After the cut-off, promote all non-zero mask values to 1.0.
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0)
return mask
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, t_prev: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
trajectory.
This function should be called after each denoising step.
"""
mask = self._apply_mask_gradient_adjustment(t_prev)
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)

View File

@@ -29,7 +29,7 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()).model
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,

View File

@@ -54,7 +54,7 @@ class T2IAdapterExt(ExtensionBase):
@callback(ExtensionCallbackType.SETUP)
def setup(self, ctx: DenoiseContext):
t2i_model: T2IAdapter
with self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()) as t2i_model:
with self._node_context.models.load(self._model_id) as t2i_model:
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
self._adapter_state = self._run_model(

View File

@@ -52,11 +52,11 @@
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
"@dagrejs/dagre": "^1.1.4",
"@dagrejs/graphlib": "^2.2.4",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.1.0",
"@invoke-ai/ui-library": "^0.0.43",
"@nanostores/react": "^0.7.3",

View File

@@ -5,21 +5,21 @@ settings:
excludeLinksFromLockfile: false
dependencies:
'@atlaskit/pragmatic-drag-and-drop':
specifier: ^1.4.0
version: 1.4.0
'@atlaskit/pragmatic-drag-and-drop-auto-scroll':
specifier: ^1.4.0
version: 1.4.0
'@atlaskit/pragmatic-drag-and-drop-hitbox':
specifier: ^1.0.3
version: 1.0.3
'@dagrejs/dagre':
specifier: ^1.1.4
version: 1.1.4
'@dagrejs/graphlib':
specifier: ^2.2.4
version: 2.2.4
'@dnd-kit/core':
specifier: ^6.1.0
version: 6.1.0(react-dom@18.3.1)(react@18.3.1)
'@dnd-kit/sortable':
specifier: ^8.0.0
version: 8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1)
'@dnd-kit/utilities':
specifier: ^3.2.2
version: 3.2.2(react@18.3.1)
'@fontsource-variable/inter':
specifier: ^5.1.0
version: 5.1.0
@@ -319,6 +319,28 @@ packages:
'@jridgewell/trace-mapping': 0.3.25
dev: true
/@atlaskit/pragmatic-drag-and-drop-auto-scroll@1.4.0:
resolution: {integrity: sha512-5GoikoTSW13UX76F9TDeWB8x3jbbGlp/Y+3aRkHe1MOBMkrWkwNpJ42MIVhhX/6NSeaZiPumP0KbGJVs2tOWSQ==}
dependencies:
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
'@babel/runtime': 7.25.7
dev: false
/@atlaskit/pragmatic-drag-and-drop-hitbox@1.0.3:
resolution: {integrity: sha512-/Sbu/HqN2VGLYBhnsG7SbRNg98XKkbF6L7XDdBi+izRybfaK1FeMfodPpm/xnBHPJzwYMdkE0qtLyv6afhgMUA==}
dependencies:
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
'@babel/runtime': 7.25.7
dev: false
/@atlaskit/pragmatic-drag-and-drop@1.4.0:
resolution: {integrity: sha512-qRY3PTJIcxfl/QB8Gwswz+BRvlmgAC5pB+J2hL6dkIxgqAgVwOhAamMUKsrOcFU/axG2Q7RbNs1xfoLKDuhoPg==}
dependencies:
'@babel/runtime': 7.25.7
bind-event-listener: 3.0.0
raf-schd: 4.0.3
dev: false
/@babel/code-frame@7.25.7:
resolution: {integrity: sha512-0xZJFNE5XMpENsgfHYTw8FbX4kv53mFLn2i3XPoq69LyhYSCBJtitaHx9QnsVTrsogI4Z3+HtEfZ2/GFPOtf5g==}
engines: {node: '>=6.9.0'}
@@ -980,49 +1002,6 @@ packages:
engines: {node: '>17.0.0'}
dev: false
/@dnd-kit/accessibility@3.1.0(react@18.3.1):
resolution: {integrity: sha512-ea7IkhKvlJUv9iSHJOnxinBcoOI3ppGnnL+VDJ75O45Nss6HtZd8IdN8touXPDtASfeI2T2LImb8VOZcL47wjQ==}
peerDependencies:
react: '>=16.8.0'
dependencies:
react: 18.3.1
tslib: 2.7.0
dev: false
/@dnd-kit/core@6.1.0(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-J3cQBClB4TVxwGo3KEjssGEXNJqGVWx17aRTZ1ob0FliR5IjYgTxl5YJbKTzA6IzrtelotH19v6y7uoIRUZPSg==}
peerDependencies:
react: '>=16.8.0'
react-dom: '>=16.8.0'
dependencies:
'@dnd-kit/accessibility': 3.1.0(react@18.3.1)
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
tslib: 2.7.0
dev: false
/@dnd-kit/sortable@8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1):
resolution: {integrity: sha512-U3jk5ebVXe1Lr7c2wU7SBZjcWdQP+j7peHJfCspnA81enlu88Mgd7CC8Q+pub9ubP7eKVETzJW+IBAhsqbSu/g==}
peerDependencies:
'@dnd-kit/core': ^6.1.0
react: '>=16.8.0'
dependencies:
'@dnd-kit/core': 6.1.0(react-dom@18.3.1)(react@18.3.1)
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
react: 18.3.1
tslib: 2.7.0
dev: false
/@dnd-kit/utilities@3.2.2(react@18.3.1):
resolution: {integrity: sha512-+MKAJEOfaBe5SmV6t34p80MMKhjvUz0vRrvVJbPT0WElzaOJ/1xs+D+KDv+tD/NE5ujfrChEcshd4fLn0wpiqg==}
peerDependencies:
react: '>=16.8.0'
dependencies:
react: 18.3.1
tslib: 2.7.0
dev: false
/@emotion/babel-plugin@11.12.0:
resolution: {integrity: sha512-y2WQb+oP8Jqvvclh8Q55gLUyb7UFvgv7eJfsj7td5TToBrIUtPay2kMrZi4xjq9qw2vD0ZR5fSho0yqoFgX7Rw==}
dependencies:
@@ -4313,6 +4292,10 @@ packages:
open: 8.4.2
dev: true
/bind-event-listener@3.0.0:
resolution: {integrity: sha512-PJvH288AWQhKs2v9zyfYdPzlPqf5bXbGMmhmUIY9x4dAUGIWgomO771oBQNwJnMQSnUIXhKu6sgzpBRXTlvb8Q==}
dev: false
/bl@4.1.0:
resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==}
dependencies:
@@ -7557,6 +7540,10 @@ packages:
resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==}
dev: true
/raf-schd@4.0.3:
resolution: {integrity: sha512-tQkJl2GRWh83ui2DiPTJz9wEiMN20syf+5oKfB03yYP7ioZcJwsIK8FjrtLwH1m7C7e+Tt2yYBlrOpdT+dyeIQ==}
dev: false
/raf-throttle@2.0.6:
resolution: {integrity: sha512-C7W6hy78A+vMmk5a/B6C5szjBHrUzWJkVyakjKCK59Uy2CcA7KhO1JUvvH32IXYFIcyJ3FMKP3ZzCc2/71I6Vg==}
dev: false

View File

@@ -174,7 +174,8 @@
"placeholderSelectAModel": "Select a model",
"reset": "Reset",
"none": "None",
"new": "New"
"new": "New",
"generating": "Generating"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -704,6 +705,8 @@
"baseModel": "Base Model",
"cancel": "Cancel",
"clipEmbed": "CLIP Embed",
"clipLEmbed": "CLIP-L Embed",
"clipGEmbed": "CLIP-G Embed",
"config": "Config",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
@@ -997,7 +1000,7 @@
"controlNetControlMode": "Control Mode",
"copyImage": "Copy Image",
"denoisingStrength": "Denoising Strength",
"noRasterLayers": "No Raster Layers",
"disabledNoRasterContent": "Disabled (No Raster Content)",
"downloadImage": "Download Image",
"general": "General",
"guidance": "Guidance",
@@ -1137,6 +1140,7 @@
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"showDetailedInvocationProgress": "Show Progress Details",
"showProgressInViewer": "Show Progress Images in Viewer",
"ui": "User Interface",
"clearIntermediatesDisabled": "Queue must be empty to clear intermediates",
@@ -1671,7 +1675,7 @@
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Masked Regions",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -1787,7 +1791,7 @@
},
"ipAdapterMethod": {
"ipAdapterMethod": "IP Adapter Method",
"full": "Full",
"full": "Style and Composition",
"style": "Style Only",
"composition": "Composition Only"
},
@@ -1999,7 +2003,9 @@
"upscaleModelDesc": "Upscale (image to image) model",
"missingUpscaleInitialImage": "Missing initial image for upscaling",
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed"
"missingTileControlNetModel": "No valid tile ControlNet models installed",
"incompatibleBaseModel": "Unsupported main model architecture for upscaling",
"incompatibleBaseModelDesc": "Upscaling is supported for SD1.5 and SDXL architecture models only. Change the main model to enable upscaling."
},
"stylePresets": {
"active": "Active",
@@ -2102,8 +2108,10 @@
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"line1": "<StrongComponent>Layer Merging</StrongComponent>: New <StrongComponent>Merge Down</StrongComponent> and improved <StrongComponent>Merge Visible</StrongComponent> for all layers, with special handling for Regional Guidance and Control Layers.",
"line2": "<StrongComponent>HF Token Support</StrongComponent>: Upload models that require Hugging Face authentication.",
"items": [
"<StrongComponent>SD 3.5</StrongComponent>: Support for Text-to-Image in Workflows with SD 3.5 Medium and Large.",
"<StrongComponent>Canvas</StrongComponent>: Streamlined Control Layer processing and improved default Control settings."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -8,10 +8,8 @@ import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import {
@@ -19,6 +17,7 @@ import {
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
@@ -62,8 +61,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
const handleReset = useCallback(() => {
clearStorage();
location.reload();
@@ -92,19 +89,8 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box
id="invoke-app-wrapper"
w="100dvw"
h="100dvh"
position="relative"
overflow="hidden"
{...dropzone.getRootProps()}
>
<input {...dropzone.getInputProps()} />
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{dropzone.isDragActive && isHandlingUpload && (
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
)}
</Box>
<DeleteImageModal />
<ChangeBoardModal />
@@ -121,6 +107,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<NewGallerySessionDialog />
<NewCanvasSessionDialog />
<ImageContextMenu />
<FullscreenDropzone />
</ErrorBoundary>
);
};

View File

@@ -1,4 +1,3 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
@@ -8,13 +7,11 @@ import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const GlobalImageHotkeys = memo(() => {
useAssertSingleton('GlobalImageHotkeys');
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const { currentData: imageDTO } = useGetImageDTOQuery(lastSelectedImage?.image_name ?? skipToken);
const imageDTO = useAppSelector(selectLastSelectedImage);
if (!imageDTO) {
return null;

View File

@@ -19,7 +19,6 @@ import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
import { createStore } from 'app/store/store';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
@@ -237,9 +236,7 @@ const InvokeAIUI = ({
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<AppDndContext>
<App config={config} studioInitAction={studioInitAction} />
</AppDndContext>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
</React.Suspense>
</Provider>

View File

@@ -17,6 +17,7 @@ const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export const zLogNamespace = z.enum([
'canvas',
'config',
'dnd',
'events',
'gallery',
'generation',

View File

@@ -1,4 +1,3 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
/** @knipignore */
export const EMPTY_OBJECT = {};

View File

@@ -16,7 +16,6 @@ import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMi
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageDeletionListeners } from 'app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners';
import { addImageDroppedListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImagesStarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesStarred';
import { addImagesUnstarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesUnstarred';
@@ -93,9 +92,6 @@ addGetOpenAPISchemaListener(startAppListening);
addWorkflowLoadRequestedListener(startAppListening);
addUpdateAllNodesRequestedListener(startAppListening);
// DND
addImageDroppedListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);

View File

@@ -1,12 +1,12 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('queue');
@@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
log.debug({ enqueueResult } as JsonObject, t('queue.graphQueued'));
} catch (error) {
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
log.error({ enqueueBatchArg } as JsonObject, t('queue.graphFailedToQueue'));
if (error instanceof Object && 'status' in error && error.status === 403) {
return;

View File

@@ -1,12 +1,12 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';
const log = logger('queue');
@@ -17,7 +17,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
effect: (action) => {
const enqueueResult = action.payload;
const arg = action.meta.arg.originalArgs;
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
log.debug({ enqueueResult } as JsonObject, 'Batch enqueued');
toast({
id: 'QUEUE_BATCH_SUCCEEDED',
@@ -45,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
status: 'error',
description: t('common.unknownError'),
});
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
log.error({ batchConfig } as JsonObject, t('queue.batchFailedToQueue'));
return;
}
@@ -71,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
description: t('common.unknownError'),
});
}
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
log.error({ batchConfig, error: serializeError(response) } as JsonObject, t('queue.batchFailedToQueue'));
},
});
};

View File

@@ -1,19 +1,22 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import type { Result } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { toast } from 'features/toast/toast';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
import { assert, AssertionError } from 'tsafe';
import type { JsonObject } from 'type-fest';
const log = logger('generation');
@@ -32,8 +35,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let buildGraphResult: Result<
{
g: Graph;
noise: Invocation<'noise' | 'flux_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
noise: Invocation<'noise' | 'flux_denoise' | 'sd3_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'sd3_text_encoder'>;
},
Error
>;
@@ -49,6 +52,9 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
case `sd-3`:
buildGraphResult = await withResultAsync(() => buildSD3Graph(state, manager));
break;
case `flux`:
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
break;
@@ -57,7 +63,17 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
}
if (buildGraphResult.isErr()) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status: 'error',
title: 'Failed to build graph',
description,
});
return;
}
@@ -88,7 +104,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return;
}
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
log.debug({ batchConfig: prepareBatchResult.value } as JsonObject, 'Enqueued batch');
},
});
};

View File

@@ -1,12 +1,12 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { JsonObject } from 'type-fest';
const log = logger('system');
@@ -16,12 +16,12 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
effect: (action, { getState }) => {
const schemaJSON = action.payload;
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
log.debug({ schemaJSON: parseify(schemaJSON) } as JsonObject, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
log.debug({ nodeTemplates } as JsonObject, `Built ${size(nodeTemplates)} node templates`);
$templates.set(nodeTemplates);
},

View File

@@ -1,333 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
entityRasterized,
entitySelected,
inpaintMaskAdded,
rasterLayerAdded,
referenceImageAdded,
referenceImageIPAdapterImageChanged,
rgAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { imagesApi } from 'services/api/endpoints/images';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
activeData: TypesafeDraggableData;
}>('dnd/dndDropped');
const log = logger('system');
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: dndDropped,
effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
}
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug({ activeData, overData }, 'Node field dropped');
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
/**
* Image dropped on IP Adapter Layer
*/
if (
overData.actionType === 'SET_IPA_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id } = overData.context;
dispatch(
referenceImageIPAdapterImageChanged({
entityIdentifier: { id, type: 'reference_image' },
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on RG Layer IP Adapter
*/
if (
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id, referenceImageId } = overData.context;
dispatch(
rgIPAdapterImageChanged({
entityIdentifier: { id, type: 'regional_guidance' },
referenceImageId,
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on Raster layer
*/
if (
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
return;
}
/**
/**
* Image dropped on Inpaint Mask
*/
if (
overData.actionType === 'ADD_INPAINT_MASK_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasInpaintMaskState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
return;
}
/**
/**
* Image dropped on Regional Guidance
*/
if (
overData.actionType === 'ADD_REGIONAL_GUIDANCE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRegionalGuidanceState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/
if (
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageObject],
position: { x, y },
controlAdapter: deepClone(initialControlNet),
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
return;
}
if (
overData.actionType === 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasRegionalGuidanceState> = {
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }],
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
if (
overData.actionType === 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasReferenceImageState> = {
ipAdapter,
};
dispatch(referenceImageAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/
if (overData.actionType === 'REPLACE_LAYER_WITH_IMAGE' && activeData.payloadType === 'IMAGE_DTO') {
const state = getState();
const { entityIdentifier } = overData.context;
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
}
/**
* Image dropped on node image field
*/
if (
overData.actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldImageValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image selected for compare
*/
if (
overData.actionType === 'SELECT_FOR_COMPARE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(imageToCompareChanged(imageDTO));
return;
}
/**
* Image dropped on user board
*/
if (
overData.actionType === 'ADD_TO_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO,
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Image dropped on 'none' board
*/
if (
overData.actionType === 'REMOVE_FROM_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Image dropped on upscale initial image
*/
if (
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(upscaleInitialImageChanged(imageDTO));
return;
}
/**
* Multiple images dropped on user board
*/
if (overData.actionType === 'ADD_TO_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
const imageDTOs = getState().gallery.selection;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImagesToBoard.initiate({
imageDTOs,
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Multiple images dropped on 'none' board
*/
if (overData.actionType === 'REMOVE_FROM_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
const imageDTOs = getState().gallery.selection;
dispatch(
imagesApi.endpoints.removeImagesFromBoard.initiate({
imageDTOs,
})
);
dispatch(selectionChanged([]));
return;
}
},
});
};

View File

@@ -1,18 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import {
entityRasterized,
entitySelected,
referenceImageIPAdapterImageChanged,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
@@ -51,12 +41,14 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
log.debug({ imageDTO }, 'Image uploaded');
const { postUploadAction } = action.meta.arg.originalArgs;
if (!postUploadAction) {
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
// like toasts and switching the gallery view
return;
}
const boardId = imageDTO.board_id ?? 'none';
const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'),
@@ -64,80 +56,34 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
} as const;
// default action - just upload and alert user
if (postUploadAction.type === 'TOAST') {
const boardId = imageDTO.board_id ?? 'none';
if (lastUploadedToastTimeout !== null) {
window.clearTimeout(lastUploadedToastTimeout);
}
const toastApi = toast({
...DEFAULT_UPLOADED_TOAST,
title: postUploadAction.title || DEFAULT_UPLOADED_TOAST.title,
description: getUploadedToastDescription(boardId, state),
duration: null, // we will close the toast manually
});
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
* the user's gallery board and view selection:
* - User uploads multiple images
* - A couple uploads finish, but others are pending still
* - User changes the board selection
* - Pending uploads finish and change the board back to the original board
* - User is confused as to why the board changed
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
}
return;
if (lastUploadedToastTimeout !== null) {
window.clearTimeout(lastUploadedToastTimeout);
}
const toastApi = toast({
...DEFAULT_UPLOADED_TOAST,
title: DEFAULT_UPLOADED_TOAST.title,
description: getUploadedToastDescription(boardId, state),
duration: null, // we will close the toast manually
});
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
if (postUploadAction.type === 'SET_UPSCALE_INITIAL_IMAGE') {
dispatch(upscaleInitialImageChanged(imageDTO));
toast({
...DEFAULT_UPLOADED_TOAST,
description: 'set as upscale initial image',
});
return;
}
if (postUploadAction.type === 'SET_IPA_IMAGE') {
const { id } = postUploadAction;
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: { id, type: 'reference_image' }, imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction.type === 'SET_RG_IP_ADAPTER_IMAGE') {
const { id, referenceImageId } = postUploadAction;
dispatch(
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, referenceImageId, imageDTO })
);
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
return;
}
if (postUploadAction.type === 'REPLACE_LAYER_WITH_IMAGE') {
const { entityIdentifier } = postUploadAction;
const state = getState();
const imageObject = imageDTOToImageObject(imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
* the user's gallery board and view selection:
* - User uploads multiple images
* - A couple uploads finish, but others are pending still
* - User changes the board selection
* - Pending uploads finish and change the board back to the original board
* - User is confused as to why the board changed
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
}
},
});

View File

@@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import {
controlLayerModelChanged,
referenceImageIPAdapterModelChanged,
@@ -41,6 +40,7 @@ import {
isSpandrelImageToImageModelConfig,
isT5EncoderModelConfig,
} from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('models');
@@ -85,7 +85,7 @@ type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<SerializableObject>
log: Logger<JsonObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {

View File

@@ -3,7 +3,6 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
@@ -37,6 +36,7 @@ import undoable from 'redux-undo';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
@@ -139,7 +139,7 @@ const unserialize: UnserializeFunction = (data, key) => {
{
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as SerializableObject, // this is always serializable
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
},
`Rehydrated slice "${key}"`
);

View File

@@ -25,7 +25,8 @@ export type AppFeature =
| 'invocationCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken';
| 'hfToken'
| 'invocationProgressAlert';
/**
* A disable-able Stable Diffusion feature

View File

@@ -1,251 +0,0 @@
import type { ChakraProps, FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { IAILoadingImageFallback, IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable';
const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
const baseStyles: SystemStyleObject = {
touchAction: 'none',
userSelect: 'none',
webkitUserSelect: 'none',
};
const sx: SystemStyleObject = {
...baseStyles,
'.gallery-image-container::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected="selected"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected="selectedForCompare"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected="selected"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected="selectedForCompare"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
};
type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
withMetadataOverlay?: boolean;
isDragDisabled?: boolean;
isDropDisabled?: boolean;
isUploadDisabled?: boolean;
minSize?: number;
postUploadAction?: PostUploadAction;
imageSx?: ChakraProps['sx'];
fitContainer?: boolean;
droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData;
dropLabel?: string;
isSelected?: boolean;
isSelectedForCompare?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
withHoverOverlay?: boolean;
children?: JSX.Element;
uploadElement?: ReactNode;
dataTestId?: string;
};
const IAIDndImage = (props: IAIDndImageProps) => {
const {
imageDTO,
onError,
onClick,
withMetadataOverlay = false,
isDropDisabled = false,
isDragDisabled = false,
isUploadDisabled = false,
minSize = 24,
postUploadAction,
imageSx,
fitContainer = false,
droppableData,
draggableData,
dropLabel,
isSelected = false,
isSelectedForCompare = false,
thumbnail = false,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
useThumbailFallback,
withHoverOverlay = false,
children,
dataTestId,
...rest
} = props;
const openInNewTab = useCallback(
(e: MouseEvent) => {
if (!imageDTO) {
return;
}
if (e.button !== 1) {
return;
}
window.open(imageDTO.image_url, '_blank');
},
[imageDTO]
);
const ref = useRef<HTMLDivElement>(null);
useImageContextMenu(imageDTO, ref);
return (
<Flex
ref={ref}
width="full"
height="full"
alignItems="center"
justifyContent="center"
position="relative"
minW={minSize ? minSize : undefined}
minH={minSize ? minSize : undefined}
userSelect="none"
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
sx={withHoverOverlay ? sx : baseStyles}
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
{...rest}
>
{imageDTO && (
<Flex
className="gallery-image-container"
w="full"
h="full"
position={fitContainer ? 'absolute' : 'relative'}
alignItems="center"
justifyContent="center"
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
onError={onError}
draggable={false}
w={imageDTO.width}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
sx={imageSx}
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<UploadButton
isUploadDisabled={isUploadDisabled}
postUploadAction={postUploadAction}
uploadElement={uploadElement}
minSize={minSize}
/>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
</Flex>
);
};
export default memo(IAIDndImage);
const UploadButton = memo(
({
isUploadDisabled,
postUploadAction,
uploadElement,
minSize,
}: {
isUploadDisabled: boolean;
postUploadAction?: PostUploadAction;
uploadElement: ReactNode;
minSize: number;
}) => {
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
isDisabled: isUploadDisabled,
});
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
const styles: SystemStyleObject = {
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: 'base.500',
};
if (!isUploadDisabled) {
Object.assign(styles, {
cursor: 'pointer',
bg: 'base.700',
_hover: {
bg: 'base.650',
color: 'base.300',
},
});
}
return styles;
}, [isUploadDisabled, minSize]);
return (
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
<input {...getUploadInputProps()} />
{uploadElement}
</Flex>
);
}
);
UploadButton.displayName = 'UploadButton';

View File

@@ -1,38 +0,0 @@
import type { BoxProps } from '@invoke-ai/ui-library';
import { Box } from '@invoke-ai/ui-library';
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDraggableData } from 'features/dnd/types';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
type IAIDraggableProps = BoxProps & {
disabled?: boolean;
data?: TypesafeDraggableData;
};
const IAIDraggable = (props: IAIDraggableProps) => {
const { data, disabled, ...rest } = props;
const dndId = useRef(uuidv4());
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
id: dndId.current,
disabled,
data,
});
return (
<Box
ref={setNodeRef}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
{...attributes}
{...listeners}
{...rest}
/>
);
};
export default memo(IAIDraggable);

View File

@@ -1,64 +0,0 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
type Props = {
isOver: boolean;
label?: string;
withBackdrop?: boolean;
};
const IAIDropOverlay = (props: Props) => {
const { isOver, label, withBackdrop = true } = props;
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0}>
<Flex
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
w="full"
h="full"
bg={withBackdrop ? 'base.900' : 'transparent'}
opacity={0.7}
borderRadius="base"
alignItems="center"
justifyContent="center"
transitionProperty="common"
transitionDuration="0.1s"
/>
<Flex
position="absolute"
top={0.5}
right={0.5}
bottom={0.5}
left={0.5}
opacity={1}
borderWidth={1.5}
borderColor={isOver ? 'invokeYellow.300' : 'base.500'}
borderRadius="base"
borderStyle="dashed"
transitionProperty="common"
transitionDuration="0.1s"
alignItems="center"
justifyContent="center"
>
{label && (
<Text
fontSize="lg"
fontWeight="semibold"
color={isOver ? 'invokeYellow.300' : 'base.500'}
transitionProperty="common"
transitionDuration="0.1s"
textAlign="center"
>
{label}
</Text>
)}
</Flex>
</Flex>
);
};
export default memo(IAIDropOverlay);

View File

@@ -1,46 +0,0 @@
import { Box } from '@invoke-ai/ui-library';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { AnimatePresence } from 'framer-motion';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay';
type IAIDroppableProps = {
dropLabel?: string;
disabled?: boolean;
data?: TypesafeDroppableData;
};
const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppableTypesafe({
id: dndId.current,
disabled,
data,
});
return (
<Box
ref={setNodeRef}
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
w="full"
h="full"
pointerEvents={active ? 'auto' : 'none'}
>
<AnimatePresence>
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
</AnimatePresence>
</Box>
);
};
export default memo(IAIDroppable);

View File

@@ -1,24 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Skeleton } from '@invoke-ai/ui-library';
import { memo } from 'react';
const skeletonStyles: SystemStyleObject = {
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
};
const IAIFillSkeleton = () => {
return (
<Skeleton sx={skeletonStyles}>
<Box position="absolute" top={0} insetInlineStart={0} height="full" width="full" />
</Skeleton>
);
};
export default memo(IAIFillSkeleton);

View File

@@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
type Props = { image: ImageDTO | undefined };
export const IAILoadingImageFallback = memo((props: Props) => {
const IAILoadingImageFallback = memo((props: Props) => {
if (props.image) {
return (
<Skeleton

View File

@@ -1,28 +0,0 @@
import { Badge, Flex } from '@invoke-ai/ui-library';
import { memo } from 'react';
import type { ImageDTO } from 'services/api/types';
type ImageMetadataOverlayProps = {
imageDTO: ImageDTO;
};
const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => {
return (
<Flex
pointerEvents="none"
flexDirection="column"
position="absolute"
top={0}
insetInlineStart={0}
p={2}
alignItems="flex-start"
gap={2}
>
<Badge variant="solid" colorScheme="base">
{imageDTO.width} × {imageDTO.height}
</Badge>
</Flex>
);
};
export default memo(ImageMetadataOverlay);

View File

@@ -1,89 +0,0 @@
import { Box, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { memo } from 'react';
import type { DropzoneState } from 'react-dropzone';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useBoardName } from 'services/api/hooks/useBoardName';
type ImageUploadOverlayProps = {
dropzone: DropzoneState;
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
};
const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
const { dropzone, setIsHandlingUpload } = props;
useHotkeys(
'esc',
() => {
setIsHandlingUpload(false);
},
[setIsHandlingUpload]
);
return (
<Box position="absolute" top={0} right={0} bottom={0} left={0} zIndex={999} backdropFilter="blur(20px)">
<Flex position="absolute" top={0} right={0} bottom={0} left={0} bg="base.900" opacity={0.7} />
<Flex
position="absolute"
flexDir="column"
gap={4}
top={2}
right={2}
bottom={2}
left={2}
opacity={1}
borderWidth={2}
borderColor={dropzone.isDragAccept ? 'invokeYellow.300' : 'error.500'}
borderRadius="base"
borderStyle="dashed"
transitionProperty="common"
transitionDuration="0.1s"
alignItems="center"
justifyContent="center"
color={dropzone.isDragReject ? 'error.300' : undefined}
>
{dropzone.isDragAccept && <DragAcceptMessage />}
{!dropzone.isDragAccept && <DragRejectMessage />}
</Flex>
</Box>
);
};
export default memo(ImageUploadOverlay);
const DragAcceptMessage = () => {
const { t } = useTranslation();
const selectedBoardId = useAppSelector(selectSelectedBoardId);
const boardName = useBoardName(selectedBoardId);
return (
<>
<Heading size="lg">{t('gallery.dropToUpload')}</Heading>
<Heading size="md">{t('toast.imagesWillBeAddedTo', { boardName })}</Heading>
</>
);
};
const DragRejectMessage = () => {
const { t } = useTranslation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
if (maxImageUploadCount === undefined) {
return (
<>
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc')}</Heading>
</>
);
}
return (
<>
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount })}</Heading>
</>
);
};

View File

@@ -1,9 +1,13 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { autoScrollForElements } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/element';
import { autoScrollForExternal } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/external';
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import type { OverlayScrollbarsComponentRef } from 'overlayscrollbars-react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties, PropsWithChildren } from 'react';
import { memo, useMemo } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
type Props = PropsWithChildren & {
maxHeight?: ChakraProps['maxHeight'];
@@ -11,17 +15,38 @@ type Props = PropsWithChildren & {
overflowY?: 'hidden' | 'scroll';
};
const styles: CSSProperties = { height: '100%', width: '100%' };
const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 };
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
const overlayscrollbarsOptions = useMemo(
() => getOverlayScrollbarsParams(overflowX, overflowY).options,
[overflowX, overflowY]
);
const [os, osRef] = useState<OverlayScrollbarsComponentRef | null>(null);
useEffect(() => {
const osInstance = os?.osInstance();
if (!osInstance) {
return;
}
const element = osInstance.elements().viewport;
// `pragmatic-drag-and-drop-auto-scroll` requires the element to have `overflow-y: scroll` or `overflow-y: auto`
// else it logs an ugly warning. In our case, using a custom scrollbar library, it will be 'hidden' by default.
// To prevent the erroneous warning, we temporarily set the overflow-y to 'scroll' and then revert it back.
const overflowY = element.style.overflowY; // starts 'hidden'
element.style.setProperty('overflow-y', 'scroll', 'important');
const cleanup = combine(autoScrollForElements({ element }), autoScrollForExternal({ element }));
element.style.setProperty('overflow-y', overflowY);
return cleanup;
}, [os]);
return (
<Flex w="full" h="full" maxHeight={maxHeight} position="relative">
<Box position="absolute" top={0} left={0} right={0} bottom={0}>
<OverlayScrollbarsComponent defer style={styles} options={overlayscrollbarsOptions}>
<OverlayScrollbarsComponent ref={osRef} style={styles} options={overlayscrollbarsOptions}>
{children}
</OverlayScrollbarsComponent>
</Box>

View File

@@ -1,124 +0,0 @@
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import type { Accept, FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
const log = logger('gallery');
const accept: Accept = {
'image/png': ['.png'],
'image/jpeg': ['.jpg', '.jpeg', '.png'],
};
export const useFullscreenDropzone = () => {
useAssertSingleton('useFullscreenDropzone');
const { t } = useTranslation();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
const [uploadImage] = useUploadImageMutation();
const activeTabName = useAppSelector(selectActiveTab);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const getPostUploadAction = useCallback((): PostUploadAction => {
if (activeTabName === 'upscaling') {
return { type: 'SET_UPSCALE_INITIAL_IMAGE' };
} else {
return { type: 'TOAST' };
}
}, [activeTabName]);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
if (fileRejections.length > 0) {
const errors = fileRejections.map((rejection) => ({
errors: rejection.errors.map(({ message }) => message),
file: rejection.file.path,
}));
log.error({ errors }, 'Invalid upload');
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description,
status: 'error',
});
setIsHandlingUpload(false);
return;
}
for (const [i, file] of acceptedFiles.entries()) {
uploadImage({
file,
image_category: 'user',
is_intermediate: false,
postUploadAction: getPostUploadAction(),
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
// The `imageUploaded` listener does some extra logic, like switching to the asset view on upload on the
// first upload of a "batch".
isFirstUploadOfBatch: i === 0,
});
}
setIsHandlingUpload(false);
},
[t, maxImageUploadCount, uploadImage, getPostUploadAction, autoAddBoardId]
);
const onDragOver = useCallback(() => {
setIsHandlingUpload(true);
}, []);
const onDragLeave = useCallback(() => {
setIsHandlingUpload(false);
}, []);
const dropzone = useDropzone({
accept,
noClick: true,
onDrop,
onDragOver,
onDragLeave,
noKeyboard: true,
multiple: maxImageUploadCount === undefined || maxImageUploadCount > 1,
maxFiles: maxImageUploadCount,
});
useEffect(() => {
// This is a hack to allow pasting images into the uploader
const handlePaste = (e: ClipboardEvent) => {
if (!dropzone.inputRef.current) {
return;
}
if (e.clipboardData?.files) {
// Set the files on the dropzone.inputRef
dropzone.inputRef.current.files = e.clipboardData.files;
// Dispatch the change event, dropzone catches this and we get to use its own validation
dropzone.inputRef.current?.dispatchEvent(new Event('change', { bubbles: true }));
}
};
// Add the paste event listener
document.addEventListener('paste', handlePaste);
return () => {
document.removeEventListener('paste', handlePaste);
};
}, [dropzone.inputRef]);
return { dropzone, isHandlingUpload, setIsHandlingUpload };
};

View File

@@ -1,3 +1,5 @@
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
@@ -7,14 +9,23 @@ import { useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
import { PiUploadBold } from 'react-icons/pi';
import { uploadImages, useUploadImageMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
type UseImageUploadButtonArgs = {
postUploadAction?: PostUploadAction;
isDisabled?: boolean;
allowMultiple?: boolean;
};
type UseImageUploadButtonArgs =
| {
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
};
const log = logger('gallery');
@@ -37,30 +48,53 @@ const log = logger('gallery');
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
*/
export const useImageUploadButton = ({
postUploadAction,
isDisabled,
allowMultiple = false,
}: UseImageUploadButtonArgs) => {
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [uploadImage] = useUploadImageMutation();
const [uploadImage, request] = useUploadImageMutation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const { t } = useTranslation();
const onDropAccepted = useCallback(
(files: File[]) => {
for (const [i, file] of files.entries()) {
uploadImage({
async (files: File[]) => {
if (!allowMultiple) {
if (files.length > 1) {
log.warn('Multiple files dropped but only one allowed');
return;
}
if (files.length === 0) {
// Should never happen
log.warn('No files dropped');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
file,
image_category: 'user',
is_intermediate: false,
postUploadAction: postUploadAction ?? { type: 'TOAST' },
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
});
silent: true,
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
const imageDTOs = await uploadImages(
files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
if (onUpload) {
onUpload(imageDTOs);
}
}
},
[autoAddBoardId, postUploadAction, uploadImage]
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
);
const onDropRejected = useCallback(
@@ -103,5 +137,42 @@ export const useImageUploadButton = ({
maxFiles: maxImageUploadCount,
});
return { getUploadButtonProps, getUploadInputProps, openUploader };
return { getUploadButtonProps, getUploadInputProps, openUploader, request };
};
const sx = {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 0,
borderRadius: 'base',
'&[data-error=true]': {
borderWidth: 1,
},
} satisfies SystemStyleObject;
export const UploadImageButton = ({
isDisabled = false,
onUpload,
isError = false,
...rest
}: {
onUpload?: (imageDTO: ImageDTO) => void;
isError?: boolean;
} & SetOptional<IconButtonProps, 'aria-label'>) => {
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload });
return (
<>
<IconButton
aria-label="Upload image"
variant="ghost"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
isLoading={uploadApi.request.isLoading}
{...rest}
{...uploadApi.getUploadButtonProps()}
/>
<input {...uploadApi.getUploadInputProps()} />
</>
);
};

View File

@@ -119,11 +119,20 @@ const createSelector = (
reasons.push({ content: i18n.t('upscaling.exceedsMaxSize') });
}
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
if (model && !['sd-1', 'sdxl'].includes(model.base)) {
// When we are using an upsupported model, do not add the other warnings
reasons.push({ content: i18n.t('upscaling.incompatibleBaseModel') });
} else {
// Using a compatible model, add all warnings
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
}
}
} else {
if (canvasIsFiltering) {

View File

@@ -1,14 +0,0 @@
import { getPrefixedId, nanoid } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
export const useNanoid = (prefix?: string) => {
const id = useMemo(() => {
if (prefix) {
return getPrefixedId(prefix);
} else {
return nanoid();
}
}, [prefix]);
return id;
};

View File

@@ -1,12 +0,0 @@
type SerializableValue =
| string
| number
| boolean
| null
| undefined
| SerializableValue[]
| readonly SerializableValue[]
| SerializableObject;
export type SerializableObject = {
[k: string | number]: SerializableValue;
};

View File

@@ -0,0 +1,6 @@
import type { AssertionError } from 'tsafe';
export function extractMessageFromAssertionError(error: AssertionError): string | null {
const match = error.message.match(/Wrong assertion encountered: "(.*)"/);
return match ? (match[1] ?? null) : null;
}

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -23,6 +23,7 @@ export const CanvasAddEntityButtons = memo(() => {
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
@@ -36,6 +37,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
isDisabled={isSD3}
>
{t('controlLayers.globalReferenceImage')}
</Button>
@@ -61,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX}
isDisabled={isFLUX || isSD3}
>
{t('controlLayers.regionalGuidance')}
</Button>
@@ -73,7 +75,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX}
isDisabled={isFLUX || isSD3}
>
{t('controlLayers.regionalReferenceImage')}
</Button>
@@ -88,6 +90,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
isDisabled={isSD3}
>
{t('controlLayers.controlLayer')}
</Button>

View File

@@ -0,0 +1,45 @@
import { Alert, AlertDescription, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemShouldShowInvocationProgressDetail } from 'features/system/store/systemSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { $invocationProgressMessage } from 'services/events/stores';
const CanvasAlertsInvocationProgressContent = memo(() => {
const { t } = useTranslation();
const invocationProgressMessage = useStore($invocationProgressMessage);
if (!invocationProgressMessage) {
return null;
}
return (
<Alert status="loading" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('common.generating')}</AlertTitle>
<AlertDescription>{invocationProgressMessage}</AlertDescription>
</Alert>
);
});
CanvasAlertsInvocationProgressContent.displayName = 'CanvasAlertsInvocationProgressContent';
export const CanvasAlertsInvocationProgress = memo(() => {
const isProgressMessageAlertEnabled = useFeatureStatus('invocationProgressAlert');
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
// The alert is disabled at the system level
if (!isProgressMessageAlertEnabled) {
return null;
}
// The alert is disabled at the user level
if (!shouldShowInvocationProgressDetail) {
return null;
}
return <CanvasAlertsInvocationProgressContent />;
});
CanvasAlertsInvocationProgress.displayName = 'CanvasAlertsInvocationProgress';

View File

@@ -1,38 +1,26 @@
import { Grid, GridItem } from '@invoke-ai/ui-library';
import IAIDroppable from 'common/components/IAIDroppable';
import type {
AddControlLayerFromImageDropData,
AddGlobalReferenceImageFromImageDropData,
AddRasterLayerFromImageDropData,
AddRegionalReferenceImageFromImageDropData,
} from 'features/dnd/types';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { newCanvasEntityFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const addRasterLayerFromImageDropData: AddRasterLayerFromImageDropData = {
id: 'add-raster-layer-from-image-drop-data',
actionType: 'ADD_RASTER_LAYER_FROM_IMAGE',
};
const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_CONTROL_LAYER_FROM_IMAGE',
};
const addRegionalReferenceImageFromImageDropData: AddRegionalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE',
};
const addGlobalReferenceImageFromImageDropData: AddGlobalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE',
};
const addRasterLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({ type: 'raster_layer' });
const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'control_layer',
});
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addGlobalReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'reference_image',
});
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
if (imageViewer.isOpen) {
return null;
@@ -51,28 +39,36 @@ export const CanvasDropArea = memo(() => {
pointerEvents="none"
>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newRasterLayer')}
data={addRasterLayerFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRasterLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRasterLayer')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newControlLayer')}
data={addControlLayerFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addControlLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newControlLayer')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
data={addRegionalReferenceImageFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRegionalGuidanceReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
data={addGlobalReferenceImageFromImageDropData}
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addGlobalReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
isDisabled={isBusy}
/>
</GridItem>
</Grid>

View File

@@ -0,0 +1,59 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasEntityListDnd } from 'features/controlLayers/components/CanvasEntityList/useCanvasEntityListDnd';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsSelected } from 'features/controlLayers/hooks/useEntityIsSelected';
import { entitySelected } from 'features/controlLayers/store/canvasSlice';
import { DndListDropIndicator } from 'features/dnd/DndListDropIndicator';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useRef } from 'react';
const sx = {
position: 'relative',
flexDir: 'column',
w: 'full',
bg: 'base.850',
borderRadius: 'base',
'&[data-selected=true]': {
bg: 'base.800',
},
'&[data-is-dragging=true]': {
opacity: 0.3,
},
transitionProperty: 'common',
} satisfies SystemStyleObject;
export const CanvasEntityContainer = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
const isSelected = useEntityIsSelected(entityIdentifier);
const onClick = useCallback(() => {
if (isSelected) {
return;
}
dispatch(entitySelected({ entityIdentifier }));
}, [dispatch, entityIdentifier, isSelected]);
const ref = useRef<HTMLDivElement>(null);
const [dndListState, isDragging] = useCanvasEntityListDnd(ref, entityIdentifier);
return (
<Box position="relative">
<Flex
// This is used to trigger the post-move flash animation
data-entity-id={entityIdentifier.id}
data-selected={isSelected}
data-is-dragging={isDragging}
ref={ref}
onClick={onClick}
sx={sx}
>
{props.children}
</Flex>
<DndListDropIndicator dndState={dndListState} />
</Box>
);
});
CanvasEntityContainer.displayName = 'CanvasEntityContainer';

View File

@@ -0,0 +1,181 @@
import { monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
import { reorderWithEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/util/reorder-with-edge';
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useBoolean } from 'common/hooks/useBoolean';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import { entitiesReordered } from 'features/controlLayers/store/canvasSlice';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isRenderableEntityType } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { triggerPostMoveFlash } from 'features/dnd/util';
import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';
import { flushSync } from 'react-dom';
import { PiCaretDownBold } from 'react-icons/pi';
type Props = PropsWithChildren<{
isSelected: boolean;
type: CanvasEntityIdentifier['type'];
entityIdentifiers: CanvasEntityIdentifier[];
}>;
export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityIdentifiers }: Props) => {
const title = useEntityTypeTitle(type);
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
const collapse = useBoolean(true);
const dispatch = useAppDispatch();
useEffect(() => {
return monitorForElements({
canMonitor({ source }) {
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
return false;
}
if (source.data.payload.entityIdentifier.type !== type) {
return false;
}
return true;
},
onDrop({ location, source }) {
const target = location.current.dropTargets[0];
if (!target) {
return;
}
const sourceData = source.data;
const targetData = target.data;
if (!singleCanvasEntityDndSource.typeGuard(sourceData) || !singleCanvasEntityDndSource.typeGuard(targetData)) {
return;
}
const indexOfSource = entityIdentifiers.findIndex(
(entityIdentifier) => entityIdentifier.id === sourceData.payload.entityIdentifier.id
);
const indexOfTarget = entityIdentifiers.findIndex(
(entityIdentifier) => entityIdentifier.id === targetData.payload.entityIdentifier.id
);
if (indexOfTarget < 0 || indexOfSource < 0) {
return;
}
// Don't move if the source and target are the same index, meaning same position in the list
if (indexOfSource === indexOfTarget) {
return;
}
const closestEdgeOfTarget = extractClosestEdge(targetData);
// It's possible that the indices are different, but refer to the same position. For example, if the source is
// at 2 and the target is at 3, but the target edge is 'top', then the entity is already in the correct position.
// We should bail if this is the case.
let edgeIndexDelta = 0;
if (closestEdgeOfTarget === 'bottom') {
edgeIndexDelta = 1;
} else if (closestEdgeOfTarget === 'top') {
edgeIndexDelta = -1;
}
// If the source is already in the correct position, we don't need to move it.
if (indexOfSource === indexOfTarget + edgeIndexDelta) {
return;
}
// Using `flushSync` so we can query the DOM straight after this line
flushSync(() => {
dispatch(
entitiesReordered({
type,
entityIdentifiers: reorderWithEdge({
list: entityIdentifiers,
startIndex: indexOfSource,
indexOfTarget,
closestEdgeOfTarget,
axis: 'vertical',
}),
})
);
});
// Flash the element that was moved
const element = document.querySelector(`[data-entity-id="${sourceData.payload.entityIdentifier.id}"]`);
if (element instanceof HTMLElement) {
triggerPostMoveFlash(element, colorTokenToCssVar('base.700'));
}
},
});
}, [dispatch, entityIdentifiers, type]);
return (
<Flex flexDir="column" w="full">
<Flex w="full">
<Flex
flexGrow={1}
as={Button}
onClick={collapse.toggle}
justifyContent="space-between"
alignItems="center"
gap={3}
variant="unstyled"
p={0}
h={8}
>
<Icon
boxSize={4}
as={PiCaretDownBold}
transform={collapse.isTrue ? undefined : 'rotate(-90deg)'}
fill={isSelected ? 'base.200' : 'base.500'}
transitionProperty="common"
transitionDuration="fast"
/>
{informationalPopoverFeature ? (
<InformationalPopover feature={informationalPopoverFeature}>
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
</InformationalPopover>
) : (
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
)}
<Spacer />
</Flex>
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>
<Flex flexDir="column" gap={2} pt={2}>
{children}
</Flex>
</Collapse>
</Flex>
);
});
CanvasEntityGroupList.displayName = 'CanvasEntityGroupList';

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -24,6 +24,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
<Menu>
@@ -40,7 +41,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
/>
<MenuList>
<MenuGroup title={t('controlLayers.global')}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isSD3}>
{t('controlLayers.globalReferenceImage')}
</MenuItem>
</MenuGroup>
@@ -48,15 +49,15 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
{t('controlLayers.regionalReferenceImage')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.layer_other')}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isSD3}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>

View File

@@ -0,0 +1,85 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable, dropTargetForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { type DndListTargetState, idle } from 'features/dnd/types';
import { firefoxDndFix } from 'features/dnd/util';
import type { RefObject } from 'react';
import { useEffect, useState } from 'react';
export const useCanvasEntityListDnd = (ref: RefObject<HTMLElement>, entityIdentifier: CanvasEntityIdentifier) => {
const [dndListState, setDndListState] = useState<DndListTargetState>(idle);
const [isDragging, setIsDragging] = useState(false);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
return combine(
firefoxDndFix(element),
draggable({
element,
getInitialData() {
return singleCanvasEntityDndSource.getData({ entityIdentifier });
},
onDragStart() {
setDndListState({ type: 'is-dragging' });
setIsDragging(true);
},
onDrop() {
setDndListState(idle);
setIsDragging(false);
},
}),
dropTargetForElements({
element,
canDrop({ source }) {
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
return false;
}
if (source.data.payload.entityIdentifier.type !== entityIdentifier.type) {
return false;
}
return true;
},
getData({ input }) {
const data = singleCanvasEntityDndSource.getData({ entityIdentifier });
return attachClosestEdge(data, {
element,
input,
allowedEdges: ['top', 'bottom'],
});
},
getIsSticky() {
return true;
},
onDragEnter({ self }) {
const closestEdge = extractClosestEdge(self.data);
setDndListState({ type: 'is-dragging-over', closestEdge });
},
onDrag({ self }) {
const closestEdge = extractClosestEdge(self.data);
// Only need to update react state if nothing has changed.
// Prevents re-rendering.
setDndListState((current) => {
if (current.type === 'is-dragging-over' && current.closestEdge === closestEdge) {
return current;
}
return { type: 'is-dragging-over', closestEdge };
});
},
onDragLeave() {
setDndListState(idle);
},
onDrop() {
setDndListState(idle);
},
})
);
}, [entityIdentifier, ref]);
return [dndListState, isDragging] as const;
};

View File

@@ -21,6 +21,8 @@ import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageV
import { memo, useCallback, useRef } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
const MenuContent = () => {
return (
<CanvasManagerProviderGate>
@@ -84,6 +86,7 @@ export const CanvasMainPanelContent = memo(() => {
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsSendingToGallery />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>
@@ -109,7 +112,9 @@ export const CanvasMainPanelContent = memo(() => {
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasDropArea />
<CanvasManagerProviderGate>
<CanvasDropArea />
</CanvasManagerProviderGate>
<GatedImageViewer />
</Flex>
);

View File

@@ -1,16 +1,20 @@
import { useDndContext } from '@dnd-kit/core';
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { dropTargetForElements, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { dropTargetForExternal, monitorForExternal } from '@atlaskit/pragmatic-drag-and-drop/external/adapter';
import { Box, Button, Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { CanvasLayersPanelContent } from 'features/controlLayers/components/CanvasLayersPanelContent';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectEntityCountActive } from 'features/controlLayers/store/selectors';
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { selectActiveTabCanvasRightPanel } from 'features/ui/store/uiSelectors';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasRightPanel = memo(() => {
@@ -79,37 +83,13 @@ CanvasRightPanel.displayName = 'CanvasRightPanel';
const PanelTabs = memo(() => {
const { t } = useTranslation();
const activeTab = useAppSelector(selectActiveTabCanvasRightPanel);
const store = useAppStore();
const activeEntityCount = useAppSelector(selectEntityCountActive);
const tabTimeout = useRef<number | null>(null);
const dndCtx = useDndContext();
const dispatch = useAppDispatch();
const [mouseOverTab, setMouseOverTab] = useState<'layers' | 'gallery' | null>(null);
const onOnMouseOverLayersTab = useCallback(() => {
setMouseOverTab('layers');
tabTimeout.current = window.setTimeout(() => {
if (dndCtx.active) {
dispatch(activeTabCanvasRightPanelChanged('layers'));
}
}, 300);
}, [dndCtx.active, dispatch]);
const onOnMouseOverGalleryTab = useCallback(() => {
setMouseOverTab('gallery');
tabTimeout.current = window.setTimeout(() => {
if (dndCtx.active) {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}
}, 300);
}, [dndCtx.active, dispatch]);
const onMouseOut = useCallback(() => {
setMouseOverTab(null);
if (tabTimeout.current) {
clearTimeout(tabTimeout.current);
}
}, []);
const [layersTabDndState, setLayersTabDndState] = useState<DndTargetState>('idle');
const [galleryTabDndState, setGalleryTabDndState] = useState<DndTargetState>('idle');
const layersTabRef = useRef<HTMLDivElement>(null);
const galleryTabRef = useRef<HTMLDivElement>(null);
const timeoutRef = useRef<number | null>(null);
const layersTabLabel = useMemo(() => {
if (activeEntityCount === 0) {
@@ -118,23 +98,172 @@ const PanelTabs = memo(() => {
return `${t('controlLayers.layer_other')} (${activeEntityCount})`;
}, [activeEntityCount, t]);
useEffect(() => {
if (!layersTabRef.current) {
return;
}
const getIsOnLayersTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'layers';
const onDragEnter = () => {
// If we are already on the layers tab, do nothing
if (getIsOnLayersTab()) {
return;
}
// Else set the state to active and switch to the layers tab after a timeout
setLayersTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('layers'));
// When we switch tabs, the other tab should be pending
setLayersTabDndState('idle');
setGalleryTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnLayersTab()) {
setLayersTabDndState('idle');
} else {
setLayersTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setLayersTabDndState('potential');
};
return combine(
dropTargetForElements({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnLayersTab();
},
onDragStart,
}),
dropTargetForExternal({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnLayersTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
if (!galleryTabRef.current) {
return;
}
const getIsOnGalleryTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'gallery';
const onDragEnter = () => {
// If we are already on the gallery tab, do nothing
if (getIsOnGalleryTab()) {
return;
}
// Else set the state to active and switch to the gallery tab after a timeout
setGalleryTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
// When we switch tabs, the other tab should be pending
setGalleryTabDndState('idle');
setLayersTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnGalleryTab()) {
setGalleryTabDndState('idle');
} else {
setGalleryTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setGalleryTabDndState('potential');
};
return combine(
dropTargetForElements({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnGalleryTab();
},
onDragStart,
}),
dropTargetForExternal({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnGalleryTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
const onDrop = () => {
// Reset the dnd state when a drop happens
setGalleryTabDndState('idle');
setLayersTabDndState('idle');
};
const cleanup = combine(monitorForElements({ onDrop }), monitorForExternal({ onDrop }));
return () => {
cleanup();
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
}, []);
return (
<>
<Tab position="relative" onMouseOver={onOnMouseOverLayersTab} onMouseOut={onMouseOut} w={32}>
<Tab ref={layersTabRef} position="relative" w={32}>
<Box as="span" w="full">
{layersTabLabel}
</Box>
{dndCtx.active && activeTab !== 'layers' && (
<IAIDropOverlay isOver={mouseOverTab === 'layers'} withBackdrop={false} />
)}
<DndDropOverlay dndState={layersTabDndState} withBackdrop={false} />
</Tab>
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut} w={32}>
<Tab ref={galleryTabRef} position="relative" w={32}>
<Box as="span" w="full">
{t('gallery.gallery')}
</Box>
{dndCtx.active && activeTab !== 'gallery' && (
<IAIDropOverlay isOver={mouseOverTab === 'gallery'} withBackdrop={false} />
)}
<DndDropOverlay dndState={galleryTabDndState} withBackdrop={false} />
</Tab>
</>
);

View File

@@ -1,6 +1,5 @@
import { Spacer } from '@invoke-ai/ui-library';
import IAIDroppable from 'common/components/IAIDroppable';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
@@ -10,8 +9,11 @@ import { ControlLayerBadges } from 'features/controlLayers/components/ControlLay
import { ControlLayerSettings } from 'features/controlLayers/components/ControlLayer/ControlLayerSettings';
import { ControlLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { ReplaceLayerImageDropData } from 'features/dnd/types';
import type { ReplaceCanvasEntityObjectsWithImageDndTargetData } from 'features/dnd/dnd';
import { replaceCanvasEntityObjectsWithImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -21,14 +23,16 @@ type Props = {
export const ControlLayer = memo(({ id }: Props) => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const entityIdentifier = useMemo<CanvasEntityIdentifier<'control_layer'>>(
() => ({ id, type: 'control_layer' }),
[id]
);
const dropData = useMemo<ReplaceLayerImageDropData>(
() => ({ id, actionType: 'REPLACE_LAYER_WITH_IMAGE', context: { entityIdentifier } }),
[id, entityIdentifier]
const dndTargetData = useMemo<ReplaceCanvasEntityObjectsWithImageDndTargetData>(
() => replaceCanvasEntityObjectsWithImageDndTarget.getData({ entityIdentifier }, entityIdentifier.id),
[entityIdentifier]
);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<ControlLayerAdapterGate>
@@ -43,7 +47,12 @@ export const ControlLayer = memo(({ id }: Props) => {
<CanvasEntitySettingsWrapper>
<ControlLayerSettings />
</CanvasEntitySettingsWrapper>
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
<DndDropTarget
dndTarget={replaceCanvasEntityObjectsWithImageDndTarget}
dndTargetData={dndTargetData}
label={t('controlLayers.replaceLayer')}
isDisabled={isBusy}
/>
</CanvasEntityContainer>
</ControlLayerAdapterGate>
</EntityIdentifierContext.Provider>

View File

@@ -1,6 +1,7 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
@@ -21,10 +22,11 @@ import { getFilterForModel } from 'features/controlLayers/store/filters';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
const selectControlAdapter = useMemo(
@@ -41,7 +43,7 @@ const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<
export const ControlLayerControlAdapter = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { dispatch, getState } = useAppStore();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
@@ -113,11 +115,17 @@ export const ControlLayerControlAdapter = memo(() => {
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
const isBusy = useCanvasIsBusy();
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
[entityIdentifier]
const uploadOptions = useMemo(
() =>
({
onUpload: (imageDTO: ImageDTO) => {
replaceCanvasEntityObjectsWithImage({ entityIdentifier, imageDTO, dispatch, getState });
},
allowMultiple: false,
}) as const,
[dispatch, entityIdentifier, getState]
);
const uploadApi = useImageUploadButton({ postUploadAction });
const uploadApi = useImageUploadButton(uploadOptions);
return (
<Flex flexDir="column" gap={3} position="relative" w="full">

View File

@@ -1,14 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { ControlLayer } from 'features/controlLayers/components/ControlLayer/ControlLayer';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.controlLayers.entities.map(mapId).reverse();
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.controlLayers.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const ControlLayerEntityList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const layerIds = useAppSelector(selectEntityIds);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (layerIds.length === 0) {
if (entityIdentifiers.length === 0) {
return null;
}
if (layerIds.length > 0) {
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="control_layer" isSelected={isSelected}>
{layerIds.map((id) => (
<ControlLayer key={id} id={id} />
<CanvasEntityGroupList type="control_layer" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifier) => (
<ControlLayer key={entityIdentifier.id} id={entityIdentifier.id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -1,22 +1,25 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/nanostores/store';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { Trans } from 'react-i18next';
import type { PostUploadAction } from 'services/api/types';
import type { ImageDTO } from 'services/api/types';
export const ControlLayerSettingsEmptyState = memo(() => {
const entityIdentifier = useEntityIdentifierContext('control_layer');
const dispatch = useAppDispatch();
const { dispatch, getState } = useAppStore();
const isBusy = useCanvasIsBusy();
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
[entityIdentifier]
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
replaceCanvasEntityObjectsWithImage({ imageDTO, entityIdentifier, dispatch, getState });
},
[dispatch, entityIdentifier, getState]
);
const uploadApi = useImageUploadButton({ postUploadAction });
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);

View File

@@ -1,5 +1,5 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';

View File

@@ -1,82 +1,80 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { useNanoid } from 'common/hooks/useNanoid';
import { UploadImageButton } from 'common/hooks/useImageUploadButton';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { memo, useCallback, useEffect, useMemo } from 'react';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props = {
type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget> = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction;
dndTarget: T;
dndTargetData: ReturnType<T['getData']>;
};
export const IPAdapterImagePreview = memo(({ image, onChangeImage, droppableData, postUploadAction }: Props) => {
const { t } = useTranslation();
const isConnected = useStore($isConnected);
const dndId = useNanoid('ip_adapter_image_preview');
export const IPAdapterImagePreview = memo(
<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget>({
image,
onChangeImage,
dndTarget,
dndTargetData,
}: Props<T>) => {
const { t } = useTranslation();
const isConnected = useStore($isConnected);
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.image_name ?? skipToken
);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
useEffect(() => {
if (isConnected && isError) {
handleResetControlImage();
}
}, [handleResetControlImage, isError, isConnected]);
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
if (controlImage) {
return {
id: dndId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, dndId]);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
onChangeImage(imageDTO);
},
[onChangeImage]
);
useEffect(() => {
if (isConnected && isErrorControlImage) {
handleResetControlImage();
}
}, [handleResetControlImage, isConnected, isErrorControlImage]);
return (
<Flex
position="relative"
w="full"
h="full"
alignItems="center"
borderColor="error.500"
borderStyle="solid"
borderWidth={controlImage ? 0 : 1}
borderRadius="base"
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
postUploadAction={postUploadAction}
/>
{controlImage && (
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
<UploadImageButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
onUpload={onUpload}
fontSize={36}
/>
</Flex>
)}
</Flex>
);
});
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
</Flex>
);
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';

View File

@@ -2,14 +2,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { IPAdapter } from 'features/controlLayers/components/IPAdapter/IPAdapter';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(mapId).reverse();
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'reference_image';
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const IPAdapterList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const ipaIds = useAppSelector(selectEntityIds);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (ipaIds.length === 0) {
if (entityIdentifiers.length === 0) {
return null;
}
if (ipaIds.length > 0) {
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="reference_image" isSelected={isSelected}>
{ipaIds.map((id) => (
<IPAdapter key={id} id={id} />
<CanvasEntityGroupList type="reference_image" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifiers) => (
<IPAdapter key={entityIdentifiers.id} id={entityIdentifiers.id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -19,11 +19,12 @@ import {
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { IPAImageDropData } from 'features/dnd/types';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig, IPALayerImagePostUploadAction } from 'services/api/types';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';
@@ -80,13 +81,9 @@ export const IPAdapterSettings = memo(() => {
[dispatch, entityIdentifier]
);
const droppableData = useMemo<IPAImageDropData>(
() => ({ actionType: 'SET_IPA_IMAGE', context: { id: entityIdentifier.id }, id: entityIdentifier.id }),
[entityIdentifier.id]
);
const postUploadAction = useMemo<IPALayerImagePostUploadAction>(
() => ({ type: 'SET_IPA_IMAGE', id: entityIdentifier.id }),
[entityIdentifier.id]
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }, ipAdapter.image?.image_name),
[entityIdentifier, ipAdapter.image?.image_name]
);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const isBusy = useCanvasIsBusy();
@@ -122,10 +119,10 @@ export const IPAdapterSettings = memo(() => {
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
<IPAdapterImagePreview
image={ipAdapter.image ?? null}
image={ipAdapter.image}
onChangeImage={onChangeImage}
droppableData={droppableData}
postUploadAction={postUploadAction}
dndTarget={setGlobalReferenceImageDndTarget}
dndTargetData={dndTargetData}
/>
</Flex>
</Flex>

Some files were not shown because too many files have changed in this diff Show More