Compare commits

..

4 Commits

241 changed files with 6807 additions and 8886 deletions

View File

@@ -751,7 +751,7 @@ async def convert_model(
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
converted_model = loader.load_model(model_config)
converted_model = loader.load_model(model_config, queue_id="default")
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")

View File

@@ -31,6 +31,7 @@ from invokeai.app.services.events.events_common import (
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadEventBase,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
@@ -53,6 +54,13 @@ class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class ModelLoadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io model loading room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationProgressEvent,
@@ -69,8 +77,6 @@ MODEL_EVENTS = {
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
@@ -79,6 +85,11 @@ MODEL_EVENTS = {
ModelInstallErrorEvent,
}
MODEL_LOAD_EVENTS = {
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
@@ -101,6 +112,7 @@ class SocketIO:
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(MODEL_LOAD_EVENTS, self._handle_model_load_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
@@ -115,9 +127,18 @@ class SocketIO:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_sub_model_load(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
async def _handle_unsub_model_load(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_model_load_event(self, event: FastAPIEvent[ModelLoadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))

View File

@@ -63,12 +63,12 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
tokenizer_info = context.models.load(self.clip.tokenizer, queue_id=context.util.get_queue_id())
text_encoder_info = context.models.load(self.clip.text_encoder, queue_id=context.util.get_queue_id())
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, queue_id=context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@@ -95,7 +95,6 @@ class CompelInvocation(BaseInvocation):
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
@@ -138,8 +137,8 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
tokenizer_info = context.models.load(clip_field.tokenizer, queue_id=context.util.get_queue_id())
text_encoder_info = context.models.load(clip_field.text_encoder, queue_id=context.util.get_queue_id())
# return zero on empty
if prompt == "" and zero_on_empty:
@@ -164,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
@@ -192,7 +191,6 @@ class SDXLPromptInvocationBase:
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)

View File

@@ -649,7 +649,9 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
source=DEPTH_ANYTHING_MODELS[self.model_size],
queue_id=self._context.util.get_queue_id(),
loader=load_depth_anything,
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)

View File

@@ -60,12 +60,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)

View File

@@ -124,14 +124,13 @@ class CreateGradientMaskInvocation(BaseInvocation):
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
vae_info: LoadedModel = context.models.load(self.vae.vae, context.util.get_queue_id())
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)

View File

@@ -88,7 +88,7 @@ def get_scheduler(
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
orig_scheduler_info = context.models.load(scheduler_info, context.util.get_queue_id())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -435,7 +435,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
control_model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
assert isinstance(control_model, ControlNetModel)
control_image_field = control_info.image
@@ -492,7 +494,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
for control_info in control_list:
model = exit_stack.enter_context(context.models.load(control_info.control_model))
model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
ext_manager.add_extension(
ControlNetExt(
model=model,
@@ -545,9 +549,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
with context.models.load(
single_ip_adapter.ip_adapter_model, context.util.get_queue_id()
) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
image_encoder_model_info = context.models.load(
single_ip_adapter.image_encoder_model, context.util.get_queue_id()
)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@@ -581,7 +589,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
ip_adapter_model = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id())
)
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
@@ -621,7 +631,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
t2i_adapter_loaded_model = context.models.load(
t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id()
)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -926,7 +938,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (cached_weights, unet),
@@ -989,13 +1001,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,

View File

@@ -35,7 +35,9 @@ class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithB
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(model_url, DepthAnythingPipeline.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model
)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)

View File

@@ -29,10 +29,10 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
loaded_session_det = context.models.load_local_model(
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
onnx_det_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
)
loaded_session_pose = context.models.load_local_model(
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
onnx_pose_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
)
with loaded_session_det as session_det, loaded_session_pose as session_pose:

View File

@@ -56,7 +56,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.2.1",
version="3.2.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -81,7 +81,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
@@ -184,7 +183,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed,
)
transformer_info = context.models.load(self.transformer.transformer)
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
@@ -208,12 +207,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"to be poor. Consider using a FLUX dev model instead."
)
if self.add_noise:
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
x = init_latents
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
@@ -472,7 +468,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
controlnet_models = [
context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets
]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
@@ -483,7 +481,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
vae_info = context.models.load(self.controlnet_vae.vae, context.util.get_queue_id())
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
@@ -594,7 +592,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_images.append(pos_image)
neg_images.append(neg_image)
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
with context.models.load(
ip_adapter_field.image_encoder_model, context.util.get_queue_id()
) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
@@ -624,7 +624,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
ip_adapter_model = exit_stack.enter_context(
context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id())
)
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
if ip_adapter_field.mask is not None:
@@ -653,7 +655,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -57,8 +57,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
return FluxConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -71,15 +71,14 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
clip_tokenizer_info = context.models.load(self.clip.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(self.clip.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -112,7 +111,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
@@ -120,7 +118,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -41,8 +41,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
@@ -53,8 +52,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
context.util.signal_progress("Running VAE")
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()

View File

@@ -44,8 +44,9 @@ class FluxVaeEncodeInvocation(BaseInvocation):
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents
@@ -53,13 +54,12 @@ class FluxVaeEncodeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")

View File

@@ -94,7 +94,9 @@ class GroundingDinoInvocation(BaseInvocation):
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
source=GROUNDING_DINO_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=GroundingDinoInvocation._load_grounding_dino,
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@@ -22,7 +22,9 @@ class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, ControlNetHED_Apache2)

View File

@@ -111,13 +111,12 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)

View File

@@ -36,7 +36,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
@abstractmethod
def infill(self, image: Image.Image) -> Image.Image:
def infill(self, image: Image.Image, queue_id: str) -> Image.Image:
"""Infill the image with the specified method"""
pass
@@ -56,7 +56,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(context.images.get_dto(self.image.image_name))
# Perform Infill action
infilled_image = self.infill(input_image)
infilled_image = self.infill(input_image, context.util.get_queue_id())
# Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image)
@@ -74,7 +74,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
description="The color to use to infill",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
@@ -93,7 +93,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
description="The seed to use for tile generation (omit for random)",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled
@@ -107,7 +107,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width / self.downscale)
@@ -131,9 +131,10 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
queue_id=queue_id,
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
@@ -144,7 +145,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
class CV2InfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
return cv2_inpaint(image)
@@ -166,5 +167,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
description="The max threshold for color",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())

View File

@@ -57,10 +57,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:

View File

@@ -23,7 +23,9 @@ class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartEdgeDetector.get_model_url(self.coarse)
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, Generator)

View File

@@ -20,7 +20,9 @@ class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoar
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartAnimeEdgeDetector.get_model_url()
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, UnetGenerator)

View File

@@ -147,10 +147,6 @@ GENERATION_MODES = Literal[
"flux_img2img",
"flux_inpaint",
"flux_outpaint",
"sd3_txt2img",
"sd3_img2img",
"sd3_inpaint",
"sd3_outpaint",
]

View File

@@ -28,7 +28,9 @@ class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(), MLSDDetector.load_model)
loaded_model = context.models.load_remote_model(
MLSDDetector.get_model_url(), context.util.get_queue_id(), MLSDDetector.load_model
)
with loaded_model as model:
assert isinstance(model, MobileV2_MLSD_Large)

View File

@@ -20,7 +20,9 @@ class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
loaded_model = context.models.load_remote_model(
NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model
)
with loaded_model as model:
assert isinstance(model, NNET)

View File

@@ -22,7 +22,9 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
loaded_model = context.models.load_remote_model(
PIDINetDetector.get_model_url(), context.util.get_queue_id(), PIDINetDetector.load_model
)
with loaded_model as model:
assert isinstance(model, PiDiNet)

View File

@@ -1,19 +1,16 @@
from typing import Callable, Optional, Tuple
from typing import Callable, Tuple
import torch
import torchvision.transforms as tv_transforms
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from torchvision.transforms.functional import resize as tv_resize
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
SD3ConditioningField,
WithBoard,
WithMetadata,
@@ -22,9 +19,7 @@ from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -35,24 +30,16 @@ from invokeai.backend.util.devices import TorchDevice
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.1.0",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model, input=Input.Connection, title="Transformer"
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
)
positive_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
@@ -74,41 +61,6 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
represent the regions to be preserved.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
# 1.0 represents the regions to be preserved.
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _load_text_conditioning(
self,
context: InvocationContext,
@@ -195,7 +147,7 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
@@ -218,20 +170,14 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the timestep schedule.
# We add an extra step to the end to account for the final timestep of 0.0.
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
total_steps = len(timesteps) - 1
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(total_steps)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
cfg_scale = self._prepare_cfg_scale(len(timesteps))
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
@@ -245,34 +191,9 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
# Prepare input latent image.
if init_latents is not None:
# Noise the init_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
latents = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return latents
# Prepare inpaint extension.
inpaint_mask = self._prep_inpaint_mask(context, latents)
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
@@ -289,12 +210,11 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
for step_idx, t in tqdm(list(enumerate(timesteps))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
timestep = t.expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
@@ -312,19 +232,21 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = latents.to(dtype=torch.float32)
latents = latents + (t_prev - t_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
if inpaint_extension is not None:
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t_curr),
timestep=int(t),
latents=latents,
),
)

View File

@@ -1,65 +0,0 @@
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"sd3_i2l",
title="SD3 Image to Latents",
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image."""
image: ImageField = InputField(description="The image to encode")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
latents = vae.config.scaling_factor * latents
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -44,10 +44,9 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)

View File

@@ -86,8 +86,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -95,7 +95,6 @@ class Sd3TextEncoderInvocation(BaseInvocation):
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
@@ -128,8 +127,8 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
clip_tokenizer_info = context.models.load(clip_model.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(clip_model.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -138,7 +137,6 @@ class Sd3TextEncoderInvocation(BaseInvocation):
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
context.util.signal_progress("Running CLIP encoder")
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
@@ -195,7 +193,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -125,7 +125,9 @@ class SegmentAnythingInvocation(BaseInvocation):
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=SegmentAnythingInvocation._load_sam_model,
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)

View File

@@ -158,7 +158,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
@@ -207,7 +207,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.

View File

@@ -196,13 +196,13 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
with (
ExitStack() as exit_stack,

View File

@@ -90,7 +90,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(msg)
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
source=ESRGAN_MODEL_URLS[self.model_name], queue_id=context.util.get_queue_id()
)
with loadnet as loadnet_model:

View File

@@ -131,15 +131,17 @@ class EventServiceBase:
# region Model loading
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
def emit_model_load_started(
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
self.dispatch(ModelLoadStartedEvent.build(config, queue_id, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id, submodel_type))
# endregion

View File

@@ -383,12 +383,14 @@ class DownloadErrorEvent(DownloadEventBase):
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
class ModelLoadEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
class ModelLoadStartedEvent(ModelLoadEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
@@ -397,12 +399,14 @@ class ModelLoadStartedEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadStartedEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
class ModelLoadCompleteEvent(ModelLoadEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
@@ -411,8 +415,14 @@ class ModelLoadCompleteEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadCompleteEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
class ModelEventBase(EventBase):
"""Base class for model events"""
@payload_schema.register

View File

@@ -14,7 +14,9 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -29,7 +31,7 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.

View File

@@ -49,7 +49,9 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -60,7 +62,7 @@ class ModelLoadService(ModelLoadServiceBase):
# We don't have an invoker during testing
# TODO(psyche): Mock this method on the invoker in the tests
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
self._invoker.services.events.emit_model_load_started(model_config, queue_id, submodel_type)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
@@ -70,12 +72,12 @@ class ModelLoadService(ModelLoadServiceBase):
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
self._invoker.services.events.emit_model_load_complete(model_config, queue_id, submodel_type)
return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache

View File

@@ -160,10 +160,6 @@ class LoggerInterface(InvocationContextInterface):
class ImagesInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def save(
self,
image: Image,
@@ -190,8 +186,6 @@ class ImagesInterface(InvocationContextInterface):
The saved image DTO.
"""
self._util.signal_progress("Saving image")
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
@@ -342,10 +336,6 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Check if a model exists.
@@ -361,7 +351,10 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
self,
identifier: Union[str, "ModelIdentifierField"],
queue_id: str,
submodel_type: Optional[SubModelType] = None,
) -> LoadedModel:
"""Load a model.
@@ -378,18 +371,19 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, queue_id, submodel_type)
else:
submodel_type = submodel_type or identifier.submodel_type
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
message = f"Loading model {model.name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(model, submodel_type)
return self._services.model_manager.load.load_model(model, queue_id, _submodel_type)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
self,
name: str,
base: BaseModelType,
type: ModelType,
queue_id: str,
submodel_type: Optional[SubModelType] = None,
) -> LoadedModel:
"""Load a model by its attributes.
@@ -411,11 +405,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
message = f"Loading model {name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
return self._services.model_manager.load.load_model(configs[0], queue_id, submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Get a model's config.
@@ -485,12 +475,12 @@ class ModelsInterface(InvocationContextInterface):
Returns:
Path to the downloaded model
"""
self._util.signal_progress(f"Downloading model {source}")
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
self,
model_path: Path,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@@ -508,13 +498,14 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A LoadedModelWithoutConfig object.
"""
self._util.signal_progress(f"Loading model {model_path.name}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
def load_remote_model(
self,
source: str | AnyHttpUrl,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@@ -535,9 +526,9 @@ class ModelsInterface(InvocationContextInterface):
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
class ConfigInterface(InvocationContextInterface):
@@ -558,6 +549,14 @@ class UtilInterface(InvocationContextInterface):
super().__init__(services, data)
self._is_canceled = is_canceled
def get_queue_id(self) -> str:
"""Checks if the current session has been canceled.
Returns:
True if the current session has been canceled, False if not.
"""
return self._data.queue_item.queue_id
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.
@@ -730,12 +729,12 @@ def build_invocation_context(
"""
logger = LoggerInterface(services=services, data=data)
images = ImagesInterface(services=services, data=data)
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data, util=util)
images = ImagesInterface(services=services, data=data, util=util)
boards = BoardsInterface(services=services, data=data)
ctx = InvocationContext(

View File

@@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(name_or_key)
loaded_model = context.models.load(name_or_key, queue_id=context.util.get_queue_id())
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
@@ -30,7 +30,7 @@ def generate_ti_list(
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
name=name_or_key, base=base, type=ModelType.TextualInversion
name=name_or_key, base=base, type=ModelType.TextualInversion, queue_id=context.util.get_queue_id()
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)

View File

@@ -45,9 +45,8 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
# Constants for FLUX.1
num_double_layers = 19
num_single_layers = 38
hidden_size = 3072
mlp_ratio = 4.0
mlp_hidden_dim = int(hidden_size * mlp_ratio)
# inner_dim = 3072
# mlp_ratio = 4.0
layers: dict[str, AnyLoRALayer] = {}
@@ -63,43 +62,30 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(
src_keys: list[str],
src_weight_shapes: list[tuple[int, int]],
dst_qkv_key: str,
allow_missing_keys: bool = False,
) -> None:
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
"""
# If none of the keys are present, return early.
# We expect that either all src keys are present or none of them are. Verify this.
keys_present = [key in grouped_state_dict for key in src_keys]
assert all(keys_present) or not any(keys_present)
# If none of the keys are present, return early.
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayer] = []
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
src_layer_dict = grouped_state_dict.pop(src_key, None)
if src_layer_dict is not None:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
else:
if not allow_missing_keys:
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
values = {
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
}
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
for src_layer_dict in src_layer_dicts:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
@@ -132,7 +118,6 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"transformer_blocks.{i}.attn.to_k",
f"transformer_blocks.{i}.attn.to_v",
],
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
f"double_blocks.{i}.img_attn.qkv",
)
add_qkv_lora_layer_if_present(
@@ -141,7 +126,6 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"transformer_blocks.{i}.attn.add_k_proj",
f"transformer_blocks.{i}.attn.add_v_proj",
],
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
f"double_blocks.{i}.txt_attn.qkv",
)
@@ -191,14 +175,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"single_transformer_blocks.{i}.attn.to_v",
f"single_transformer_blocks.{i}.proj_mlp",
],
[
(hidden_size, hidden_size),
(hidden_size, hidden_size),
(hidden_size, hidden_size),
(mlp_hidden_dim, hidden_size),
],
f"single_blocks.{i}.linear1",
allow_missing_keys=True,
)
# Output projections.

View File

@@ -165,8 +165,6 @@ class SubmodelDefinition(BaseModel):
model_type: ModelType
variant: AnyVariant = None
model_config = ConfigDict(protected_namespaces=())
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")

View File

@@ -35,7 +35,6 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
self._torch_device = TorchDevice.choose_torch_device()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@@ -84,15 +84,7 @@ class FluxVAELoader(ModelLoader):
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
model.to(dtype=self._torch_dtype)
return model

View File

@@ -300,7 +300,7 @@ ip_adapter_sdxl = StarterModel(
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter)",
base=BaseModelType.Flux,
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors",
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
description="References images with a more generalized/looser degree of precision.",
type=ModelType.IPAdapter,
dependencies=[clip_vit_l_image_encoder],

View File

@@ -172,8 +172,6 @@ def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
try:
path = Path(location)
config_path = path / "config.json"
if not config_path.exists():
config_path = path / "text_encoder" / "config.json"
if not config_path.exists():
return ClipVariantType.L
with open(config_path) as file:

View File

@@ -85,7 +85,6 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result: set[Path] = set()
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
safetensors_detected = False
for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX:
@@ -120,16 +119,10 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
if safetensors_detected and path.suffix == ".bin":
continue
parent = path.parent
score = 0
if path.suffix == ".safetensors":
safetensors_detected = True
if parent in subfolder_weights:
subfolder_weights[parent] = [sfc for sfc in subfolder_weights[parent] if sfc.path.suffix != ".bin"]
score += 1
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None

View File

@@ -1,58 +0,0 @@
import torch
class InpaintExtension:
"""A class for managing inpainting with SD3."""
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
"""Initialize InpaintExtension.
Args:
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0).
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
inpainted region with the background.
noise (torch.Tensor): The noise tensor used to noise the init_latents.
"""
assert init_latents.dim() == inpaint_mask.dim() == noise.dim() == 4
assert init_latents.shape[-2:] == inpaint_mask.shape[-2:] == noise.shape[-2:]
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small)
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff
# threshold worked well.
# We use a small epsilon to avoid any potential issues with floating point precision.
eps = 1e-4
mask_gradient_t_cutoff = 0.5
if t_prev > mask_gradient_t_cutoff:
# Early in the denoising process, use the inpaint mask as-is.
return self._inpaint_mask
else:
# After the cut-off, promote all non-zero mask values to 1.0.
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0)
return mask
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, t_prev: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
trajectory.
This function should be called after each denoising step.
"""
mask = self._apply_mask_gradient_adjustment(t_prev)
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)

View File

@@ -29,7 +29,7 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
lora_model = self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()).model
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,

View File

@@ -54,7 +54,7 @@ class T2IAdapterExt(ExtensionBase):
@callback(ExtensionCallbackType.SETUP)
def setup(self, ctx: DenoiseContext):
t2i_model: T2IAdapter
with self._node_context.models.load(self._model_id) as t2i_model:
with self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()) as t2i_model:
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
self._adapter_state = self._run_model(

View File

@@ -52,11 +52,11 @@
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
"@dagrejs/dagre": "^1.1.4",
"@dagrejs/graphlib": "^2.2.4",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.1.0",
"@invoke-ai/ui-library": "^0.0.43",
"@nanostores/react": "^0.7.3",

View File

@@ -5,21 +5,21 @@ settings:
excludeLinksFromLockfile: false
dependencies:
'@atlaskit/pragmatic-drag-and-drop':
specifier: ^1.4.0
version: 1.4.0
'@atlaskit/pragmatic-drag-and-drop-auto-scroll':
specifier: ^1.4.0
version: 1.4.0
'@atlaskit/pragmatic-drag-and-drop-hitbox':
specifier: ^1.0.3
version: 1.0.3
'@dagrejs/dagre':
specifier: ^1.1.4
version: 1.1.4
'@dagrejs/graphlib':
specifier: ^2.2.4
version: 2.2.4
'@dnd-kit/core':
specifier: ^6.1.0
version: 6.1.0(react-dom@18.3.1)(react@18.3.1)
'@dnd-kit/sortable':
specifier: ^8.0.0
version: 8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1)
'@dnd-kit/utilities':
specifier: ^3.2.2
version: 3.2.2(react@18.3.1)
'@fontsource-variable/inter':
specifier: ^5.1.0
version: 5.1.0
@@ -319,28 +319,6 @@ packages:
'@jridgewell/trace-mapping': 0.3.25
dev: true
/@atlaskit/pragmatic-drag-and-drop-auto-scroll@1.4.0:
resolution: {integrity: sha512-5GoikoTSW13UX76F9TDeWB8x3jbbGlp/Y+3aRkHe1MOBMkrWkwNpJ42MIVhhX/6NSeaZiPumP0KbGJVs2tOWSQ==}
dependencies:
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
'@babel/runtime': 7.25.7
dev: false
/@atlaskit/pragmatic-drag-and-drop-hitbox@1.0.3:
resolution: {integrity: sha512-/Sbu/HqN2VGLYBhnsG7SbRNg98XKkbF6L7XDdBi+izRybfaK1FeMfodPpm/xnBHPJzwYMdkE0qtLyv6afhgMUA==}
dependencies:
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
'@babel/runtime': 7.25.7
dev: false
/@atlaskit/pragmatic-drag-and-drop@1.4.0:
resolution: {integrity: sha512-qRY3PTJIcxfl/QB8Gwswz+BRvlmgAC5pB+J2hL6dkIxgqAgVwOhAamMUKsrOcFU/axG2Q7RbNs1xfoLKDuhoPg==}
dependencies:
'@babel/runtime': 7.25.7
bind-event-listener: 3.0.0
raf-schd: 4.0.3
dev: false
/@babel/code-frame@7.25.7:
resolution: {integrity: sha512-0xZJFNE5XMpENsgfHYTw8FbX4kv53mFLn2i3XPoq69LyhYSCBJtitaHx9QnsVTrsogI4Z3+HtEfZ2/GFPOtf5g==}
engines: {node: '>=6.9.0'}
@@ -1002,6 +980,49 @@ packages:
engines: {node: '>17.0.0'}
dev: false
/@dnd-kit/accessibility@3.1.0(react@18.3.1):
resolution: {integrity: sha512-ea7IkhKvlJUv9iSHJOnxinBcoOI3ppGnnL+VDJ75O45Nss6HtZd8IdN8touXPDtASfeI2T2LImb8VOZcL47wjQ==}
peerDependencies:
react: '>=16.8.0'
dependencies:
react: 18.3.1
tslib: 2.7.0
dev: false
/@dnd-kit/core@6.1.0(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-J3cQBClB4TVxwGo3KEjssGEXNJqGVWx17aRTZ1ob0FliR5IjYgTxl5YJbKTzA6IzrtelotH19v6y7uoIRUZPSg==}
peerDependencies:
react: '>=16.8.0'
react-dom: '>=16.8.0'
dependencies:
'@dnd-kit/accessibility': 3.1.0(react@18.3.1)
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
tslib: 2.7.0
dev: false
/@dnd-kit/sortable@8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1):
resolution: {integrity: sha512-U3jk5ebVXe1Lr7c2wU7SBZjcWdQP+j7peHJfCspnA81enlu88Mgd7CC8Q+pub9ubP7eKVETzJW+IBAhsqbSu/g==}
peerDependencies:
'@dnd-kit/core': ^6.1.0
react: '>=16.8.0'
dependencies:
'@dnd-kit/core': 6.1.0(react-dom@18.3.1)(react@18.3.1)
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
react: 18.3.1
tslib: 2.7.0
dev: false
/@dnd-kit/utilities@3.2.2(react@18.3.1):
resolution: {integrity: sha512-+MKAJEOfaBe5SmV6t34p80MMKhjvUz0vRrvVJbPT0WElzaOJ/1xs+D+KDv+tD/NE5ujfrChEcshd4fLn0wpiqg==}
peerDependencies:
react: '>=16.8.0'
dependencies:
react: 18.3.1
tslib: 2.7.0
dev: false
/@emotion/babel-plugin@11.12.0:
resolution: {integrity: sha512-y2WQb+oP8Jqvvclh8Q55gLUyb7UFvgv7eJfsj7td5TToBrIUtPay2kMrZi4xjq9qw2vD0ZR5fSho0yqoFgX7Rw==}
dependencies:
@@ -4292,10 +4313,6 @@ packages:
open: 8.4.2
dev: true
/bind-event-listener@3.0.0:
resolution: {integrity: sha512-PJvH288AWQhKs2v9zyfYdPzlPqf5bXbGMmhmUIY9x4dAUGIWgomO771oBQNwJnMQSnUIXhKu6sgzpBRXTlvb8Q==}
dev: false
/bl@4.1.0:
resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==}
dependencies:
@@ -7540,10 +7557,6 @@ packages:
resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==}
dev: true
/raf-schd@4.0.3:
resolution: {integrity: sha512-tQkJl2GRWh83ui2DiPTJz9wEiMN20syf+5oKfB03yYP7ioZcJwsIK8FjrtLwH1m7C7e+Tt2yYBlrOpdT+dyeIQ==}
dev: false
/raf-throttle@2.0.6:
resolution: {integrity: sha512-C7W6hy78A+vMmk5a/B6C5szjBHrUzWJkVyakjKCK59Uy2CcA7KhO1JUvvH32IXYFIcyJ3FMKP3ZzCc2/71I6Vg==}
dev: false

View File

@@ -174,8 +174,7 @@
"placeholderSelectAModel": "Select a model",
"reset": "Reset",
"none": "None",
"new": "New",
"generating": "Generating"
"new": "New"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -705,8 +704,6 @@
"baseModel": "Base Model",
"cancel": "Cancel",
"clipEmbed": "CLIP Embed",
"clipLEmbed": "CLIP-L Embed",
"clipGEmbed": "CLIP-G Embed",
"config": "Config",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
@@ -1000,7 +997,7 @@
"controlNetControlMode": "Control Mode",
"copyImage": "Copy Image",
"denoisingStrength": "Denoising Strength",
"disabledNoRasterContent": "Disabled (No Raster Content)",
"noRasterLayers": "No Raster Layers",
"downloadImage": "Download Image",
"general": "General",
"guidance": "Guidance",
@@ -1140,7 +1137,6 @@
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"showDetailedInvocationProgress": "Show Progress Details",
"showProgressInViewer": "Show Progress Images in Viewer",
"ui": "User Interface",
"clearIntermediatesDisabled": "Queue must be empty to clear intermediates",
@@ -1675,7 +1671,7 @@
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"outputOnlyMaskedRegions": "Output Only Masked Regions",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -1791,7 +1787,7 @@
},
"ipAdapterMethod": {
"ipAdapterMethod": "IP Adapter Method",
"full": "Style and Composition",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only"
},
@@ -2003,9 +1999,7 @@
"upscaleModelDesc": "Upscale (image to image) model",
"missingUpscaleInitialImage": "Missing initial image for upscaling",
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed",
"incompatibleBaseModel": "Unsupported main model architecture for upscaling",
"incompatibleBaseModelDesc": "Upscaling is supported for SD1.5 and SDXL architecture models only. Change the main model to enable upscaling."
"missingTileControlNetModel": "No valid tile ControlNet models installed"
},
"stylePresets": {
"active": "Active",
@@ -2108,10 +2102,8 @@
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"<StrongComponent>SD 3.5</StrongComponent>: Support for Text-to-Image in Workflows with SD 3.5 Medium and Large.",
"<StrongComponent>Canvas</StrongComponent>: Streamlined Control Layer processing and improved default Control settings."
],
"line1": "<StrongComponent>Layer Merging</StrongComponent>: New <StrongComponent>Merge Down</StrongComponent> and improved <StrongComponent>Merge Visible</StrongComponent> for all layers, with special handling for Regional Guidance and Control Layers.",
"line2": "<StrongComponent>HF Token Support</StrongComponent>: Upload models that require Hugging Face authentication.",
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -8,8 +8,10 @@ import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import {
@@ -17,7 +19,6 @@ import {
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
@@ -61,6 +62,8 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
const handleReset = useCallback(() => {
clearStorage();
location.reload();
@@ -89,8 +92,19 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<Box
id="invoke-app-wrapper"
w="100dvw"
h="100dvh"
position="relative"
overflow="hidden"
{...dropzone.getRootProps()}
>
<input {...dropzone.getInputProps()} />
<AppContent />
{dropzone.isDragActive && isHandlingUpload && (
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
)}
</Box>
<DeleteImageModal />
<ChangeBoardModal />
@@ -107,7 +121,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<NewGallerySessionDialog />
<NewCanvasSessionDialog />
<ImageContextMenu />
<FullscreenDropzone />
</ErrorBoundary>
);
};

View File

@@ -1,3 +1,4 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
@@ -7,11 +8,13 @@ import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const GlobalImageHotkeys = memo(() => {
useAssertSingleton('GlobalImageHotkeys');
const imageDTO = useAppSelector(selectLastSelectedImage);
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const { currentData: imageDTO } = useGetImageDTOQuery(lastSelectedImage?.image_name ?? skipToken);
if (!imageDTO) {
return null;

View File

@@ -19,6 +19,7 @@ import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
import { createStore } from 'app/store/store';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
@@ -236,7 +237,9 @@ const InvokeAIUI = ({
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
<AppDndContext>
<App config={config} studioInitAction={studioInitAction} />
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>
</Provider>

View File

@@ -17,7 +17,6 @@ const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export const zLogNamespace = z.enum([
'canvas',
'config',
'dnd',
'events',
'gallery',
'generation',

View File

@@ -1,3 +1,4 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
/** @knipignore */
export const EMPTY_OBJECT = {};

View File

@@ -16,6 +16,7 @@ import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMi
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageDeletionListeners } from 'app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners';
import { addImageDroppedListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImagesStarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesStarred';
import { addImagesUnstarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesUnstarred';
@@ -92,6 +93,9 @@ addGetOpenAPISchemaListener(startAppListening);
addWorkflowLoadRequestedListener(startAppListening);
addUpdateAllNodesRequestedListener(startAppListening);
// DND
addImageDroppedListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);

View File

@@ -1,12 +1,12 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('queue');
@@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult } as JsonObject, t('queue.graphQueued'));
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
} catch (error) {
log.error({ enqueueBatchArg } as JsonObject, t('queue.graphFailedToQueue'));
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
if (error instanceof Object && 'status' in error && error.status === 403) {
return;

View File

@@ -1,12 +1,12 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';
const log = logger('queue');
@@ -17,7 +17,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
effect: (action) => {
const enqueueResult = action.payload;
const arg = action.meta.arg.originalArgs;
log.debug({ enqueueResult } as JsonObject, 'Batch enqueued');
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
toast({
id: 'QUEUE_BATCH_SUCCEEDED',
@@ -45,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
status: 'error',
description: t('common.unknownError'),
});
log.error({ batchConfig } as JsonObject, t('queue.batchFailedToQueue'));
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
return;
}
@@ -71,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
description: t('common.unknownError'),
});
}
log.error({ batchConfig, error: serializeError(response) } as JsonObject, t('queue.batchFailedToQueue'));
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
},
});
};

View File

@@ -1,22 +1,19 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { toast } from 'features/toast/toast';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert, AssertionError } from 'tsafe';
import type { JsonObject } from 'type-fest';
import { assert } from 'tsafe';
const log = logger('generation');
@@ -35,8 +32,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let buildGraphResult: Result<
{
g: Graph;
noise: Invocation<'noise' | 'flux_denoise' | 'sd3_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'sd3_text_encoder'>;
noise: Invocation<'noise' | 'flux_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
},
Error
>;
@@ -52,9 +49,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
case `sd-3`:
buildGraphResult = await withResultAsync(() => buildSD3Graph(state, manager));
break;
case `flux`:
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
break;
@@ -63,17 +57,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
}
if (buildGraphResult.isErr()) {
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status: 'error',
title: 'Failed to build graph',
description,
});
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
return;
}
@@ -104,7 +88,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return;
}
log.debug({ batchConfig: prepareBatchResult.value } as JsonObject, 'Enqueued batch');
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
},
});
};

View File

@@ -1,12 +1,12 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { JsonObject } from 'type-fest';
const log = logger('system');
@@ -16,12 +16,12 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
effect: (action, { getState }) => {
const schemaJSON = action.payload;
log.debug({ schemaJSON: parseify(schemaJSON) } as JsonObject, 'Received OpenAPI schema');
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
log.debug({ nodeTemplates } as JsonObject, `Built ${size(nodeTemplates)} node templates`);
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
$templates.set(nodeTemplates);
},

View File

@@ -0,0 +1,333 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
entityRasterized,
entitySelected,
inpaintMaskAdded,
rasterLayerAdded,
referenceImageAdded,
referenceImageIPAdapterImageChanged,
rgAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { imagesApi } from 'services/api/endpoints/images';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
activeData: TypesafeDraggableData;
}>('dnd/dndDropped');
const log = logger('system');
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: dndDropped,
effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
}
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug({ activeData, overData }, 'Node field dropped');
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
/**
* Image dropped on IP Adapter Layer
*/
if (
overData.actionType === 'SET_IPA_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id } = overData.context;
dispatch(
referenceImageIPAdapterImageChanged({
entityIdentifier: { id, type: 'reference_image' },
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on RG Layer IP Adapter
*/
if (
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id, referenceImageId } = overData.context;
dispatch(
rgIPAdapterImageChanged({
entityIdentifier: { id, type: 'regional_guidance' },
referenceImageId,
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on Raster layer
*/
if (
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
return;
}
/**
/**
* Image dropped on Inpaint Mask
*/
if (
overData.actionType === 'ADD_INPAINT_MASK_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasInpaintMaskState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
return;
}
/**
/**
* Image dropped on Regional Guidance
*/
if (
overData.actionType === 'ADD_REGIONAL_GUIDANCE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRegionalGuidanceState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/
if (
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageObject],
position: { x, y },
controlAdapter: deepClone(initialControlNet),
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
return;
}
if (
overData.actionType === 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasRegionalGuidanceState> = {
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }],
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
if (
overData.actionType === 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasReferenceImageState> = {
ipAdapter,
};
dispatch(referenceImageAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/
if (overData.actionType === 'REPLACE_LAYER_WITH_IMAGE' && activeData.payloadType === 'IMAGE_DTO') {
const state = getState();
const { entityIdentifier } = overData.context;
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
}
/**
* Image dropped on node image field
*/
if (
overData.actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldImageValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image selected for compare
*/
if (
overData.actionType === 'SELECT_FOR_COMPARE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(imageToCompareChanged(imageDTO));
return;
}
/**
* Image dropped on user board
*/
if (
overData.actionType === 'ADD_TO_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO,
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Image dropped on 'none' board
*/
if (
overData.actionType === 'REMOVE_FROM_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Image dropped on upscale initial image
*/
if (
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(upscaleInitialImageChanged(imageDTO));
return;
}
/**
* Multiple images dropped on user board
*/
if (overData.actionType === 'ADD_TO_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
const imageDTOs = getState().gallery.selection;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImagesToBoard.initiate({
imageDTOs,
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
/**
* Multiple images dropped on 'none' board
*/
if (overData.actionType === 'REMOVE_FROM_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
const imageDTOs = getState().gallery.selection;
dispatch(
imagesApi.endpoints.removeImagesFromBoard.initiate({
imageDTOs,
})
);
dispatch(selectionChanged([]));
return;
}
},
});
};

View File

@@ -1,8 +1,18 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import {
entityRasterized,
entitySelected,
referenceImageIPAdapterImageChanged,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
@@ -41,14 +51,12 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
log.debug({ imageDTO }, 'Image uploaded');
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
// like toasts and switching the gallery view
const { postUploadAction } = action.meta.arg.originalArgs;
if (!postUploadAction) {
return;
}
const boardId = imageDTO.board_id ?? 'none';
const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'),
@@ -56,34 +64,80 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
} as const;
// default action - just upload and alert user
if (lastUploadedToastTimeout !== null) {
window.clearTimeout(lastUploadedToastTimeout);
if (postUploadAction.type === 'TOAST') {
const boardId = imageDTO.board_id ?? 'none';
if (lastUploadedToastTimeout !== null) {
window.clearTimeout(lastUploadedToastTimeout);
}
const toastApi = toast({
...DEFAULT_UPLOADED_TOAST,
title: postUploadAction.title || DEFAULT_UPLOADED_TOAST.title,
description: getUploadedToastDescription(boardId, state),
duration: null, // we will close the toast manually
});
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
* the user's gallery board and view selection:
* - User uploads multiple images
* - A couple uploads finish, but others are pending still
* - User changes the board selection
* - Pending uploads finish and change the board back to the original board
* - User is confused as to why the board changed
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
}
return;
}
const toastApi = toast({
...DEFAULT_UPLOADED_TOAST,
title: DEFAULT_UPLOADED_TOAST.title,
description: getUploadedToastDescription(boardId, state),
duration: null, // we will close the toast manually
});
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
* the user's gallery board and view selection:
* - User uploads multiple images
* - A couple uploads finish, but others are pending still
* - User changes the board selection
* - Pending uploads finish and change the board back to the original board
* - User is confused as to why the board changed
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
if (postUploadAction.type === 'SET_UPSCALE_INITIAL_IMAGE') {
dispatch(upscaleInitialImageChanged(imageDTO));
toast({
...DEFAULT_UPLOADED_TOAST,
description: 'set as upscale initial image',
});
return;
}
if (postUploadAction.type === 'SET_IPA_IMAGE') {
const { id } = postUploadAction;
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: { id, type: 'reference_image' }, imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction.type === 'SET_RG_IP_ADAPTER_IMAGE') {
const { id, referenceImageId } = postUploadAction;
dispatch(
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, referenceImageId, imageDTO })
);
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
return;
}
if (postUploadAction.type === 'REPLACE_LAYER_WITH_IMAGE') {
const { entityIdentifier } = postUploadAction;
const state = getState();
const imageObject = imageDTOToImageObject(imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
}
},
});

View File

@@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import {
controlLayerModelChanged,
referenceImageIPAdapterModelChanged,
@@ -40,7 +41,6 @@ import {
isSpandrelImageToImageModelConfig,
isT5EncoderModelConfig,
} from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('models');
@@ -85,7 +85,7 @@ type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<JsonObject>
log: Logger<SerializableObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {

View File

@@ -3,6 +3,7 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
@@ -36,7 +37,6 @@ import undoable from 'redux-undo';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
@@ -139,7 +139,7 @@ const unserialize: UnserializeFunction = (data, key) => {
{
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
diff: diff(parsed, transformed) as SerializableObject, // this is always serializable
},
`Rehydrated slice "${key}"`
);

View File

@@ -25,8 +25,7 @@ export type AppFeature =
| 'invocationCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken'
| 'invocationProgressAlert';
| 'hfToken';
/**
* A disable-able Stable Diffusion feature

View File

@@ -0,0 +1,251 @@
import type { ChakraProps, FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { IAILoadingImageFallback, IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable';
const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
const baseStyles: SystemStyleObject = {
touchAction: 'none',
userSelect: 'none',
webkitUserSelect: 'none',
};
const sx: SystemStyleObject = {
...baseStyles,
'.gallery-image-container::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected="selected"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected="selectedForCompare"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected="selected"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected="selectedForCompare"]>.gallery-image-container::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
};
type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
withMetadataOverlay?: boolean;
isDragDisabled?: boolean;
isDropDisabled?: boolean;
isUploadDisabled?: boolean;
minSize?: number;
postUploadAction?: PostUploadAction;
imageSx?: ChakraProps['sx'];
fitContainer?: boolean;
droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData;
dropLabel?: string;
isSelected?: boolean;
isSelectedForCompare?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
withHoverOverlay?: boolean;
children?: JSX.Element;
uploadElement?: ReactNode;
dataTestId?: string;
};
const IAIDndImage = (props: IAIDndImageProps) => {
const {
imageDTO,
onError,
onClick,
withMetadataOverlay = false,
isDropDisabled = false,
isDragDisabled = false,
isUploadDisabled = false,
minSize = 24,
postUploadAction,
imageSx,
fitContainer = false,
droppableData,
draggableData,
dropLabel,
isSelected = false,
isSelectedForCompare = false,
thumbnail = false,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
useThumbailFallback,
withHoverOverlay = false,
children,
dataTestId,
...rest
} = props;
const openInNewTab = useCallback(
(e: MouseEvent) => {
if (!imageDTO) {
return;
}
if (e.button !== 1) {
return;
}
window.open(imageDTO.image_url, '_blank');
},
[imageDTO]
);
const ref = useRef<HTMLDivElement>(null);
useImageContextMenu(imageDTO, ref);
return (
<Flex
ref={ref}
width="full"
height="full"
alignItems="center"
justifyContent="center"
position="relative"
minW={minSize ? minSize : undefined}
minH={minSize ? minSize : undefined}
userSelect="none"
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
sx={withHoverOverlay ? sx : baseStyles}
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
{...rest}
>
{imageDTO && (
<Flex
className="gallery-image-container"
w="full"
h="full"
position={fitContainer ? 'absolute' : 'relative'}
alignItems="center"
justifyContent="center"
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
onError={onError}
draggable={false}
w={imageDTO.width}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
sx={imageSx}
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<UploadButton
isUploadDisabled={isUploadDisabled}
postUploadAction={postUploadAction}
uploadElement={uploadElement}
minSize={minSize}
/>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
</Flex>
);
};
export default memo(IAIDndImage);
const UploadButton = memo(
({
isUploadDisabled,
postUploadAction,
uploadElement,
minSize,
}: {
isUploadDisabled: boolean;
postUploadAction?: PostUploadAction;
uploadElement: ReactNode;
minSize: number;
}) => {
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
isDisabled: isUploadDisabled,
});
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
const styles: SystemStyleObject = {
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: 'base.500',
};
if (!isUploadDisabled) {
Object.assign(styles, {
cursor: 'pointer',
bg: 'base.700',
_hover: {
bg: 'base.650',
color: 'base.300',
},
});
}
return styles;
}, [isUploadDisabled, minSize]);
return (
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
<input {...getUploadInputProps()} />
{uploadElement}
</Flex>
);
}
);
UploadButton.displayName = 'UploadButton';

View File

@@ -21,7 +21,7 @@ type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
tooltip: string;
};
export const DndImageIcon = memo((props: Props) => {
const IAIDndImageIcon = (props: Props) => {
const { onClick, tooltip, icon, ...rest } = props;
return (
@@ -35,6 +35,6 @@ export const DndImageIcon = memo((props: Props) => {
{...rest}
/>
);
});
};
DndImageIcon.displayName = 'DndImageIcon';
export default memo(IAIDndImageIcon);

View File

@@ -0,0 +1,38 @@
import type { BoxProps } from '@invoke-ai/ui-library';
import { Box } from '@invoke-ai/ui-library';
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDraggableData } from 'features/dnd/types';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
type IAIDraggableProps = BoxProps & {
disabled?: boolean;
data?: TypesafeDraggableData;
};
const IAIDraggable = (props: IAIDraggableProps) => {
const { data, disabled, ...rest } = props;
const dndId = useRef(uuidv4());
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
id: dndId.current,
disabled,
data,
});
return (
<Box
ref={setNodeRef}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
{...attributes}
{...listeners}
{...rest}
/>
);
};
export default memo(IAIDraggable);

View File

@@ -0,0 +1,64 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
type Props = {
isOver: boolean;
label?: string;
withBackdrop?: boolean;
};
const IAIDropOverlay = (props: Props) => {
const { isOver, label, withBackdrop = true } = props;
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0}>
<Flex
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
w="full"
h="full"
bg={withBackdrop ? 'base.900' : 'transparent'}
opacity={0.7}
borderRadius="base"
alignItems="center"
justifyContent="center"
transitionProperty="common"
transitionDuration="0.1s"
/>
<Flex
position="absolute"
top={0.5}
right={0.5}
bottom={0.5}
left={0.5}
opacity={1}
borderWidth={1.5}
borderColor={isOver ? 'invokeYellow.300' : 'base.500'}
borderRadius="base"
borderStyle="dashed"
transitionProperty="common"
transitionDuration="0.1s"
alignItems="center"
justifyContent="center"
>
{label && (
<Text
fontSize="lg"
fontWeight="semibold"
color={isOver ? 'invokeYellow.300' : 'base.500'}
transitionProperty="common"
transitionDuration="0.1s"
textAlign="center"
>
{label}
</Text>
)}
</Flex>
</Flex>
);
};
export default memo(IAIDropOverlay);

View File

@@ -0,0 +1,46 @@
import { Box } from '@invoke-ai/ui-library';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import type { TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { AnimatePresence } from 'framer-motion';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay';
type IAIDroppableProps = {
dropLabel?: string;
disabled?: boolean;
data?: TypesafeDroppableData;
};
const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppableTypesafe({
id: dndId.current,
disabled,
data,
});
return (
<Box
ref={setNodeRef}
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
w="full"
h="full"
pointerEvents={active ? 'auto' : 'none'}
>
<AnimatePresence>
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
</AnimatePresence>
</Box>
);
};
export default memo(IAIDroppable);

View File

@@ -0,0 +1,24 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Skeleton } from '@invoke-ai/ui-library';
import { memo } from 'react';
const skeletonStyles: SystemStyleObject = {
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
};
const IAIFillSkeleton = () => {
return (
<Skeleton sx={skeletonStyles}>
<Box position="absolute" top={0} insetInlineStart={0} height="full" width="full" />
</Skeleton>
);
};
export default memo(IAIFillSkeleton);

View File

@@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
type Props = { image: ImageDTO | undefined };
const IAILoadingImageFallback = memo((props: Props) => {
export const IAILoadingImageFallback = memo((props: Props) => {
if (props.image) {
return (
<Skeleton

View File

@@ -0,0 +1,28 @@
import { Badge, Flex } from '@invoke-ai/ui-library';
import { memo } from 'react';
import type { ImageDTO } from 'services/api/types';
type ImageMetadataOverlayProps = {
imageDTO: ImageDTO;
};
const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => {
return (
<Flex
pointerEvents="none"
flexDirection="column"
position="absolute"
top={0}
insetInlineStart={0}
p={2}
alignItems="flex-start"
gap={2}
>
<Badge variant="solid" colorScheme="base">
{imageDTO.width} × {imageDTO.height}
</Badge>
</Flex>
);
};
export default memo(ImageMetadataOverlay);

View File

@@ -0,0 +1,89 @@
import { Box, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { memo } from 'react';
import type { DropzoneState } from 'react-dropzone';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useBoardName } from 'services/api/hooks/useBoardName';
type ImageUploadOverlayProps = {
dropzone: DropzoneState;
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
};
const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
const { dropzone, setIsHandlingUpload } = props;
useHotkeys(
'esc',
() => {
setIsHandlingUpload(false);
},
[setIsHandlingUpload]
);
return (
<Box position="absolute" top={0} right={0} bottom={0} left={0} zIndex={999} backdropFilter="blur(20px)">
<Flex position="absolute" top={0} right={0} bottom={0} left={0} bg="base.900" opacity={0.7} />
<Flex
position="absolute"
flexDir="column"
gap={4}
top={2}
right={2}
bottom={2}
left={2}
opacity={1}
borderWidth={2}
borderColor={dropzone.isDragAccept ? 'invokeYellow.300' : 'error.500'}
borderRadius="base"
borderStyle="dashed"
transitionProperty="common"
transitionDuration="0.1s"
alignItems="center"
justifyContent="center"
color={dropzone.isDragReject ? 'error.300' : undefined}
>
{dropzone.isDragAccept && <DragAcceptMessage />}
{!dropzone.isDragAccept && <DragRejectMessage />}
</Flex>
</Box>
);
};
export default memo(ImageUploadOverlay);
const DragAcceptMessage = () => {
const { t } = useTranslation();
const selectedBoardId = useAppSelector(selectSelectedBoardId);
const boardName = useBoardName(selectedBoardId);
return (
<>
<Heading size="lg">{t('gallery.dropToUpload')}</Heading>
<Heading size="md">{t('toast.imagesWillBeAddedTo', { boardName })}</Heading>
</>
);
};
const DragRejectMessage = () => {
const { t } = useTranslation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
if (maxImageUploadCount === undefined) {
return (
<>
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc')}</Heading>
</>
);
}
return (
<>
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount })}</Heading>
</>
);
};

View File

@@ -1,13 +1,9 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { autoScrollForElements } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/element';
import { autoScrollForExternal } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/external';
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import type { OverlayScrollbarsComponentRef } from 'overlayscrollbars-react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties, PropsWithChildren } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
import { memo, useMemo } from 'react';
type Props = PropsWithChildren & {
maxHeight?: ChakraProps['maxHeight'];
@@ -15,38 +11,17 @@ type Props = PropsWithChildren & {
overflowY?: 'hidden' | 'scroll';
};
const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 };
const styles: CSSProperties = { height: '100%', width: '100%' };
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
const overlayscrollbarsOptions = useMemo(
() => getOverlayScrollbarsParams(overflowX, overflowY).options,
[overflowX, overflowY]
);
const [os, osRef] = useState<OverlayScrollbarsComponentRef | null>(null);
useEffect(() => {
const osInstance = os?.osInstance();
if (!osInstance) {
return;
}
const element = osInstance.elements().viewport;
// `pragmatic-drag-and-drop-auto-scroll` requires the element to have `overflow-y: scroll` or `overflow-y: auto`
// else it logs an ugly warning. In our case, using a custom scrollbar library, it will be 'hidden' by default.
// To prevent the erroneous warning, we temporarily set the overflow-y to 'scroll' and then revert it back.
const overflowY = element.style.overflowY; // starts 'hidden'
element.style.setProperty('overflow-y', 'scroll', 'important');
const cleanup = combine(autoScrollForElements({ element }), autoScrollForExternal({ element }));
element.style.setProperty('overflow-y', overflowY);
return cleanup;
}, [os]);
return (
<Flex w="full" h="full" maxHeight={maxHeight} position="relative">
<Box position="absolute" top={0} left={0} right={0} bottom={0}>
<OverlayScrollbarsComponent ref={osRef} style={styles} options={overlayscrollbarsOptions}>
<OverlayScrollbarsComponent defer style={styles} options={overlayscrollbarsOptions}>
{children}
</OverlayScrollbarsComponent>
</Box>

View File

@@ -0,0 +1,124 @@
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import type { Accept, FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
const log = logger('gallery');
const accept: Accept = {
'image/png': ['.png'],
'image/jpeg': ['.jpg', '.jpeg', '.png'],
};
export const useFullscreenDropzone = () => {
useAssertSingleton('useFullscreenDropzone');
const { t } = useTranslation();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
const [uploadImage] = useUploadImageMutation();
const activeTabName = useAppSelector(selectActiveTab);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const getPostUploadAction = useCallback((): PostUploadAction => {
if (activeTabName === 'upscaling') {
return { type: 'SET_UPSCALE_INITIAL_IMAGE' };
} else {
return { type: 'TOAST' };
}
}, [activeTabName]);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
if (fileRejections.length > 0) {
const errors = fileRejections.map((rejection) => ({
errors: rejection.errors.map(({ message }) => message),
file: rejection.file.path,
}));
log.error({ errors }, 'Invalid upload');
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description,
status: 'error',
});
setIsHandlingUpload(false);
return;
}
for (const [i, file] of acceptedFiles.entries()) {
uploadImage({
file,
image_category: 'user',
is_intermediate: false,
postUploadAction: getPostUploadAction(),
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
// The `imageUploaded` listener does some extra logic, like switching to the asset view on upload on the
// first upload of a "batch".
isFirstUploadOfBatch: i === 0,
});
}
setIsHandlingUpload(false);
},
[t, maxImageUploadCount, uploadImage, getPostUploadAction, autoAddBoardId]
);
const onDragOver = useCallback(() => {
setIsHandlingUpload(true);
}, []);
const onDragLeave = useCallback(() => {
setIsHandlingUpload(false);
}, []);
const dropzone = useDropzone({
accept,
noClick: true,
onDrop,
onDragOver,
onDragLeave,
noKeyboard: true,
multiple: maxImageUploadCount === undefined || maxImageUploadCount > 1,
maxFiles: maxImageUploadCount,
});
useEffect(() => {
// This is a hack to allow pasting images into the uploader
const handlePaste = (e: ClipboardEvent) => {
if (!dropzone.inputRef.current) {
return;
}
if (e.clipboardData?.files) {
// Set the files on the dropzone.inputRef
dropzone.inputRef.current.files = e.clipboardData.files;
// Dispatch the change event, dropzone catches this and we get to use its own validation
dropzone.inputRef.current?.dispatchEvent(new Event('change', { bubbles: true }));
}
};
// Add the paste event listener
document.addEventListener('paste', handlePaste);
return () => {
document.removeEventListener('paste', handlePaste);
};
}, [dropzone.inputRef]);
return { dropzone, isHandlingUpload, setIsHandlingUpload };
};

View File

@@ -1,5 +1,3 @@
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
@@ -9,23 +7,14 @@ import { useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiUploadBold } from 'react-icons/pi';
import { uploadImages, useUploadImageMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
type UseImageUploadButtonArgs =
| {
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
};
type UseImageUploadButtonArgs = {
postUploadAction?: PostUploadAction;
isDisabled?: boolean;
allowMultiple?: boolean;
};
const log = logger('gallery');
@@ -48,53 +37,30 @@ const log = logger('gallery');
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
*/
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
export const useImageUploadButton = ({
postUploadAction,
isDisabled,
allowMultiple = false,
}: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [uploadImage, request] = useUploadImageMutation();
const [uploadImage] = useUploadImageMutation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const { t } = useTranslation();
const onDropAccepted = useCallback(
async (files: File[]) => {
if (!allowMultiple) {
if (files.length > 1) {
log.warn('Multiple files dropped but only one allowed');
return;
}
if (files.length === 0) {
// Should never happen
log.warn('No files dropped');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
(files: File[]) => {
for (const [i, file] of files.entries()) {
uploadImage({
file,
image_category: 'user',
is_intermediate: false,
postUploadAction: postUploadAction ?? { type: 'TOAST' },
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: true,
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
const imageDTOs = await uploadImages(
files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
if (onUpload) {
onUpload(imageDTOs);
}
isFirstUploadOfBatch: i === 0,
});
}
},
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
[autoAddBoardId, postUploadAction, uploadImage]
);
const onDropRejected = useCallback(
@@ -137,42 +103,5 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
maxFiles: maxImageUploadCount,
});
return { getUploadButtonProps, getUploadInputProps, openUploader, request };
};
const sx = {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 0,
borderRadius: 'base',
'&[data-error=true]': {
borderWidth: 1,
},
} satisfies SystemStyleObject;
export const UploadImageButton = ({
isDisabled = false,
onUpload,
isError = false,
...rest
}: {
onUpload?: (imageDTO: ImageDTO) => void;
isError?: boolean;
} & SetOptional<IconButtonProps, 'aria-label'>) => {
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload });
return (
<>
<IconButton
aria-label="Upload image"
variant="ghost"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
isLoading={uploadApi.request.isLoading}
{...rest}
{...uploadApi.getUploadButtonProps()}
/>
<input {...uploadApi.getUploadInputProps()} />
</>
);
return { getUploadButtonProps, getUploadInputProps, openUploader };
};

View File

@@ -119,20 +119,11 @@ const createSelector = (
reasons.push({ content: i18n.t('upscaling.exceedsMaxSize') });
}
}
if (model && !['sd-1', 'sdxl'].includes(model.base)) {
// When we are using an upsupported model, do not add the other warnings
reasons.push({ content: i18n.t('upscaling.incompatibleBaseModel') });
} else {
// Using a compatible model, add all warnings
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
}
} else {
if (canvasIsFiltering) {

View File

@@ -0,0 +1,14 @@
import { getPrefixedId, nanoid } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
export const useNanoid = (prefix?: string) => {
const id = useMemo(() => {
if (prefix) {
return getPrefixedId(prefix);
} else {
return nanoid();
}
}, [prefix]);
return id;
};

View File

@@ -0,0 +1,12 @@
type SerializableValue =
| string
| number
| boolean
| null
| undefined
| SerializableValue[]
| readonly SerializableValue[]
| SerializableObject;
export type SerializableObject = {
[k: string | number]: SerializableValue;
};

View File

@@ -1,6 +0,0 @@
import type { AssertionError } from 'tsafe';
export function extractMessageFromAssertionError(error: AssertionError): string | null {
const match = error.message.match(/Wrong assertion encountered: "(.*)"/);
return match ? (match[1] ?? null) : null;
}

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -23,7 +23,6 @@ export const CanvasAddEntityButtons = memo(() => {
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
@@ -37,7 +36,6 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
isDisabled={isSD3}
>
{t('controlLayers.globalReferenceImage')}
</Button>
@@ -63,7 +61,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX || isSD3}
isDisabled={isFLUX}
>
{t('controlLayers.regionalGuidance')}
</Button>
@@ -75,7 +73,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX || isSD3}
isDisabled={isFLUX}
>
{t('controlLayers.regionalReferenceImage')}
</Button>
@@ -90,7 +88,6 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
isDisabled={isSD3}
>
{t('controlLayers.controlLayer')}
</Button>

View File

@@ -1,45 +0,0 @@
import { Alert, AlertDescription, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemShouldShowInvocationProgressDetail } from 'features/system/store/systemSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { $invocationProgressMessage } from 'services/events/stores';
const CanvasAlertsInvocationProgressContent = memo(() => {
const { t } = useTranslation();
const invocationProgressMessage = useStore($invocationProgressMessage);
if (!invocationProgressMessage) {
return null;
}
return (
<Alert status="loading" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('common.generating')}</AlertTitle>
<AlertDescription>{invocationProgressMessage}</AlertDescription>
</Alert>
);
});
CanvasAlertsInvocationProgressContent.displayName = 'CanvasAlertsInvocationProgressContent';
export const CanvasAlertsInvocationProgress = memo(() => {
const isProgressMessageAlertEnabled = useFeatureStatus('invocationProgressAlert');
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
// The alert is disabled at the system level
if (!isProgressMessageAlertEnabled) {
return null;
}
// The alert is disabled at the user level
if (!shouldShowInvocationProgressDetail) {
return null;
}
return <CanvasAlertsInvocationProgressContent />;
});
CanvasAlertsInvocationProgress.displayName = 'CanvasAlertsInvocationProgress';

View File

@@ -1,26 +1,38 @@
import { Grid, GridItem } from '@invoke-ai/ui-library';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { newCanvasEntityFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import IAIDroppable from 'common/components/IAIDroppable';
import type {
AddControlLayerFromImageDropData,
AddGlobalReferenceImageFromImageDropData,
AddRasterLayerFromImageDropData,
AddRegionalReferenceImageFromImageDropData,
} from 'features/dnd/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const addRasterLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({ type: 'raster_layer' });
const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'control_layer',
});
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addGlobalReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'reference_image',
});
const addRasterLayerFromImageDropData: AddRasterLayerFromImageDropData = {
id: 'add-raster-layer-from-image-drop-data',
actionType: 'ADD_RASTER_LAYER_FROM_IMAGE',
};
const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_CONTROL_LAYER_FROM_IMAGE',
};
const addRegionalReferenceImageFromImageDropData: AddRegionalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE',
};
const addGlobalReferenceImageFromImageDropData: AddGlobalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE',
};
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
if (imageViewer.isOpen) {
return null;
@@ -39,36 +51,28 @@ export const CanvasDropArea = memo(() => {
pointerEvents="none"
>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRasterLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRasterLayer')}
isDisabled={isBusy}
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newRasterLayer')}
data={addRasterLayerFromImageDropData}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addControlLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newControlLayer')}
isDisabled={isBusy}
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newControlLayer')}
data={addControlLayerFromImageDropData}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRegionalGuidanceReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
isDisabled={isBusy}
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
data={addRegionalReferenceImageFromImageDropData}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addGlobalReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
isDisabled={isBusy}
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
data={addGlobalReferenceImageFromImageDropData}
/>
</GridItem>
</Grid>

View File

@@ -1,59 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasEntityListDnd } from 'features/controlLayers/components/CanvasEntityList/useCanvasEntityListDnd';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsSelected } from 'features/controlLayers/hooks/useEntityIsSelected';
import { entitySelected } from 'features/controlLayers/store/canvasSlice';
import { DndListDropIndicator } from 'features/dnd/DndListDropIndicator';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useRef } from 'react';
const sx = {
position: 'relative',
flexDir: 'column',
w: 'full',
bg: 'base.850',
borderRadius: 'base',
'&[data-selected=true]': {
bg: 'base.800',
},
'&[data-is-dragging=true]': {
opacity: 0.3,
},
transitionProperty: 'common',
} satisfies SystemStyleObject;
export const CanvasEntityContainer = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
const isSelected = useEntityIsSelected(entityIdentifier);
const onClick = useCallback(() => {
if (isSelected) {
return;
}
dispatch(entitySelected({ entityIdentifier }));
}, [dispatch, entityIdentifier, isSelected]);
const ref = useRef<HTMLDivElement>(null);
const [dndListState, isDragging] = useCanvasEntityListDnd(ref, entityIdentifier);
return (
<Box position="relative">
<Flex
// This is used to trigger the post-move flash animation
data-entity-id={entityIdentifier.id}
data-selected={isSelected}
data-is-dragging={isDragging}
ref={ref}
onClick={onClick}
sx={sx}
>
{props.children}
</Flex>
<DndListDropIndicator dndState={dndListState} />
</Box>
);
});
CanvasEntityContainer.displayName = 'CanvasEntityContainer';

View File

@@ -1,181 +0,0 @@
import { monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
import { reorderWithEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/util/reorder-with-edge';
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useBoolean } from 'common/hooks/useBoolean';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import { entitiesReordered } from 'features/controlLayers/store/canvasSlice';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isRenderableEntityType } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { triggerPostMoveFlash } from 'features/dnd/util';
import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';
import { flushSync } from 'react-dom';
import { PiCaretDownBold } from 'react-icons/pi';
type Props = PropsWithChildren<{
isSelected: boolean;
type: CanvasEntityIdentifier['type'];
entityIdentifiers: CanvasEntityIdentifier[];
}>;
export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityIdentifiers }: Props) => {
const title = useEntityTypeTitle(type);
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
const collapse = useBoolean(true);
const dispatch = useAppDispatch();
useEffect(() => {
return monitorForElements({
canMonitor({ source }) {
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
return false;
}
if (source.data.payload.entityIdentifier.type !== type) {
return false;
}
return true;
},
onDrop({ location, source }) {
const target = location.current.dropTargets[0];
if (!target) {
return;
}
const sourceData = source.data;
const targetData = target.data;
if (!singleCanvasEntityDndSource.typeGuard(sourceData) || !singleCanvasEntityDndSource.typeGuard(targetData)) {
return;
}
const indexOfSource = entityIdentifiers.findIndex(
(entityIdentifier) => entityIdentifier.id === sourceData.payload.entityIdentifier.id
);
const indexOfTarget = entityIdentifiers.findIndex(
(entityIdentifier) => entityIdentifier.id === targetData.payload.entityIdentifier.id
);
if (indexOfTarget < 0 || indexOfSource < 0) {
return;
}
// Don't move if the source and target are the same index, meaning same position in the list
if (indexOfSource === indexOfTarget) {
return;
}
const closestEdgeOfTarget = extractClosestEdge(targetData);
// It's possible that the indices are different, but refer to the same position. For example, if the source is
// at 2 and the target is at 3, but the target edge is 'top', then the entity is already in the correct position.
// We should bail if this is the case.
let edgeIndexDelta = 0;
if (closestEdgeOfTarget === 'bottom') {
edgeIndexDelta = 1;
} else if (closestEdgeOfTarget === 'top') {
edgeIndexDelta = -1;
}
// If the source is already in the correct position, we don't need to move it.
if (indexOfSource === indexOfTarget + edgeIndexDelta) {
return;
}
// Using `flushSync` so we can query the DOM straight after this line
flushSync(() => {
dispatch(
entitiesReordered({
type,
entityIdentifiers: reorderWithEdge({
list: entityIdentifiers,
startIndex: indexOfSource,
indexOfTarget,
closestEdgeOfTarget,
axis: 'vertical',
}),
})
);
});
// Flash the element that was moved
const element = document.querySelector(`[data-entity-id="${sourceData.payload.entityIdentifier.id}"]`);
if (element instanceof HTMLElement) {
triggerPostMoveFlash(element, colorTokenToCssVar('base.700'));
}
},
});
}, [dispatch, entityIdentifiers, type]);
return (
<Flex flexDir="column" w="full">
<Flex w="full">
<Flex
flexGrow={1}
as={Button}
onClick={collapse.toggle}
justifyContent="space-between"
alignItems="center"
gap={3}
variant="unstyled"
p={0}
h={8}
>
<Icon
boxSize={4}
as={PiCaretDownBold}
transform={collapse.isTrue ? undefined : 'rotate(-90deg)'}
fill={isSelected ? 'base.200' : 'base.500'}
transitionProperty="common"
transitionDuration="fast"
/>
{informationalPopoverFeature ? (
<InformationalPopover feature={informationalPopoverFeature}>
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
</InformationalPopover>
) : (
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
)}
<Spacer />
</Flex>
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>
<Flex flexDir="column" gap={2} pt={2}>
{children}
</Flex>
</Collapse>
</Flex>
);
});
CanvasEntityGroupList.displayName = 'CanvasEntityGroupList';

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -24,7 +24,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
<Menu>
@@ -41,7 +40,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
/>
<MenuList>
<MenuGroup title={t('controlLayers.global')}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage}>
{t('controlLayers.globalReferenceImage')}
</MenuItem>
</MenuGroup>
@@ -49,15 +48,15 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX}>
{t('controlLayers.regionalReferenceImage')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.layer_other')}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>

View File

@@ -1,85 +0,0 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable, dropTargetForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { type DndListTargetState, idle } from 'features/dnd/types';
import { firefoxDndFix } from 'features/dnd/util';
import type { RefObject } from 'react';
import { useEffect, useState } from 'react';
export const useCanvasEntityListDnd = (ref: RefObject<HTMLElement>, entityIdentifier: CanvasEntityIdentifier) => {
const [dndListState, setDndListState] = useState<DndListTargetState>(idle);
const [isDragging, setIsDragging] = useState(false);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
return combine(
firefoxDndFix(element),
draggable({
element,
getInitialData() {
return singleCanvasEntityDndSource.getData({ entityIdentifier });
},
onDragStart() {
setDndListState({ type: 'is-dragging' });
setIsDragging(true);
},
onDrop() {
setDndListState(idle);
setIsDragging(false);
},
}),
dropTargetForElements({
element,
canDrop({ source }) {
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
return false;
}
if (source.data.payload.entityIdentifier.type !== entityIdentifier.type) {
return false;
}
return true;
},
getData({ input }) {
const data = singleCanvasEntityDndSource.getData({ entityIdentifier });
return attachClosestEdge(data, {
element,
input,
allowedEdges: ['top', 'bottom'],
});
},
getIsSticky() {
return true;
},
onDragEnter({ self }) {
const closestEdge = extractClosestEdge(self.data);
setDndListState({ type: 'is-dragging-over', closestEdge });
},
onDrag({ self }) {
const closestEdge = extractClosestEdge(self.data);
// Only need to update react state if nothing has changed.
// Prevents re-rendering.
setDndListState((current) => {
if (current.type === 'is-dragging-over' && current.closestEdge === closestEdge) {
return current;
}
return { type: 'is-dragging-over', closestEdge };
});
},
onDragLeave() {
setDndListState(idle);
},
onDrop() {
setDndListState(idle);
},
})
);
}, [entityIdentifier, ref]);
return [dndListState, isDragging] as const;
};

View File

@@ -21,8 +21,6 @@ import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageV
import { memo, useCallback, useRef } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
const MenuContent = () => {
return (
<CanvasManagerProviderGate>
@@ -86,7 +84,6 @@ export const CanvasMainPanelContent = memo(() => {
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsSendingToGallery />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>
@@ -112,9 +109,7 @@ export const CanvasMainPanelContent = memo(() => {
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasManagerProviderGate>
<CanvasDropArea />
</CanvasManagerProviderGate>
<CanvasDropArea />
<GatedImageViewer />
</Flex>
);

View File

@@ -1,20 +1,16 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { dropTargetForElements, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { dropTargetForExternal, monitorForExternal } from '@atlaskit/pragmatic-drag-and-drop/external/adapter';
import { useDndContext } from '@dnd-kit/core';
import { Box, Button, Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { CanvasLayersPanelContent } from 'features/controlLayers/components/CanvasLayersPanelContent';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectEntityCountActive } from 'features/controlLayers/store/selectors';
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { selectActiveTabCanvasRightPanel } from 'features/ui/store/uiSelectors';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasRightPanel = memo(() => {
@@ -83,13 +79,37 @@ CanvasRightPanel.displayName = 'CanvasRightPanel';
const PanelTabs = memo(() => {
const { t } = useTranslation();
const store = useAppStore();
const activeTab = useAppSelector(selectActiveTabCanvasRightPanel);
const activeEntityCount = useAppSelector(selectEntityCountActive);
const [layersTabDndState, setLayersTabDndState] = useState<DndTargetState>('idle');
const [galleryTabDndState, setGalleryTabDndState] = useState<DndTargetState>('idle');
const layersTabRef = useRef<HTMLDivElement>(null);
const galleryTabRef = useRef<HTMLDivElement>(null);
const timeoutRef = useRef<number | null>(null);
const tabTimeout = useRef<number | null>(null);
const dndCtx = useDndContext();
const dispatch = useAppDispatch();
const [mouseOverTab, setMouseOverTab] = useState<'layers' | 'gallery' | null>(null);
const onOnMouseOverLayersTab = useCallback(() => {
setMouseOverTab('layers');
tabTimeout.current = window.setTimeout(() => {
if (dndCtx.active) {
dispatch(activeTabCanvasRightPanelChanged('layers'));
}
}, 300);
}, [dndCtx.active, dispatch]);
const onOnMouseOverGalleryTab = useCallback(() => {
setMouseOverTab('gallery');
tabTimeout.current = window.setTimeout(() => {
if (dndCtx.active) {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}
}, 300);
}, [dndCtx.active, dispatch]);
const onMouseOut = useCallback(() => {
setMouseOverTab(null);
if (tabTimeout.current) {
clearTimeout(tabTimeout.current);
}
}, []);
const layersTabLabel = useMemo(() => {
if (activeEntityCount === 0) {
@@ -98,172 +118,23 @@ const PanelTabs = memo(() => {
return `${t('controlLayers.layer_other')} (${activeEntityCount})`;
}, [activeEntityCount, t]);
useEffect(() => {
if (!layersTabRef.current) {
return;
}
const getIsOnLayersTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'layers';
const onDragEnter = () => {
// If we are already on the layers tab, do nothing
if (getIsOnLayersTab()) {
return;
}
// Else set the state to active and switch to the layers tab after a timeout
setLayersTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('layers'));
// When we switch tabs, the other tab should be pending
setLayersTabDndState('idle');
setGalleryTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnLayersTab()) {
setLayersTabDndState('idle');
} else {
setLayersTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setLayersTabDndState('potential');
};
return combine(
dropTargetForElements({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnLayersTab();
},
onDragStart,
}),
dropTargetForExternal({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnLayersTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
if (!galleryTabRef.current) {
return;
}
const getIsOnGalleryTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'gallery';
const onDragEnter = () => {
// If we are already on the gallery tab, do nothing
if (getIsOnGalleryTab()) {
return;
}
// Else set the state to active and switch to the gallery tab after a timeout
setGalleryTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
// When we switch tabs, the other tab should be pending
setGalleryTabDndState('idle');
setLayersTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnGalleryTab()) {
setGalleryTabDndState('idle');
} else {
setGalleryTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setGalleryTabDndState('potential');
};
return combine(
dropTargetForElements({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnGalleryTab();
},
onDragStart,
}),
dropTargetForExternal({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnGalleryTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
const onDrop = () => {
// Reset the dnd state when a drop happens
setGalleryTabDndState('idle');
setLayersTabDndState('idle');
};
const cleanup = combine(monitorForElements({ onDrop }), monitorForExternal({ onDrop }));
return () => {
cleanup();
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
}, []);
return (
<>
<Tab ref={layersTabRef} position="relative" w={32}>
<Tab position="relative" onMouseOver={onOnMouseOverLayersTab} onMouseOut={onMouseOut} w={32}>
<Box as="span" w="full">
{layersTabLabel}
</Box>
<DndDropOverlay dndState={layersTabDndState} withBackdrop={false} />
{dndCtx.active && activeTab !== 'layers' && (
<IAIDropOverlay isOver={mouseOverTab === 'layers'} withBackdrop={false} />
)}
</Tab>
<Tab ref={galleryTabRef} position="relative" w={32}>
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut} w={32}>
<Box as="span" w="full">
{t('gallery.gallery')}
</Box>
<DndDropOverlay dndState={galleryTabDndState} withBackdrop={false} />
{dndCtx.active && activeTab !== 'gallery' && (
<IAIDropOverlay isOver={mouseOverTab === 'gallery'} withBackdrop={false} />
)}
</Tab>
</>
);

View File

@@ -1,5 +1,6 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import IAIDroppable from 'common/components/IAIDroppable';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
@@ -9,11 +10,8 @@ import { ControlLayerBadges } from 'features/controlLayers/components/ControlLay
import { ControlLayerSettings } from 'features/controlLayers/components/ControlLayer/ControlLayerSettings';
import { ControlLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { ReplaceCanvasEntityObjectsWithImageDndTargetData } from 'features/dnd/dnd';
import { replaceCanvasEntityObjectsWithImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import type { ReplaceLayerImageDropData } from 'features/dnd/types';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -23,16 +21,14 @@ type Props = {
export const ControlLayer = memo(({ id }: Props) => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const entityIdentifier = useMemo<CanvasEntityIdentifier<'control_layer'>>(
() => ({ id, type: 'control_layer' }),
[id]
);
const dndTargetData = useMemo<ReplaceCanvasEntityObjectsWithImageDndTargetData>(
() => replaceCanvasEntityObjectsWithImageDndTarget.getData({ entityIdentifier }, entityIdentifier.id),
[entityIdentifier]
const dropData = useMemo<ReplaceLayerImageDropData>(
() => ({ id, actionType: 'REPLACE_LAYER_WITH_IMAGE', context: { entityIdentifier } }),
[id, entityIdentifier]
);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<ControlLayerAdapterGate>
@@ -47,12 +43,7 @@ export const ControlLayer = memo(({ id }: Props) => {
<CanvasEntitySettingsWrapper>
<ControlLayerSettings />
</CanvasEntitySettingsWrapper>
<DndDropTarget
dndTarget={replaceCanvasEntityObjectsWithImageDndTarget}
dndTargetData={dndTargetData}
label={t('controlLayers.replaceLayer')}
isDisabled={isBusy}
/>
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
</CanvasEntityContainer>
</ControlLayerAdapterGate>
</EntityIdentifierContext.Provider>

View File

@@ -1,7 +1,6 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
@@ -22,11 +21,10 @@ import { getFilterForModel } from 'features/controlLayers/store/filters';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
const selectControlAdapter = useMemo(
@@ -43,7 +41,7 @@ const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<
export const ControlLayerControlAdapter = memo(() => {
const { t } = useTranslation();
const { dispatch, getState } = useAppStore();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
@@ -115,17 +113,11 @@ export const ControlLayerControlAdapter = memo(() => {
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
const isBusy = useCanvasIsBusy();
const uploadOptions = useMemo(
() =>
({
onUpload: (imageDTO: ImageDTO) => {
replaceCanvasEntityObjectsWithImage({ entityIdentifier, imageDTO, dispatch, getState });
},
allowMultiple: false,
}) as const,
[dispatch, entityIdentifier, getState]
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
[entityIdentifier]
);
const uploadApi = useImageUploadButton(uploadOptions);
const uploadApi = useImageUploadButton({ postUploadAction });
return (
<Flex flexDir="column" gap={3} position="relative" w="full">

View File

@@ -1,14 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { ControlLayer } from 'features/controlLayers/components/ControlLayer/ControlLayer';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.controlLayers.entities.map(getEntityIdentifier).toReversed();
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.controlLayers.entities.map(mapId).reverse();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const ControlLayerEntityList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
const layerIds = useAppSelector(selectEntityIds);
if (entityIdentifiers.length === 0) {
if (layerIds.length === 0) {
return null;
}
if (entityIdentifiers.length > 0) {
if (layerIds.length > 0) {
return (
<CanvasEntityGroupList type="control_layer" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifier) => (
<ControlLayer key={entityIdentifier.id} id={entityIdentifier.id} />
<CanvasEntityGroupList type="control_layer" isSelected={isSelected}>
{layerIds.map((id) => (
<ControlLayer key={id} id={id} />
))}
</CanvasEntityGroupList>
);

View File

@@ -1,25 +1,22 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { Trans } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
import type { PostUploadAction } from 'services/api/types';
export const ControlLayerSettingsEmptyState = memo(() => {
const entityIdentifier = useEntityIdentifierContext('control_layer');
const { dispatch, getState } = useAppStore();
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
replaceCanvasEntityObjectsWithImage({ imageDTO, entityIdentifier, dispatch, getState });
},
[dispatch, entityIdentifier, getState]
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
[entityIdentifier]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const uploadApi = useImageUploadButton({ postUploadAction });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);

View File

@@ -1,5 +1,5 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';

View File

@@ -1,80 +1,82 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { UploadImageButton } from 'common/hooks/useImageUploadButton';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { useNanoid } from 'common/hooks/useNanoid';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { memo, useCallback, useEffect } from 'react';
import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget> = {
type Props = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
dndTarget: T;
dndTargetData: ReturnType<T['getData']>;
droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction;
};
export const IPAdapterImagePreview = memo(
<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget>({
image,
onChangeImage,
dndTarget,
dndTargetData,
}: Props<T>) => {
const { t } = useTranslation();
const isConnected = useStore($isConnected);
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
export const IPAdapterImagePreview = memo(({ image, onChangeImage, droppableData, postUploadAction }: Props) => {
const { t } = useTranslation();
const isConnected = useStore($isConnected);
const dndId = useNanoid('ip_adapter_image_preview');
useEffect(() => {
if (isConnected && isError) {
handleResetControlImage();
}
}, [handleResetControlImage, isError, isConnected]);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.image_name ?? skipToken
);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
onChangeImage(imageDTO);
},
[onChangeImage]
);
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
if (controlImage) {
return {
id: dndId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, dndId]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
<UploadImageButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
onUpload={onUpload}
fontSize={36}
useEffect(() => {
if (isConnected && isErrorControlImage) {
handleResetControlImage();
}
}, [handleResetControlImage, isConnected, isErrorControlImage]);
return (
<Flex
position="relative"
w="full"
h="full"
alignItems="center"
borderColor="error.500"
borderStyle="solid"
borderWidth={controlImage ? 0 : 1}
borderRadius="base"
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
postUploadAction={postUploadAction}
/>
{controlImage && (
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
/>
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
</Flex>
);
}
);
</Flex>
)}
</Flex>
);
});
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';

View File

@@ -2,14 +2,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
import { IPAdapter } from 'features/controlLayers/components/IPAdapter/IPAdapter';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(getEntityIdentifier).toReversed();
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(mapId).reverse();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'reference_image';
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
export const IPAdapterList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
const ipaIds = useAppSelector(selectEntityIds);
if (entityIdentifiers.length === 0) {
if (ipaIds.length === 0) {
return null;
}
if (entityIdentifiers.length > 0) {
if (ipaIds.length > 0) {
return (
<CanvasEntityGroupList type="reference_image" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifiers) => (
<IPAdapter key={entityIdentifiers.id} id={entityIdentifiers.id} />
<CanvasEntityGroupList type="reference_image" isSelected={isSelected}>
{ipaIds.map((id) => (
<IPAdapter key={id} id={id} />
))}
</CanvasEntityGroupList>
);

Some files were not shown because too many files have changed in this diff Show More