Compare commits

..

22 Commits

Author SHA1 Message Date
Brandon Rising
b75b6cf55f Run ruff, fix bug in hf downloading code which failed to download parts of a model 2024-10-31 13:05:09 -04:00
Brandon Rising
8ca5047d93 Create new latent factors for sd35 2024-10-30 14:28:18 -04:00
Brandon Rising
1d8f8051e9 Rather than .fp16., some repos start the suffix with .fp16... for weights spread across multiple files 2024-10-30 13:03:05 -04:00
Ryan Dick
25cf24ee57 Make T5 encoder optonal in SD3 workflows. 2024-10-25 18:49:28 +00:00
Ryan Dick
99e1c80abc Make the default CFG for SD3 3.5. 2024-10-25 15:05:16 +00:00
Ryan Dick
f8743ef333 Add progress images to SD3 and make denoising cancellable. 2024-10-25 15:02:30 +00:00
Brandon Rising
596dc6b4e3 Setup Model and T5 Encoder selection fields for sd3 nodes 2024-10-25 00:20:28 -04:00
Brandon Rising
b4f93a0ff5 Initial wave of frontend updates for sd-3 node inputs 2024-10-24 22:22:32 -04:00
Brandon Rising
f4be52abb4 define submodels on sd3 models during probe 2024-10-24 15:18:42 -04:00
Ryan Dick
4e029331ba Add tqdm progress bar for SD3. 2024-10-24 16:04:37 +00:00
Ryan Dick
40b4de5f77 Bug fixes to get SD3 text-to-image workflow running. 2024-10-24 15:55:17 +00:00
Ryan Dick
c930807881 Temporary hack for testing SD3 model loader. 2024-10-24 15:34:12 +00:00
Ryan Dick
bec8b27429 Fix Sd3TextEncoderInvocation output type. 2024-10-24 15:14:46 +00:00
Ryan Dick
ef4f466ccf Initial draft of SD3DenoiseInvocation. 2024-10-24 14:43:48 +00:00
Ryan Dick
3c869ee5ab Add first draft of Sd3TextEncoderInvocation. 2024-10-24 01:19:40 +00:00
Ryan Dick
16dc30fd5b Add Sd3ModelLoaderInvocation. 2024-10-24 00:17:19 +00:00
Ryan Dick
0c14192819 Move FluxModelLoaderInvocation to its own file. model.py was getting bloated. 2024-10-24 00:03:35 +00:00
Ryan Dick
36dadba45b Get diffusers SD3 model probing working. 2024-10-23 19:55:26 +00:00
Ryan Dick
f2a9c01d0e (minor) Remove unused dict. 2024-10-23 19:03:33 +00:00
Ryan Dick
1ca57ade4d Fix huggingface_hub.errors imports after version bump. 2024-10-23 18:29:24 +00:00
Ryan Dick
85c0e0db1e Fix changed import for FromOriginalControlNetMixin after diffusers bump. 2024-10-23 18:25:12 +00:00
Ryan Dick
59a2388585 Bump diffusers, accelerate, and huggingface-hub. 2024-10-23 18:09:35 +00:00
154 changed files with 2478 additions and 3535 deletions

View File

@@ -17,49 +17,46 @@ If you just want to use Invoke, you should use the [installer][installer link].
## Setup
1. Run through the [requirements][requirements link].
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
4. Create a python virtual environment inside the directory you just created:
1. [Fork and clone][forking link] the [InvokeAI repo][repo link].
1. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
1. Create a python virtual environment inside the directory you just created:
```sh
python3 -m venv .venv --prompt InvokeAI-Dev
```
```sh
python3 -m venv .venv --prompt InvokeAI-Dev
```
5. Activate the venv (you'll need to do this every time you want to run the app):
1. Activate the venv (you'll need to do this every time you want to run the app):
```sh
source .venv/bin/activate
```
```sh
source .venv/bin/activate
```
6. Install the repo as an [editable install][editable install link]:
1. Install the repo as an [editable install][editable install link]:
```sh
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
```sh
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
Refer to the [manual installation][manual install link]] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
Refer to the [manual installation][manual install link]] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
7. Install the frontend dev toolchain:
1. Install the frontend dev toolchain:
- [`nodejs`](https://nodejs.org/) (recommend v20 LTS)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
- [`pnpm`](https://pnpm.io/installation#installing-a-specific-version) (must be v8 - not v9!)
8. Do a production build of the frontend:
1. Do a production build of the frontend:
```sh
cd PATH_TO_INVOKEAI_REPO/invokeai/frontend/web
pnpm i
pnpm build
```
```sh
pnpm build
```
9. Start the application:
1. Start the application:
```sh
cd PATH_TO_INVOKEAI_REPO
python scripts/invokeai-web.py
```
```sh
python scripts/invokeai-web.py
```
10. Access the UI at `localhost:9090`.
1. Access the UI at `localhost:9090`.
## Updating the UI

View File

@@ -808,11 +808,7 @@ def get_is_installed(
for model in installed_models:
if model.source == starter_model.source:
return True
if (
(model.name == starter_model.name or model.name in starter_model.previous_names)
and model.base == starter_model.base
and model.type == starter_model.type
):
if model.name == starter_model.name and model.base == starter_model.base and model.type == starter_model.type:
return True
return False

View File

@@ -13,7 +13,6 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from PIL import Image
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
@@ -511,7 +510,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
ext_manager: ExtensionsManager,
bgr_mode: bool = False,
) -> None:
if t2i_adapters is None:
return
@@ -521,10 +519,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapters = [t2i_adapters]
for t2i_adapter_field in t2i_adapters:
image = context.images.get_pil(t2i_adapter_field.image.image_name)
if bgr_mode: # SDXL t2i trained on cv2's BGR outputs, but PIL won't convert straight to BGR
r, g, b = image.split()
image = Image.merge("RGB", (b, g, r))
ext_manager.add_extension(
T2IAdapterExt(
node_context=context,
@@ -629,10 +623,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
max_unet_downscale = 8
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
# SDXL adapters are trained on cv2's BGR outputs
r, g, b = image.split()
image = Image.merge("RGB", (b, g, r))
else:
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
@@ -910,8 +900,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
# ext_manager.add_extension(ext)
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
bgr_mode = self.unet.unet.base == BaseModelType.StableDiffusionXL
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager, bgr_mode)
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

View File

@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@@ -133,6 +134,7 @@ class FieldDescriptions:
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
mmditx = "MMDiTX"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -140,6 +142,7 @@ class FieldDescriptions:
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -246,6 +249,12 @@ class FluxConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -0,0 +1,89 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
SubModelType,
)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -165,7 +165,6 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
mask: TensorField = InputField(description="The mask tensor to apply.")
image: ImageField = InputField(description="The image to apply the mask to.")
invert: bool = InputField(default=False, description="Whether to invert the mask.")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
@@ -180,9 +179,6 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy().astype(np.uint8)
if self.invert:
mask_np = 255 - mask_np
# Apply the mask only to the alpha channel where the original alpha is non-zero. This preserves the original
# image's transparency - else the transparent regions would end up as opaque black.

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
SD3ConditioningField,
TensorField,
UIComponent,
)
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,260 @@
from typing import Callable, Tuple
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
SD3ConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_denoise",
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
joint_attention_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
sd3_conditioning = cond_data.conditionings[0]
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
t5_embeds = sd3_conditioning.t5_embeds
if t5_embeds is None:
t5_embeds = torch.zeros(
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
device=device,
dtype=dtype,
)
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
pooled_prompt_embeds = torch.cat(
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
)
return prompt_embeds, pooled_prompt_embeds
def _get_noise(
self,
num_samples: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
"""Prepare the CFG scale list.
Args:
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
on the scheduler used (e.g. higher order schedulers).
Returns:
list[float]: _description_
"""
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_text_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_text_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.num_steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)
noise = self._get_noise(
num_samples=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=latents,
),
)
with transformer_info.model_on_device() as (cached_weights, transformer):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=None,
return_dict=False,
)[0]
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
return step_callback

View File

@@ -0,0 +1,97 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("sd3_model_loader_output")
class Sd3ModelLoaderOutput(BaseInvocationOutput):
"""SD3 base model loader output."""
mmditx: TransformerField = OutputField(description=FieldDescriptions.mmditx, title="MMDiTX")
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"sd3_model_loader",
title="SD3 Main Model",
tags=["model", "sd3"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
)
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="T5 Encoder",
default=None,
)
# TODO(brandon): Setup UI updates to support selecting a clip l model.
# clip_l_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.clip_l_model,
# ui_type=UIType.CLIPEmbedModel,
# input=Input.Direct,
# title="CLIP L Encoder",
# )
# TODO(brandon): Setup UI updates to support selecting a clip g model.
# clip_g_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.clip_g_model,
# ui_type=UIType.CLIPGModel,
# input=Input.Direct,
# title="CLIP G Encoder",
# )
# TODO(brandon): Setup UI updates to support selecting an SD3 vae model.
# vae_model: ModelIdentifierField = InputField(
# description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE", default=None
# )
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
mmditx = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
clip_encoder_g = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
mmditx=TransformerField(transformer=mmditx, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,199 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@invocation(
"sd3_text_encoder",
title="SD3 Text Encoding",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""
clip_l: CLIPField = InputField(
title="CLIP L",
description=FieldDescriptions.clip,
input=Input.Connection,
)
clip_g: CLIPField = InputField(
title="CLIP G",
description=FieldDescriptions.clip,
input=Input.Connection,
)
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
t5_encoder: T5EncoderField | None = InputField(
title="T5Encoder",
default=None,
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
t5_embeddings: torch.Tensor | None = None
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(
conditionings=[
SD3ConditioningInfo(
clip_l_embeds=clip_l_embeddings,
clip_l_pooled_embeds=clip_l_pooled_embeddings,
clip_g_embeds=clip_g_embeddings,
clip_g_pooled_embeds=clip_g_pooled_embeddings,
t5_embeds=t5_embeddings,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return SD3ConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
text_inputs = t5_tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
text_inputs = clip_tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -1,4 +1,3 @@
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
@@ -222,7 +221,7 @@ class ImagesInterface(InvocationContextInterface):
)
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
"""Gets an image as a PIL Image object. This method returns a copy of the image.
"""Gets an image as a PIL Image object.
Args:
image_name: The name of the image to get.
@@ -234,15 +233,11 @@ class ImagesInterface(InvocationContextInterface):
image = self._services.images.get_pil_image(image_name)
if mode and mode != image.mode:
try:
# convert makes a copy!
image = image.convert(mode)
except ValueError:
self._services.logger.warning(
f"Could not convert image from {image.mode} to {mode}. Using original mode instead."
)
else:
# copy the image to prevent the user from modifying the original
image = image.copy()
return image
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
@@ -295,15 +290,15 @@ class TensorsInterface(InvocationContextInterface):
return name
def load(self, name: str) -> Tensor:
"""Loads a tensor by name. This method returns a copy of the tensor.
"""Loads a tensor by name.
Args:
name: The name of the tensor to load.
Returns:
The tensor.
The loaded tensor.
"""
return self._services.tensors.load(name).clone()
return self._services.tensors.load(name)
class ConditioningInterface(InvocationContextInterface):
@@ -321,16 +316,16 @@ class ConditioningInterface(InvocationContextInterface):
return name
def load(self, name: str) -> ConditioningFieldData:
"""Loads conditioning data by name. This method returns a copy of the conditioning data.
"""Loads conditioning data by name.
Args:
name: The name of the conditioning data to load.
Returns:
The conditioning data.
The loaded conditioning data.
"""
return deepcopy(self._services.conditioning.load(name))
return self._services.conditioning.load(name)
class ModelsInterface(InvocationContextInterface):

View File

@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
[-0.1307, -0.1874, -0.7445], # L4
]
SD3_5_LATENT_RGB_FACTORS = [
[-0.05240681, 0.03251581, 0.0749016],
[-0.0580572, 0.00759826, 0.05729818],
[0.16144888, 0.01270368, -0.03768577],
[0.14418615, 0.08460266, 0.15941818],
[0.04894035, 0.0056485, -0.06686988],
[0.05187166, 0.19222395, 0.06261094],
[0.1539433, 0.04818359, 0.07103094],
[-0.08601796, 0.09013458, 0.10893912],
[-0.12398469, -0.06766567, 0.0033688],
[-0.0439737, 0.07825329, 0.02258823],
[0.03101129, 0.06382551, 0.07753657],
[-0.01315361, 0.08554491, -0.08772475],
[0.06464487, 0.05914605, 0.13262741],
[-0.07863674, -0.02261737, -0.12761454],
[-0.09923835, -0.08010759, -0.06264447],
[-0.03392309, -0.0804029, -0.06078822],
]
FLUX_LATENT_RGB_FACTORS = [
[-0.0412, 0.0149, 0.0521],
[0.0056, 0.0291, 0.0768],
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
elif base_model == BaseModelType.StableDiffusion3:
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
else:
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)

View File

@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
@@ -147,6 +150,11 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
@@ -193,6 +201,9 @@ class ModelConfigBase(BaseModel):
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
class CheckpointConfigBase(ModelConfigBase):

View File

@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
raise ValueError(

View File

@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
config: AnyModelConfig,
@@ -117,6 +111,8 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
load_class = load_classes[config.base][config.variant]
except KeyError as e:
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
prediction_type = config.prediction_type.value
upcast_attention = config.upcast_attention
# Without SilenceWarnings we get log messages like this:
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
@@ -127,7 +123,13 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
# ['text_model.embeddings.position_ids']
with SilenceWarnings():
pipeline = load_class.from_single_file(config.path, torch_dtype=self._torch_dtype)
pipeline = load_class.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
load_safety_checker=False,
)
if not submodel_type:
return pipeline

View File

@@ -19,7 +19,7 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@@ -33,7 +33,10 @@ from invokeai.backend.model_manager.config import (
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubmodelDefinition,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -112,6 +115,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
@@ -122,6 +126,8 @@ class ModelProbe(object):
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
}
@classmethod
@@ -178,7 +184,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["hash"] = "placeholder" # fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@@ -217,6 +223,10 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
get_submodels = getattr(probe, "get_submodels", None)
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@@ -462,9 +472,8 @@ MODEL_NAME_TO_PREPROCESSOR = {
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart anime": "lineart_anime_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"hed": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
@@ -747,18 +756,33 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
config_path = self.model_path / "unet" / "config.json"
if config_path.exists():
with open(config_path) as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a transformer (i.e. SD3).
config_path = self.model_path / "transformer" / "config.json"
if config_path.exists():
with open(config_path) as file:
transformer_conf = json.load(file)
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
@@ -770,6 +794,21 @@ class PipelineFolderProbe(FolderProbeBase):
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
submodels: Dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type,
)
return submodels
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the

View File

@@ -13,9 +13,6 @@ class StarterModelWithoutDependencies(BaseModel):
type: ModelType
format: Optional[ModelFormat] = None
is_installed: bool = False
# allows us to track what models a user has installed across name changes within starter models
# if you update a starter model name, please add the old one to this list for that starter model
previous_names: list[str] = []
class StarterModel(StarterModelWithoutDependencies):
@@ -246,49 +243,44 @@ easy_neg_sd1 = StarterModel(
# endregion
# region IP Adapter
ip_adapter_sd1 = StarterModel(
name="Standard Reference (IP Adapter)",
name="IP Adapter",
base=BaseModelType.StableDiffusion1,
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
description="References images with a more generalized/looser degree of precision.",
description="IP-Adapter for SD 1.5 models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
previous_names=["IP Adapter"],
)
ip_adapter_plus_sd1 = StarterModel(
name="Precise Reference (IP Adapter Plus)",
name="IP Adapter Plus",
base=BaseModelType.StableDiffusion1,
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
description="References images with a higher degree of precision.",
description="Refined IP-Adapter for SD 1.5 models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
previous_names=["IP Adapter Plus"],
)
ip_adapter_plus_face_sd1 = StarterModel(
name="Face Reference (IP Adapter Plus Face)",
name="IP Adapter Plus Face",
base=BaseModelType.StableDiffusion1,
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
description="References images with a higher degree of precision, adapted for faces",
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
previous_names=["IP Adapter Plus Face"],
)
ip_adapter_sdxl = StarterModel(
name="Standard Reference (IP Adapter ViT-H)",
name="IP Adapter SDXL",
base=BaseModelType.StableDiffusionXL,
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
description="References images with a higher degree of precision.",
description="IP-Adapter for SDXL models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sdxl_image_encoder],
previous_names=["IP Adapter SDXL"],
)
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter)",
name="XLabs FLUX IP-Adapter",
base=BaseModelType.Flux,
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
description="References images with a more generalized/looser degree of precision.",
description="FLUX IP-Adapter",
type=ModelType.IPAdapter,
dependencies=[clip_vit_l_image_encoder],
previous_names=["XLabs FLUX IP-Adapter"],
)
# endregion
# region ControlNet
@@ -307,162 +299,157 @@ qr_code_cnet_sdxl = StarterModel(
type=ModelType.ControlNet,
)
canny_sd1 = StarterModel(
name="Hard Edge Detection (canny)",
name="canny",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_canny",
description="Uses detected edges in the image to control composition.",
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
type=ModelType.ControlNet,
previous_names=["canny"],
)
inpaint_cnet_sd1 = StarterModel(
name="Inpainting",
name="inpaint",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_inpaint",
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
type=ModelType.ControlNet,
previous_names=["inpaint"],
)
mlsd_sd1 = StarterModel(
name="Line Drawing (mlsd)",
name="mlsd",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_mlsd",
description="Uses straight line detection for controlling the generation.",
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
type=ModelType.ControlNet,
previous_names=["mlsd"],
)
depth_sd1 = StarterModel(
name="Depth Map",
name="depth",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11f1p_sd15_depth",
description="Uses depth information in the image to control the depth in the generation.",
description="ControlNet weights trained on sd-1.5 with depth conditioning",
type=ModelType.ControlNet,
previous_names=["depth"],
)
normal_bae_sd1 = StarterModel(
name="Lighting Detection (Normals)",
name="normal_bae",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_normalbae",
description="Uses detected lighting information to guide the lighting of the composition.",
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
type=ModelType.ControlNet,
previous_names=["normal_bae"],
)
seg_sd1 = StarterModel(
name="Segmentation Map",
name="seg",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_seg",
description="Uses segmentation maps to guide the structure of the composition.",
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
type=ModelType.ControlNet,
previous_names=["seg"],
)
lineart_sd1 = StarterModel(
name="Lineart",
name="lineart",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_lineart",
description="Uses lineart detection to guide the lighting of the composition.",
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
type=ModelType.ControlNet,
previous_names=["lineart"],
)
lineart_anime_sd1 = StarterModel(
name="Lineart Anime",
name="lineart_anime",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
description="Uses anime lineart detection to guide the lighting of the composition.",
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
type=ModelType.ControlNet,
previous_names=["lineart_anime"],
)
openpose_sd1 = StarterModel(
name="Pose Detection (openpose)",
name="openpose",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_openpose",
description="Uses pose information to control the pose of human characters in the generation.",
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
type=ModelType.ControlNet,
previous_names=["openpose"],
)
scribble_sd1 = StarterModel(
name="Contour Detection (scribble)",
name="scribble",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_scribble",
description="Uses edges, contours, or line art in the image to control composition.",
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
type=ModelType.ControlNet,
previous_names=["scribble"],
)
softedge_sd1 = StarterModel(
name="Soft Edge Detection (softedge)",
name="softedge",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_softedge",
description="Uses a soft edge detection map to control composition.",
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
type=ModelType.ControlNet,
previous_names=["softedge"],
)
shuffle_sd1 = StarterModel(
name="Remix (shuffle)",
name="shuffle",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11e_sd15_shuffle",
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
type=ModelType.ControlNet,
previous_names=["shuffle"],
)
tile_sd1 = StarterModel(
name="Tile",
name="tile",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11f1e_sd15_tile",
description="Uses image data to replicate exact colors/structure in the resulting generation.",
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
type=ModelType.ControlNet,
)
ip2p_sd1 = StarterModel(
name="ip2p",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11e_sd15_ip2p",
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
type=ModelType.ControlNet,
previous_names=["tile"],
)
canny_sdxl = StarterModel(
name="Hard Edge Detection (canny)",
name="canny-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-canny-sdxl-1.0",
description="Uses detected edges in the image to control composition.",
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
type=ModelType.ControlNet,
previous_names=["canny-sdxl"],
)
depth_sdxl = StarterModel(
name="Depth Map",
name="depth-sdxl",
base=BaseModelType.StableDiffusionXL,
source="diffusers/controlNet-depth-sdxl-1.0",
description="Uses depth information in the image to control the depth in the generation.",
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
type=ModelType.ControlNet,
previous_names=["depth-sdxl"],
)
softedge_sdxl = StarterModel(
name="Soft Edge Detection (softedge)",
name="softedge-dexined-sdxl",
base=BaseModelType.StableDiffusionXL,
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
description="Uses a soft edge detection map to control composition.",
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
type=ModelType.ControlNet,
)
depth_zoe_16_sdxl = StarterModel(
name="depth-16bit-zoe-sdxl",
base=BaseModelType.StableDiffusionXL,
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
type=ModelType.ControlNet,
)
depth_zoe_32_sdxl = StarterModel(
name="depth-zoe-sdxl",
base=BaseModelType.StableDiffusionXL,
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
type=ModelType.ControlNet,
previous_names=["softedge-dexined-sdxl"],
)
openpose_sdxl = StarterModel(
name="Pose Detection (openpose)",
name="openpose-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-openpose-sdxl-1.0",
description="Uses pose information to control the pose of human characters in the generation.",
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
type=ModelType.ControlNet,
previous_names=["openpose-sdxl", "controlnet-openpose-sdxl"],
)
scribble_sdxl = StarterModel(
name="Contour Detection (scribble)",
name="scribble-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-scribble-sdxl-1.0",
description="Uses edges, contours, or line art in the image to control composition.",
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
type=ModelType.ControlNet,
previous_names=["scribble-sdxl", "controlnet-scribble-sdxl"],
)
tile_sdxl = StarterModel(
name="Tile",
name="tile-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-tile-sdxl-1.0",
description="Uses image data to replicate exact colors/structure in the resulting generation.",
type=ModelType.ControlNet,
previous_names=["tile-sdxl"],
)
union_cnet_sdxl = StarterModel(
name="Multi-Guidance Detection (Union Pro)",
base=BaseModelType.StableDiffusionXL,
source="InvokeAI/Xinsir-SDXL_Controlnet_Union",
description="A unified ControlNet for SDXL model that supports 10+ control types",
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
type=ModelType.ControlNet,
)
union_cnet_flux = StarterModel(
@@ -475,52 +462,60 @@ union_cnet_flux = StarterModel(
# endregion
# region T2I Adapter
t2i_canny_sd1 = StarterModel(
name="Hard Edge Detection (canny)",
name="canny-sd15",
base=BaseModelType.StableDiffusion1,
source="TencentARC/t2iadapter_canny_sd15v2",
description="Uses detected edges in the image to control composition",
description="T2I Adapter weights trained on sd-1.5 with canny conditioning.",
type=ModelType.T2IAdapter,
previous_names=["canny-sd15"],
)
t2i_sketch_sd1 = StarterModel(
name="Sketch",
name="sketch-sd15",
base=BaseModelType.StableDiffusion1,
source="TencentARC/t2iadapter_sketch_sd15v2",
description="Uses a sketch to control composition",
description="T2I Adapter weights trained on sd-1.5 with sketch conditioning.",
type=ModelType.T2IAdapter,
previous_names=["sketch-sd15"],
)
t2i_depth_sd1 = StarterModel(
name="Depth Map",
name="depth-sd15",
base=BaseModelType.StableDiffusion1,
source="TencentARC/t2iadapter_depth_sd15v2",
description="Uses depth information in the image to control the depth in the generation.",
description="T2I Adapter weights trained on sd-1.5 with depth conditioning.",
type=ModelType.T2IAdapter,
)
t2i_zoe_depth_sd1 = StarterModel(
name="zoedepth-sd15",
base=BaseModelType.StableDiffusion1,
source="TencentARC/t2iadapter_zoedepth_sd15v1",
description="T2I Adapter weights trained on sd-1.5 with zoe depth conditioning.",
type=ModelType.T2IAdapter,
previous_names=["depth-sd15"],
)
t2i_canny_sdxl = StarterModel(
name="Hard Edge Detection (canny)",
name="canny-sdxl",
base=BaseModelType.StableDiffusionXL,
source="TencentARC/t2i-adapter-canny-sdxl-1.0",
description="Uses detected edges in the image to control composition",
description="T2I Adapter weights trained on sdxl-1.0 with canny conditioning.",
type=ModelType.T2IAdapter,
)
t2i_zoe_depth_sdxl = StarterModel(
name="zoedepth-sdxl",
base=BaseModelType.StableDiffusionXL,
source="TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
description="T2I Adapter weights trained on sdxl-1.0 with zoe depth conditioning.",
type=ModelType.T2IAdapter,
previous_names=["canny-sdxl"],
)
t2i_lineart_sdxl = StarterModel(
name="Lineart",
name="lineart-sdxl",
base=BaseModelType.StableDiffusionXL,
source="TencentARC/t2i-adapter-lineart-sdxl-1.0",
description="Uses lineart detection to guide the lighting of the composition.",
description="T2I Adapter weights trained on sdxl-1.0 with lineart conditioning.",
type=ModelType.T2IAdapter,
previous_names=["lineart-sdxl"],
)
t2i_sketch_sdxl = StarterModel(
name="Sketch",
name="sketch-sdxl",
base=BaseModelType.StableDiffusionXL,
source="TencentARC/t2i-adapter-sketch-sdxl-1.0",
description="Uses a sketch to control composition",
description="T2I Adapter weights trained on sdxl-1.0 with sketch conditioning.",
type=ModelType.T2IAdapter,
previous_names=["sketch-sdxl"],
)
# endregion
# region SpandrelImageToImage
@@ -605,18 +600,22 @@ STARTER_MODELS: list[StarterModel] = [
softedge_sd1,
shuffle_sd1,
tile_sd1,
ip2p_sd1,
canny_sdxl,
depth_sdxl,
softedge_sdxl,
depth_zoe_16_sdxl,
depth_zoe_32_sdxl,
openpose_sdxl,
scribble_sdxl,
tile_sdxl,
union_cnet_sdxl,
union_cnet_flux,
t2i_canny_sd1,
t2i_sketch_sd1,
t2i_depth_sd1,
t2i_zoe_depth_sd1,
t2i_canny_sdxl,
t2i_zoe_depth_sdxl,
t2i_lineart_sdxl,
t2i_sketch_sdxl,
realesrgan_x4,
@@ -647,6 +646,7 @@ sd1_bundle: list[StarterModel] = [
softedge_sd1,
shuffle_sd1,
tile_sd1,
ip2p_sd1,
swinir,
]
@@ -657,6 +657,8 @@ sdxl_bundle: list[StarterModel] = [
canny_sdxl,
depth_sdxl,
softedge_sdxl,
depth_zoe_16_sdxl,
depth_zoe_32_sdxl,
openpose_sdxl,
scribble_sdxl,
tile_sdxl,

View File

@@ -129,8 +129,10 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
if (
candidate_variant_label
and candidate_variant_label.startswith(f".{variant.value}")
or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default])
):
score += 1
@@ -146,7 +148,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
at_least_one_fp16 = True
break
@@ -162,7 +164,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
match = re.match(pattern, highest_score_candidate.path.as_posix())
if match:
for candidate in candidate_list:
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
match.group(2)
):
result.add(candidate.path)
else:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list

View File

@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
return self
@dataclass
class SD3ConditioningInfo:
clip_l_pooled_embeds: torch.Tensor
clip_l_embeds: torch.Tensor
clip_g_pooled_embeds: torch.Tensor
clip_g_embeds: torch.Tensor
t5_embeds: torch.Tensor | None
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
if self.t5_embeds is not None:
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]
| List[FLUXConditioningInfo]
| List[SD3ConditioningInfo]
)
@dataclass

View File

@@ -33,7 +33,7 @@ class PreviewExt(ExtensionBase):
def initial_preview(self, ctx: DenoiseContext):
self.callback(
PipelineIntermediateState(
step=0,
step=-1,
order=ctx.scheduler.order,
total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?

View File

@@ -58,7 +58,7 @@
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.1.0",
"@invoke-ai/ui-library": "^0.0.43",
"@invoke-ai/ui-library": "^0.0.42",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.2.3",
"@roarr/browser-log-writer": "^1.3.0",

View File

@@ -24,8 +24,8 @@ dependencies:
specifier: ^5.1.0
version: 5.1.0
'@invoke-ai/ui-library':
specifier: ^0.0.43
version: 0.0.43(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1)
specifier: ^0.0.42
version: 0.0.42(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1)
'@nanostores/react':
specifier: ^0.7.3
version: 0.7.3(nanostores@0.11.3)(react@18.3.1)
@@ -1696,20 +1696,20 @@ packages:
prettier: 3.3.3
dev: true
/@invoke-ai/ui-library@0.0.43(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-t3fPYyks07ue3dEBPJuTHbeDLnDckDCOrtvc07mMDbLOnlPEZ0StaeiNGH+oO8qLzAuMAlSTdswgHfzTc2MmPw==}
/@invoke-ai/ui-library@0.0.42(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-OuDXRipBO5mu+Nv4qN8cd8MiwiGBdq6h4PirVgPI9/ltbdcIzePgUJ0dJns26lflHSTRWW38I16wl4YTw3mNWA==}
peerDependencies:
'@fontsource-variable/inter': ^5.0.16
react: ^18.2.0
react-dom: ^18.2.0
dependencies:
'@chakra-ui/anatomy': 2.3.4
'@chakra-ui/anatomy': 2.2.2
'@chakra-ui/icons': 2.2.4(@chakra-ui/react@2.10.2)(react@18.3.1)
'@chakra-ui/layout': 2.3.1(@chakra-ui/system@2.6.2)(react@18.3.1)
'@chakra-ui/portal': 2.1.0(react-dom@18.3.1)(react@18.3.1)
'@chakra-ui/react': 2.10.2(@emotion/react@11.13.3)(@emotion/styled@11.13.0)(@types/react@18.3.11)(framer-motion@11.10.0)(react-dom@18.3.1)(react@18.3.1)
'@chakra-ui/styled-system': 2.11.2(react@18.3.1)
'@chakra-ui/theme-tools': 2.2.6(@chakra-ui/styled-system@2.11.2)(react@18.3.1)
'@chakra-ui/styled-system': 2.9.2
'@chakra-ui/theme-tools': 2.1.2(@chakra-ui/styled-system@2.9.2)
'@emotion/react': 11.13.3(@types/react@18.3.11)(react@18.3.1)
'@emotion/styled': 11.13.0(@emotion/react@11.13.3)(@types/react@18.3.11)(react@18.3.1)
'@fontsource-variable/inter': 5.1.0

View File

@@ -94,7 +94,6 @@
"close": "Close",
"copy": "Copy",
"copyError": "$t(gallery.copy) Error",
"clipboard": "Clipboard",
"on": "On",
"off": "Off",
"or": "or",
@@ -682,8 +681,7 @@
"recallParameters": "Recall Parameters",
"recallParameter": "Recall {{label}}",
"scheduler": "Scheduler",
"seamlessXAxis": "Seamless X Axis",
"seamlessYAxis": "Seamless Y Axis",
"seamless": "Seamless",
"seed": "Seed",
"steps": "Steps",
"strength": "Image to image strength",
@@ -714,12 +712,8 @@
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"noDefaultSettings": "No default settings configured for this model. Visit the Model Manager to add default settings.",
"defaultSettings": "Default Settings",
"defaultSettingsSaved": "Default Settings Saved",
"defaultSettingsOutOfSync": "Some settings do not match the model's defaults:",
"restoreDefaultSettings": "Click to use the model's default settings.",
"usingDefaultSettings": "Using model's default settings",
"delete": "Delete",
"deleteConfig": "Delete Config",
"deleteModel": "Delete Model",
@@ -804,6 +798,7 @@
"uploadImage": "Upload Image",
"urlOrLocalPath": "URL or Local Path",
"urlOrLocalPathHelper": "URLs should point to a single file. Local paths can point to a single file or folder for a single diffusers model.",
"useDefaultSettings": "Use Default Settings",
"vae": "VAE",
"vaePrecision": "VAE Precision",
"variant": "Variant",
@@ -1113,9 +1108,6 @@
"enableInformationalPopovers": "Enable Informational Popovers",
"informationalPopoversDisabled": "Informational Popovers Disabled",
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
"enableModelDescriptions": "Enable Model Descriptions in Dropdowns",
"modelDescriptionsDisabled": "Model Descriptions in Dropdowns Disabled",
"modelDescriptionsDisabledDesc": "Model descriptions in dropdowns have been disabled. Enable them in Settings.",
"enableInvisibleWatermark": "Enable Invisible Watermark",
"enableNSFWChecker": "Enable NSFW Checker",
"general": "General",
@@ -1259,33 +1251,6 @@
"heading": "Mask Adjustments",
"paragraphs": ["Adjust the mask."]
},
"inpainting": {
"heading": "Inpainting",
"paragraphs": ["Controls which area is modified, guided by Denoising Strength."]
},
"rasterLayer": {
"heading": "Raster Layer",
"paragraphs": ["Pixel-based content of your canvas, used during image generation."]
},
"regionalGuidance": {
"heading": "Regional Guidance",
"paragraphs": ["Brush to guide where elements from global prompts should appear."]
},
"regionalGuidanceAndReferenceImage": {
"heading": "Regional Guidance and Regional Reference Image",
"paragraphs": [
"For Regional Guidance, brush to guide where elements from global prompts should appear.",
"For Regional Reference Image, brush to apply a reference image to specific areas."
]
},
"globalReferenceImage": {
"heading": "Global Reference Image",
"paragraphs": ["Applies a reference image to influence the entire generation."]
},
"regionalReferenceImage": {
"heading": "Regional Reference Image",
"paragraphs": ["Brush to apply a reference image to specific areas."]
},
"controlNet": {
"heading": "ControlNet",
"paragraphs": [
@@ -1683,8 +1648,6 @@
"controlLayer": "Control Layer",
"inpaintMask": "Inpaint Mask",
"regionalGuidance": "Regional Guidance",
"canvasAsRasterLayer": "$t(controlLayers.canvas) as $t(controlLayers.rasterLayer)",
"canvasAsControlLayer": "$t(controlLayers.canvas) as $t(controlLayers.controlLayer)",
"referenceImage": "Reference Image",
"regionalReferenceImage": "Regional Reference Image",
"globalReferenceImage": "Global Reference Image",
@@ -1725,18 +1688,8 @@
"layer_other": "Layers",
"layer_withCount_one": "Layer ({{count}})",
"layer_withCount_other": "Layers ({{count}})",
"convertRasterLayerTo": "Convert $t(controlLayers.rasterLayer) To",
"convertControlLayerTo": "Convert $t(controlLayers.controlLayer) To",
"convertInpaintMaskTo": "Convert $t(controlLayers.inpaintMask) To",
"convertRegionalGuidanceTo": "Convert $t(controlLayers.regionalGuidance) To",
"copyRasterLayerTo": "Copy $t(controlLayers.rasterLayer) To",
"copyControlLayerTo": "Copy $t(controlLayers.controlLayer) To",
"copyInpaintMaskTo": "Copy $t(controlLayers.inpaintMask) To",
"copyRegionalGuidanceTo": "Copy $t(controlLayers.regionalGuidance) To",
"newRasterLayer": "New $t(controlLayers.rasterLayer)",
"newControlLayer": "New $t(controlLayers.controlLayer)",
"newInpaintMask": "New $t(controlLayers.inpaintMask)",
"newRegionalGuidance": "New $t(controlLayers.regionalGuidance)",
"convertToControlLayer": "Convert to Control Layer",
"convertToRasterLayer": "Convert to Raster Layer",
"transparency": "Transparency",
"enableTransparencyEffect": "Enable Transparency Effect",
"disableTransparencyEffect": "Disable Transparency Effect",
@@ -1760,7 +1713,6 @@
"newGallerySessionDesc": "This will clear the canvas and all settings except for your model selection. Generations will be sent to the gallery.",
"newCanvasSession": "New Canvas Session",
"newCanvasSessionDesc": "This will clear the canvas and all settings except for your model selection. Generations will be staged on the canvas.",
"replaceCurrent": "Replace Current",
"controlMode": {
"controlMode": "Control Mode",
"balanced": "Balanced",
@@ -1890,24 +1842,16 @@
"apply": "Apply",
"cancel": "Cancel"
},
"selectObject": {
"selectObject": "Select Object",
"segment": {
"autoMask": "Auto Mask",
"pointType": "Point Type",
"invertSelection": "Invert Selection",
"include": "Include",
"exclude": "Exclude",
"foreground": "Foreground",
"background": "Background",
"neutral": "Neutral",
"apply": "Apply",
"reset": "Reset",
"saveAs": "Save As",
"apply": "Apply",
"cancel": "Cancel",
"process": "Process",
"help1": "Select a single target object. Add <Bold>Include</Bold> and <Bold>Exclude</Bold> points to indicate which parts of the layer are part of the target object.",
"help2": "Start with one <Bold>Include</Bold> point within the target object. Add more points to refine the selection. Fewer points typically produce better results.",
"help3": "Invert the selection to select everything except the target object.",
"clickToAdd": "Click on the layer to add a point",
"dragToMove": "Drag a point to move it",
"clickToRemove": "Click on a point to remove it"
"process": "Process"
},
"settings": {
"snapToGrid": {
@@ -1948,8 +1892,6 @@
"newRegionalReferenceImage": "New Regional Reference Image",
"newControlLayer": "New Control Layer",
"newRasterLayer": "New Raster Layer",
"newInpaintMask": "New Inpaint Mask",
"newRegionalGuidance": "New Regional Guidance",
"cropCanvasToBbox": "Crop Canvas to Bbox"
},
"stagingArea": {
@@ -2082,11 +2024,13 @@
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"line1": "<ItalicComponent>Select Object</ItalicComponent> tool for precise object selection and editing",
"line2": "Expanded Flux support, now with Global Reference Images",
"line3": "Improved tooltips and context menus",
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"
"canvasV2Announcement": {
"newCanvas": "A powerful new control canvas",
"newLayerTypes": "New layer types for even more control",
"fluxSupport": "Support for the Flux family of models",
"readReleaseNotes": "Read Release Notes",
"watchReleaseVideo": "Watch Release Video",
"watchUiUpdatesOverview": "Watch UI Updates Overview"
}
}
}

View File

@@ -8,7 +8,6 @@ import {
controlLayerAdded,
entityRasterized,
entitySelected,
inpaintMaskAdded,
rasterLayerAdded,
referenceImageAdded,
referenceImageIPAdapterImageChanged,
@@ -18,7 +17,6 @@ import {
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
@@ -112,46 +110,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
return;
}
/**
/**
* Image dropped on Inpaint Mask
*/
if (
overData.actionType === 'ADD_INPAINT_MASK_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasInpaintMaskState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
return;
}
/**
/**
* Image dropped on Regional Guidance
*/
if (
overData.actionType === 'ADD_REGIONAL_GUIDANCE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRegionalGuidanceState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/

View File

@@ -26,9 +26,5 @@ export const IconMenuItem = ({ tooltip, icon, ...props }: Props) => {
};
export const IconMenuItemGroup = ({ children }: { children: ReactNode }) => {
return (
<Flex gap={2} justifyContent="space-between">
{children}
</Flex>
);
return <Flex gap={2}>{children}</Flex>;
};

View File

@@ -23,10 +23,8 @@ export type Feature =
| 'dynamicPrompts'
| 'dynamicPromptsMaxPrompts'
| 'dynamicPromptsSeedBehaviour'
| 'globalReferenceImage'
| 'imageFit'
| 'infillMethod'
| 'inpainting'
| 'ipAdapterMethod'
| 'lora'
| 'loraWeight'
@@ -48,7 +46,6 @@ export type Feature =
| 'paramVAEPrecision'
| 'paramWidth'
| 'patchmatchDownScaleSize'
| 'rasterLayer'
| 'refinerModel'
| 'refinerNegativeAestheticScore'
| 'refinerPositiveAestheticScore'
@@ -56,9 +53,6 @@ export type Feature =
| 'refinerStart'
| 'refinerSteps'
| 'refinerCfgScale'
| 'regionalGuidance'
| 'regionalGuidanceAndReferenceImage'
| 'regionalReferenceImage'
| 'scaleBeforeProcessing'
| 'seamlessTilingXAxis'
| 'seamlessTilingYAxis'
@@ -82,24 +76,6 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
clipSkip: {
href: 'https://support.invoke.ai/support/solutions/articles/151000178161-advanced-settings',
},
inpainting: {
href: 'https://support.invoke.ai/support/solutions/articles/151000096702-inpainting-outpainting-and-bounding-box',
},
rasterLayer: {
href: 'https://support.invoke.ai/support/solutions/articles/151000094998-raster-layers-and-initial-images',
},
regionalGuidance: {
href: 'https://support.invoke.ai/support/solutions/articles/151000165024-regional-guidance-layers',
},
regionalGuidanceAndReferenceImage: {
href: 'https://support.invoke.ai/support/solutions/articles/151000165024-regional-guidance-layers',
},
globalReferenceImage: {
href: 'https://support.invoke.ai/support/solutions/articles/151000159340-global-and-regional-reference-images-ip-adapters-',
},
regionalReferenceImage: {
href: 'https://support.invoke.ai/support/solutions/articles/151000159340-global-and-regional-reference-images-ip-adapters-',
},
controlNet: {
href: 'https://support.invoke.ai/support/solutions/articles/151000105880',
},

View File

@@ -127,6 +127,8 @@ export const buildUseDisclosure = (defaultIsOpen: boolean): [() => UseDisclosure
*
* Hook to manage a boolean state. Use this for a local boolean state.
* @param defaultIsOpen Initial state of the disclosure
*
* @knipignore
*/
export const useDisclosure = (defaultIsOpen: boolean): UseDisclosure => {
const [isOpen, set] = useState(defaultIsOpen);

View File

@@ -4,7 +4,6 @@ import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { selectSystemShouldEnableModelDescriptions } from 'features/system/store/systemSlice';
import { groupBy, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -38,7 +37,6 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();
const base = useAppSelector(selectBaseWithSDXLFallback);
const shouldShowModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions);
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelConfigs) {
@@ -53,7 +51,6 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
options: val.map((model) => ({
label: model.name,
value: model.key,
description: (shouldShowModelDescriptions && model.description) || undefined,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
})),
});
@@ -63,7 +60,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
);
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base) ? -1 : 1));
return _options;
}, [modelConfigs, groupByType, getIsDisabled, base, shouldShowModelDescriptions]);
}, [modelConfigs, groupByType, getIsDisabled, base]);
const value = useMemo(
() =>

View File

@@ -1,7 +1,5 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { selectSystemShouldEnableModelDescriptions } from 'features/system/store/systemSlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
@@ -26,16 +24,13 @@ type UseModelComboboxReturn = {
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
const { t } = useTranslation();
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const shouldShowModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions);
const options = useMemo<ComboboxOption[]>(() => {
return modelConfigs.filter(optionsFilter).map((model) => ({
label: model.name,
value: model.key,
description: (shouldShowModelDescriptions && model.description) || undefined,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelConfigs, shouldShowModelDescriptions]);
}, [optionsFilter, getIsDisabled, modelConfigs]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),

View File

@@ -1,161 +0,0 @@
import type { MenuButtonProps, MenuItemProps, MenuListProps, MenuProps } from '@invoke-ai/ui-library';
import { Box, Flex, Icon, Text } from '@invoke-ai/ui-library';
import { useDisclosure } from 'common/hooks/useBoolean';
import type { FocusEventHandler, PointerEvent, RefObject } from 'react';
import { useCallback, useEffect, useRef } from 'react';
import { PiCaretRightBold } from 'react-icons/pi';
import { useDebouncedCallback } from 'use-debounce';
const offset: [number, number] = [0, 8];
type UseSubMenuReturn = {
parentMenuItemProps: Partial<MenuItemProps>;
menuProps: Partial<MenuProps>;
menuButtonProps: Partial<MenuButtonProps>;
menuListProps: Partial<MenuListProps> & { ref: RefObject<HTMLDivElement> };
};
/**
* A hook that provides the necessary props to create a sub-menu within a menu.
*
* The sub-menu should be wrapped inside a parent `MenuItem` component.
*
* Use SubMenuButtonContent to render a button with a label and a right caret icon.
*
* TODO(psyche): Add keyboard handling for sub-menu.
*
* @example
* ```tsx
* const SubMenuExample = () => {
* const subMenu = useSubMenu();
* return (
* <Menu>
* <MenuButton>Open Parent Menu</MenuButton>
* <MenuList>
* <MenuItem>Parent Item 1</MenuItem>
* <MenuItem>Parent Item 2</MenuItem>
* <MenuItem>Parent Item 3</MenuItem>
* <MenuItem {...subMenu.parentMenuItemProps} icon={<PiImageBold />}>
* <Menu {...subMenu.menuProps}>
* <MenuButton {...subMenu.menuButtonProps}>
* <SubMenuButtonContent label="Open Sub Menu" />
* </MenuButton>
* <MenuList {...subMenu.menuListProps}>
* <MenuItem>Sub Item 1</MenuItem>
* <MenuItem>Sub Item 2</MenuItem>
* <MenuItem>Sub Item 3</MenuItem>
* </MenuList>
* </Menu>
* </MenuItem>
* </MenuList>
* </Menu>
* );
* };
* ```
*/
export const useSubMenu = (): UseSubMenuReturn => {
const subMenu = useDisclosure(false);
const menuListRef = useRef<HTMLDivElement>(null);
const closeDebounced = useDebouncedCallback(subMenu.close, 300);
const openAndCancelPendingClose = useCallback(() => {
closeDebounced.cancel();
subMenu.open();
}, [closeDebounced, subMenu]);
const toggleAndCancelPendingClose = useCallback(() => {
if (subMenu.isOpen) {
subMenu.close();
return;
} else {
closeDebounced.cancel();
subMenu.toggle();
}
}, [closeDebounced, subMenu]);
const onBlurMenuList = useCallback<FocusEventHandler<HTMLDivElement>>(
(e) => {
// Don't trigger blur if focus is moving to a child element - e.g. from a sub-menu item to another sub-menu item
if (e.currentTarget.contains(e.relatedTarget)) {
closeDebounced.cancel();
return;
}
subMenu.close();
},
[closeDebounced, subMenu]
);
const onParentMenuItemPointerLeave = useCallback(
(e: PointerEvent<HTMLButtonElement>) => {
/**
* The pointerleave event is triggered when the pen or touch device is lifted, which would close the sub-menu.
* However, we want to keep the sub-menu open until the pen or touch device pressed some other element. This
* will be handled in the useEffect below - just ignore the pointerleave event for pen and touch devices.
*/
if (e.pointerType === 'pen' || e.pointerType === 'touch') {
return;
}
subMenu.close();
},
[subMenu]
);
/**
* When using a mouse, the pointerleave events close the menu. But when using a pen or touch device, we need to close
* the sub-menu when the user taps outside of the menu list. So we need to listen for clicks outside of the menu list
* and close the menu accordingly.
*/
useEffect(() => {
const el = menuListRef.current;
if (!el) {
return;
}
const controller = new AbortController();
window.addEventListener(
'click',
(e) => {
if (menuListRef.current?.contains(e.target as Node)) {
return;
}
subMenu.close();
},
{ signal: controller.signal }
);
return () => {
controller.abort();
};
}, [subMenu]);
return {
parentMenuItemProps: {
onClick: toggleAndCancelPendingClose,
onPointerEnter: openAndCancelPendingClose,
onPointerLeave: onParentMenuItemPointerLeave,
closeOnSelect: false,
},
menuProps: {
isOpen: subMenu.isOpen,
onClose: subMenu.close,
placement: 'right',
offset: offset,
closeOnBlur: false,
},
menuButtonProps: {
as: Box,
width: 'full',
height: 'full',
},
menuListProps: {
ref: menuListRef,
onPointerEnter: openAndCancelPendingClose,
onPointerLeave: closeDebounced,
onBlur: onBlurMenuList,
},
};
};
export const SubMenuButtonContent = ({ label }: { label: string }) => {
return (
<Flex w="full" h="full" flexDir="row" justifyContent="space-between" alignItems="center">
<Text>{label}</Text>
<Icon as={PiCaretRightBold} />
</Flex>
);
};

View File

@@ -1,6 +1,5 @@
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
@@ -29,80 +28,69 @@ export const CanvasAddEntityButtons = memo(() => {
<Flex position="relative" flexDir="column" gap={4} top="20%">
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
<Heading size="xs">{t('controlLayers.global')}</Heading>
<InformationalPopover feature="globalReferenceImage">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
>
{t('controlLayers.globalReferenceImage')}
</Button>
</InformationalPopover>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
>
{t('controlLayers.globalReferenceImage')}
</Button>
</Flex>
<Flex flexDir="column" gap={2}>
<Heading size="xs">{t('controlLayers.regional')}</Heading>
<InformationalPopover feature="inpainting">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addInpaintMask}
>
{t('controlLayers.inpaintMask')}
</Button>
</InformationalPopover>
<InformationalPopover feature="regionalGuidance">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX}
>
{t('controlLayers.regionalGuidance')}
</Button>
</InformationalPopover>
<InformationalPopover feature="regionalReferenceImage">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX}
>
{t('controlLayers.regionalReferenceImage')}
</Button>
</InformationalPopover>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addInpaintMask}
>
{t('controlLayers.inpaintMask')}
</Button>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX}
>
{t('controlLayers.regionalGuidance')}
</Button>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX}
>
{t('controlLayers.regionalReferenceImage')}
</Button>
</Flex>
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
<Heading size="xs">{t('controlLayers.layer_other')}</Heading>
<InformationalPopover feature="controlNet">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
>
{t('controlLayers.controlLayer')}
</Button>
</InformationalPopover>
<InformationalPopover feature="rasterLayer">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRasterLayer}
>
{t('controlLayers.rasterLayer')}
</Button>
</InformationalPopover>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
>
{t('controlLayers.controlLayer')}
</Button>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRasterLayer}
>
{t('controlLayers.rasterLayer')}
</Button>
</Flex>
</Flex>
</Flex>

View File

@@ -13,7 +13,7 @@ export const CanvasAlertsPreserveMask = memo(() => {
}
return (
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content" alignSelf="flex-end">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.preserveMask.alert')}</AlertTitle>
</Alert>

View File

@@ -98,7 +98,7 @@ const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapte
}
return (
<Alert status={alert.status} borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<Alert status={alert.status} borderRadius="base" fontSize="sm" shadow="md" w="fit-content" alignSelf="flex-end">
<AlertIcon />
<AlertTitle>{alert.title}</AlertTitle>
</Alert>

View File

@@ -132,6 +132,7 @@ const AlertWrapper = ({
fontSize="sm"
shadow="md"
w="fit-content"
alignSelf="flex-end"
>
<Flex w="full" alignItems="center">
<AlertIcon />

View File

@@ -1,5 +1,4 @@
import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { MenuGroup, MenuItem } from '@invoke-ai/ui-library';
import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import {
@@ -17,8 +16,6 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
export const CanvasContextMenuGlobalMenuItems = memo(() => {
const { t } = useTranslation();
const saveSubMenu = useSubMenu();
const newSubMenu = useSubMenu();
const isBusy = useCanvasIsBusy();
const saveCanvasToGallery = useSaveCanvasToGallery();
const saveBboxToGallery = useSaveBboxToGallery();
@@ -31,41 +28,27 @@ export const CanvasContextMenuGlobalMenuItems = memo(() => {
<>
<MenuGroup title={t('controlLayers.canvasContextMenu.canvasGroup')}>
<CanvasContextMenuItemsCropCanvasToBbox />
<MenuItem {...saveSubMenu.parentMenuItemProps} icon={<PiFloppyDiskBold />}>
<Menu {...saveSubMenu.menuProps}>
<MenuButton {...saveSubMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.canvasContextMenu.saveToGalleryGroup')} />
</MenuButton>
<MenuList {...saveSubMenu.menuListProps}>
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveCanvasToGallery}>
{t('controlLayers.canvasContextMenu.saveCanvasToGallery')}
</MenuItem>
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveBboxToGallery}>
{t('controlLayers.canvasContextMenu.saveBboxToGallery')}
</MenuItem>
</MenuList>
</Menu>
</MenuGroup>
<MenuGroup title={t('controlLayers.canvasContextMenu.saveToGalleryGroup')}>
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveCanvasToGallery}>
{t('controlLayers.canvasContextMenu.saveCanvasToGallery')}
</MenuItem>
<MenuItem {...newSubMenu.parentMenuItemProps} icon={<NewLayerIcon />}>
<Menu {...newSubMenu.menuProps}>
<MenuButton {...newSubMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.canvasContextMenu.bboxGroup')} />
</MenuButton>
<MenuList {...newSubMenu.menuListProps}>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newGlobalReferenceImageFromBbox}>
{t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRegionalReferenceImageFromBbox}>
{t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newControlLayerFromBbox}>
{t('controlLayers.canvasContextMenu.newControlLayer')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRasterLayerFromBbox}>
{t('controlLayers.canvasContextMenu.newRasterLayer')}
</MenuItem>
</MenuList>
</Menu>
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveBboxToGallery}>
{t('controlLayers.canvasContextMenu.saveBboxToGallery')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.canvasContextMenu.bboxGroup')}>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newGlobalReferenceImageFromBbox}>
{t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRegionalReferenceImageFromBbox}>
{t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newControlLayerFromBbox}>
{t('controlLayers.canvasContextMenu.newControlLayer')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRasterLayerFromBbox}>
{t('controlLayers.canvasContextMenu.newRasterLayer')}
</MenuItem>
</MenuGroup>
</>

View File

@@ -1,43 +1,42 @@
import { MenuGroup } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ControlLayerMenuItems } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItems';
import { InpaintMaskMenuItems } from 'features/controlLayers/components/InpaintMask/InpaintMaskMenuItems';
import { IPAdapterMenuItems } from 'features/controlLayers/components/IPAdapter/IPAdapterMenuItems';
import { RasterLayerMenuItems } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItems';
import { RegionalGuidanceMenuItems } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItems';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import {
EntityIdentifierContext,
useEntityIdentifierContext,
} from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityTypeString } from 'features/controlLayers/hooks/useEntityTypeString';
import { useEntityTitle } from 'features/controlLayers/hooks/useEntityTitle';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { PropsWithChildren } from 'react';
import {
isFilterableEntityIdentifier,
isSaveableEntityIdentifier,
isSegmentableEntityIdentifier,
isTransformableEntityIdentifier,
} from 'features/controlLayers/store/types';
import { memo } from 'react';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const CanvasContextMenuSelectedEntityMenuItemsContent = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const title = useEntityTitle(entityIdentifier);
if (entityIdentifier.type === 'raster_layer') {
return <RasterLayerMenuItems />;
}
if (entityIdentifier.type === 'control_layer') {
return <ControlLayerMenuItems />;
}
if (entityIdentifier.type === 'inpaint_mask') {
return <InpaintMaskMenuItems />;
}
if (entityIdentifier.type === 'regional_guidance') {
return <RegionalGuidanceMenuItems />;
}
if (entityIdentifier.type === 'reference_image') {
return <IPAdapterMenuItems />;
}
assert<Equals<typeof entityIdentifier.type, never>>(false);
return (
<MenuGroup title={title}>
{isFilterableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsFilter />}
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsTransform />}
{isSegmentableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSegment />}
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCopyToClipboard />}
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSave />}
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCropToBbox />}
<CanvasEntityMenuItemsDelete />
</MenuGroup>
);
});
CanvasContextMenuSelectedEntityMenuItemsContent.displayName = 'CanvasContextMenuSelectedEntityMenuItemsContent';
export const CanvasContextMenuSelectedEntityMenuItems = memo(() => {
@@ -49,20 +48,9 @@ export const CanvasContextMenuSelectedEntityMenuItems = memo(() => {
return (
<EntityIdentifierContext.Provider value={selectedEntityIdentifier}>
<CanvasContextMenuSelectedEntityMenuGroup>
<CanvasContextMenuSelectedEntityMenuItemsContent />
</CanvasContextMenuSelectedEntityMenuGroup>
<CanvasContextMenuSelectedEntityMenuItemsContent />
</EntityIdentifierContext.Provider>
);
});
CanvasContextMenuSelectedEntityMenuItems.displayName = 'CanvasContextMenuSelectedEntityMenuItems';
const CanvasContextMenuSelectedEntityMenuGroup = memo((props: PropsWithChildren) => {
const entityIdentifier = useEntityIdentifierContext();
const title = useEntityTypeString(entityIdentifier.type);
return <MenuGroup title={title}>{props.children}</MenuGroup>;
});
CanvasContextMenuSelectedEntityMenuGroup.displayName = 'CanvasContextMenuSelectedEntityMenuGroup';

View File

@@ -62,7 +62,6 @@ export const CanvasDropArea = memo(() => {
data={addControlLayerFromImageDropData}
/>
</GridItem>
<GridItem position="relative">
<IAIDroppable
dropLabel={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}

View File

@@ -29,7 +29,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<Menu>
<MenuButton
as={IconButton}
minW={8}
size="sm"
variant="link"
alignSelf="stretch"
tooltip={t('controlLayers.addLayer')}

View File

@@ -4,7 +4,6 @@ import { EntityListSelectedEntityActionBarDuplicateButton } from 'features/contr
import { EntityListSelectedEntityActionBarFill } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFill';
import { EntityListSelectedEntityActionBarFilterButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFilterButton';
import { EntityListSelectedEntityActionBarOpacity } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarOpacity';
import { EntityListSelectedEntityActionBarSelectObjectButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarSelectObjectButton';
import { EntityListSelectedEntityActionBarTransformButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarTransformButton';
import { memo } from 'react';
@@ -17,7 +16,6 @@ export const EntityListSelectedEntityActionBar = memo(() => {
<Spacer />
<EntityListSelectedEntityActionBarFill />
<Flex h="full">
<EntityListSelectedEntityActionBarSelectObjectButton />
<EntityListSelectedEntityActionBarFilterButton />
<EntityListSelectedEntityActionBarTransformButton />
<EntityListSelectedEntityActionBarSaveToAssetsButton />

View File

@@ -23,7 +23,7 @@ export const EntityListSelectedEntityActionBarDuplicateButton = memo(() => {
<IconButton
onClick={onClick}
isDisabled={!selectedEntityIdentifier || isBusy}
minW={8}
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.duplicate')}

View File

@@ -5,7 +5,7 @@ import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/sel
import { isFilterableEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShootingStarFill } from 'react-icons/pi';
import { PiShootingStarBold } from 'react-icons/pi';
export const EntityListSelectedEntityActionBarFilterButton = memo(() => {
const { t } = useTranslation();
@@ -24,12 +24,12 @@ export const EntityListSelectedEntityActionBarFilterButton = memo(() => {
<IconButton
onClick={filter.start}
isDisabled={filter.isDisabled}
minW={8}
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.filter.filter')}
tooltip={t('controlLayers.filter.filter')}
icon={<PiShootingStarFill />}
icon={<PiShootingStarBold />}
/>
);
});

View File

@@ -31,7 +31,7 @@ export const EntityListSelectedEntityActionBarSaveToAssetsButton = memo(() => {
<IconButton
onClick={onClick}
isDisabled={!selectedEntityIdentifier || isBusy}
minW={8}
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.saveLayerToAssets')}

View File

@@ -1,37 +0,0 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntitySegmentAnything } from 'features/controlLayers/hooks/useEntitySegmentAnything';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { isSegmentableEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShapesFill } from 'react-icons/pi';
export const EntityListSelectedEntityActionBarSelectObjectButton = memo(() => {
const { t } = useTranslation();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const segment = useEntitySegmentAnything(selectedEntityIdentifier);
if (!selectedEntityIdentifier) {
return null;
}
if (!isSegmentableEntityIdentifier(selectedEntityIdentifier)) {
return null;
}
return (
<IconButton
onClick={segment.start}
isDisabled={segment.isDisabled}
minW={8}
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.selectObject.selectObject')}
tooltip={t('controlLayers.selectObject.selectObject')}
icon={<PiShapesFill />}
/>
);
});
EntityListSelectedEntityActionBarSelectObjectButton.displayName = 'EntityListSelectedEntityActionBarSelectObjectButton';

View File

@@ -24,7 +24,7 @@ export const EntityListSelectedEntityActionBarTransformButton = memo(() => {
<IconButton
onClick={transform.start}
isDisabled={transform.isDisabled}
minW={8}
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.transform.transform')}

View File

@@ -10,7 +10,7 @@ import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { SelectObject } from 'features/controlLayers/components/SelectObject/SelectObject';
import { SegmentAnything } from 'features/controlLayers/components/SegmentAnything/SegmentAnything';
import { StagingAreaIsStagingGate } from 'features/controlLayers/components/StagingArea/StagingAreaIsStagingGate';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
@@ -25,8 +25,8 @@ const MenuContent = () => {
return (
<CanvasManagerProviderGate>
<MenuList>
<CanvasContextMenuSelectedEntityMenuItems />
<CanvasContextMenuGlobalMenuItems />
<CanvasContextMenuSelectedEntityMenuItems />
</MenuList>
</CanvasManagerProviderGate>
);
@@ -71,16 +71,12 @@ export const CanvasMainPanelContent = memo(() => {
>
<InvokeCanvasComponent />
<CanvasManagerProviderGate>
<Flex
position="absolute"
flexDir="column"
top={1}
insetInlineStart={1}
pointerEvents="none"
gap={2}
alignItems="flex-start"
>
{showHUD && <CanvasHUD />}
{showHUD && (
<Flex position="absolute" top={1} insetInlineStart={1} pointerEvents="none">
<CanvasHUD />
</Flex>
)}
<Flex flexDir="column" position="absolute" top={1} insetInlineEnd={1} pointerEvents="none" gap={2}>
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsSendingToGallery />
@@ -106,7 +102,7 @@ export const CanvasMainPanelContent = memo(() => {
<CanvasManagerProviderGate>
<Filter />
<Transform />
<SelectObject />
<SegmentAnything />
</CanvasManagerProviderGate>
</Flex>
<CanvasDropArea />

View File

@@ -21,7 +21,7 @@ import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/s
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import { PiBoundingBoxBold, PiShootingStarBold, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
@@ -93,7 +93,7 @@ export const ControlLayerControlAdapter = memo(() => {
variant="link"
aria-label={t('controlLayers.filter.filter')}
tooltip={t('controlLayers.filter.filter')}
icon={<PiShootingStarFill />}
icon={<PiShootingStarBold />}
/>
<IconButton
onClick={pullBboxIntoLayer}

View File

@@ -1,15 +1,15 @@
import { MenuDivider } from '@invoke-ai/ui-library';
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { ControlLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertToSubMenu';
import { ControlLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsCopyToSubMenu';
import { ControlLayerMenuItemsConvertControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertControlToRaster';
import { ControlLayerMenuItemsTransparencyEffect } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsTransparencyEffect';
import { memo } from 'react';
@@ -24,12 +24,12 @@ export const ControlLayerMenuItems = memo(() => {
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsSelectObject />
<CanvasEntityMenuItemsSegment />
<ControlLayerMenuItemsConvertControlToRaster />
<ControlLayerMenuItemsTransparencyEffect />
<MenuDivider />
<ControlLayerMenuItemsCopyToSubMenu />
<ControlLayerMenuItemsConvertToSubMenu />
<CanvasEntityMenuItemsCropToBbox />
<CanvasEntityMenuItemsCopyToClipboard />
<CanvasEntityMenuItemsSave />
</>
);

View File

@@ -0,0 +1,27 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { controlLayerConvertedToRasterLayer } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLightningBold } from 'react-icons/pi';
export const ControlLayerMenuItemsConvertControlToRaster = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const convertControlLayerToRasterLayer = useCallback(() => {
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiLightningBold />} isDisabled={!isInteractable}>
{t('controlLayers.convertToRasterLayer')}
</MenuItem>
);
});
ControlLayerMenuItemsConvertControlToRaster.displayName = 'ControlLayerMenuItemsConvertControlToRaster';

View File

@@ -1,56 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import {
controlLayerConvertedToInpaintMask,
controlLayerConvertedToRasterLayer,
controlLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSwapBold } from 'react-icons/pi';
export const ControlLayerMenuItemsConvertToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const convertToInpaintMask = useCallback(() => {
dispatch(controlLayerConvertedToInpaintMask({ entityIdentifier, replace: true }));
}, [dispatch, entityIdentifier]);
const convertToRegionalGuidance = useCallback(() => {
dispatch(controlLayerConvertedToRegionalGuidance({ entityIdentifier, replace: true }));
}, [dispatch, entityIdentifier]);
const convertToRasterLayer = useCallback(() => {
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier, replace: true }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.convertControlLayerTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem onClick={convertToInpaintMask} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem onClick={convertToRegionalGuidance} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem onClick={convertToRasterLayer} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.rasterLayer')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
ControlLayerMenuItemsConvertToSubMenu.displayName = 'ControlLayerMenuItemsConvertToSubMenu';

View File

@@ -1,58 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import {
controlLayerConvertedToInpaintMask,
controlLayerConvertedToRasterLayer,
controlLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold } from 'react-icons/pi';
export const ControlLayerMenuItemsCopyToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const copyToInpaintMask = useCallback(() => {
dispatch(controlLayerConvertedToInpaintMask({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
const copyToRegionalGuidance = useCallback(() => {
dispatch(controlLayerConvertedToRegionalGuidance({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
const copyToRasterLayer = useCallback(() => {
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.copyControlLayerTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<CanvasEntityMenuItemsCopyToClipboard />
<MenuItem onClick={copyToInpaintMask} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newInpaintMask')}
</MenuItem>
<MenuItem onClick={copyToRegionalGuidance} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newRegionalGuidance')}
</MenuItem>
<MenuItem onClick={copyToRasterLayer} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newRasterLayer')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
ControlLayerMenuItemsCopyToSubMenu.displayName = 'ControlLayerMenuItemsCopyToSubMenu';

View File

@@ -1,15 +1,4 @@
import {
Button,
ButtonGroup,
Flex,
Heading,
Menu,
MenuButton,
MenuItem,
MenuList,
Spacer,
Spinner,
} from '@invoke-ai/ui-library';
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
@@ -26,7 +15,7 @@ import { IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold } from 'react-icons/pi';
const FilterContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
@@ -36,7 +25,7 @@ const FilterContent = memo(
const config = useStore(adapter.filterer.$filterConfig);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasImageState = useStore(adapter.filterer.$hasImageState);
const hasProcessed = useStore(adapter.filterer.$hasProcessed);
const autoProcess = useAppSelector(selectAutoProcess);
const onChangeFilterConfig = useCallback(
@@ -57,22 +46,6 @@ const FilterContent = memo(
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
}, [config]);
const saveAsInpaintMask = useCallback(() => {
adapter.filterer.saveAs('inpaint_mask');
}, [adapter.filterer]);
const saveAsRegionalGuidance = useCallback(() => {
adapter.filterer.saveAs('regional_guidance');
}, [adapter.filterer]);
const saveAsRasterLayer = useCallback(() => {
adapter.filterer.saveAs('raster_layer');
}, [adapter.filterer]);
const saveAsControlLayer = useCallback(() => {
adapter.filterer.saveAs('control_layer');
}, [adapter.filterer]);
useRegisteredHotkeys({
id: 'applyFilter',
category: 'canvas',
@@ -116,56 +89,40 @@ const FilterContent = memo(
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
variant="ghost"
leftIcon={<PiShootingStarBold />}
onClick={adapter.filterer.processImmediate}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.process')}
isDisabled={isProcessing || !isValid || (autoProcess && hasImageState)}
isDisabled={!isValid || autoProcess}
>
{t('controlLayers.filter.process')}
{isProcessing && <Spinner ms={3} boxSize={5} color="base.600" />}
</Button>
<Spacer />
<Button
leftIcon={<PiArrowsCounterClockwiseBold />}
onClick={adapter.filterer.reset}
isDisabled={isProcessing}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.reset')}
variant="ghost"
>
{t('controlLayers.filter.reset')}
</Button>
<Button
onClick={adapter.filterer.apply}
loadingText={t('controlLayers.filter.apply')}
variant="ghost"
isDisabled={isProcessing || !isValid || !hasImageState}
leftIcon={<PiCheckBold />}
onClick={adapter.filterer.apply}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.apply')}
isDisabled={!isValid || !hasProcessed}
>
{t('controlLayers.filter.apply')}
</Button>
<Menu>
<MenuButton
as={Button}
loadingText={t('controlLayers.selectObject.saveAs')}
variant="ghost"
isDisabled={isProcessing || !isValid || !hasImageState}
rightIcon={<PiCaretDownBold />}
>
{t('controlLayers.selectObject.saveAs')}
</MenuButton>
<MenuList>
<MenuItem isDisabled={isProcessing || !isValid || !hasImageState} onClick={saveAsInpaintMask}>
{t('controlLayers.newInpaintMask')}
</MenuItem>
<MenuItem isDisabled={isProcessing || !isValid || !hasImageState} onClick={saveAsRegionalGuidance}>
{t('controlLayers.newRegionalGuidance')}
</MenuItem>
<MenuItem isDisabled={isProcessing || !isValid || !hasImageState} onClick={saveAsControlLayer}>
{t('controlLayers.newControlLayer')}
</MenuItem>
<MenuItem isDisabled={isProcessing || !isValid || !hasImageState} onClick={saveAsRasterLayer}>
{t('controlLayers.newRasterLayer')}
</MenuItem>
</MenuList>
</Menu>
<Button variant="ghost" onClick={adapter.filterer.cancel} loadingText={t('controlLayers.filter.cancel')}>
<Button
variant="ghost"
leftIcon={<PiXBold />}
onClick={adapter.filterer.cancel}
loadingText={t('controlLayers.filter.cancel')}
>
{t('controlLayers.filter.cancel')}
</Button>
</ButtonGroup>

View File

@@ -1,22 +0,0 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
export const IPAdapterMenuItemPullBbox = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const isBusy = useCanvasIsBusy();
return (
<MenuItem onClick={pullBboxIntoIPAdapter} icon={<PiBoundingBoxBold />} isDisabled={isBusy}>
{t('controlLayers.pullBboxIntoReferenceImage')}
</MenuItem>
);
});
IPAdapterMenuItemPullBbox.displayName = 'IPAdapterMenuItemPullBbox';

View File

@@ -1,22 +1,16 @@
import { MenuDivider } from '@invoke-ai/ui-library';
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { IPAdapterMenuItemPullBbox } from 'features/controlLayers/components/IPAdapter/IPAdapterMenuItemPullBbox';
import { memo } from 'react';
export const IPAdapterMenuItems = memo(() => {
return (
<>
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
</IconMenuItemGroup>
<MenuDivider />
<IPAdapterMenuItemPullBbox />
</>
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
</IconMenuItemGroup>
);
});

View File

@@ -14,7 +14,7 @@ type Props = {
};
export const InpaintMask = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier<'inpaint_mask'>>(() => ({ id, type: 'inpaint_mask' }), [id]);
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'inpaint_mask' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>

View File

@@ -5,8 +5,6 @@ import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/componen
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { InpaintMaskMenuItemsConvertToSubMenu } from 'features/controlLayers/components/InpaintMask/InpaintMaskMenuItemsConvertToSubMenu';
import { InpaintMaskMenuItemsCopyToSubMenu } from 'features/controlLayers/components/InpaintMask/InpaintMaskMenuItemsCopyToSubMenu';
import { memo } from 'react';
export const InpaintMaskMenuItems = memo(() => {
@@ -20,8 +18,6 @@ export const InpaintMaskMenuItems = memo(() => {
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<MenuDivider />
<InpaintMaskMenuItemsCopyToSubMenu />
<InpaintMaskMenuItemsConvertToSubMenu />
<CanvasEntityMenuItemsCropToBbox />
</>
);

View File

@@ -1,38 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { inpaintMaskConvertedToRegionalGuidance } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSwapBold } from 'react-icons/pi';
export const InpaintMaskMenuItemsConvertToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('inpaint_mask');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const convertToRegionalGuidance = useCallback(() => {
dispatch(inpaintMaskConvertedToRegionalGuidance({ entityIdentifier, replace: true }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.convertInpaintMaskTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem onClick={convertToRegionalGuidance} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
InpaintMaskMenuItemsConvertToSubMenu.displayName = 'InpaintMaskMenuItemsConvertToSubMenu';

View File

@@ -1,40 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { inpaintMaskConvertedToRegionalGuidance } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold } from 'react-icons/pi';
export const InpaintMaskMenuItemsCopyToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('inpaint_mask');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const copyToRegionalGuidance = useCallback(() => {
dispatch(inpaintMaskConvertedToRegionalGuidance({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.copyInpaintMaskTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<CanvasEntityMenuItemsCopyToClipboard />
<MenuItem onClick={copyToRegionalGuidance} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newRegionalGuidance')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
InpaintMaskMenuItemsCopyToSubMenu.displayName = 'InpaintMaskMenuItemsCopyToSubMenu';

View File

@@ -1,15 +1,15 @@
import { MenuDivider } from '@invoke-ai/ui-library';
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu';
import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu';
import { RasterLayerMenuItemsConvertRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertRasterToControl';
import { memo } from 'react';
export const RasterLayerMenuItems = memo(() => {
@@ -23,11 +23,11 @@ export const RasterLayerMenuItems = memo(() => {
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsSelectObject />
<CanvasEntityMenuItemsSegment />
<RasterLayerMenuItemsConvertRasterToControl />
<MenuDivider />
<RasterLayerMenuItemsCopyToSubMenu />
<RasterLayerMenuItemsConvertToSubMenu />
<CanvasEntityMenuItemsCropToBbox />
<CanvasEntityMenuItemsCopyToClipboard />
<CanvasEntityMenuItemsSave />
</>
);

View File

@@ -0,0 +1,36 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { rasterLayerConvertedToControlLayer } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLightningBold } from 'react-icons/pi';
export const RasterLayerMenuItemsConvertRasterToControl = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isInteractable = useIsEntityInteractable(entityIdentifier);
const onClick = useCallback(() => {
dispatch(
rasterLayerConvertedToControlLayer({
entityIdentifier,
overrides: {
controlAdapter: defaultControlAdapter,
},
})
);
}, [defaultControlAdapter, dispatch, entityIdentifier]);
return (
<MenuItem onClick={onClick} icon={<PiLightningBold />} isDisabled={!isInteractable}>
{t('controlLayers.convertToControlLayer')}
</MenuItem>
);
});
RasterLayerMenuItemsConvertRasterToControl.displayName = 'RasterLayerMenuItemsConvertRasterToControl';

View File

@@ -1,65 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import {
rasterLayerConvertedToControlLayer,
rasterLayerConvertedToInpaintMask,
rasterLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSwapBold } from 'react-icons/pi';
export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isInteractable = useIsEntityInteractable(entityIdentifier);
const convertToInpaintMask = useCallback(() => {
dispatch(rasterLayerConvertedToInpaintMask({ entityIdentifier, replace: true }));
}, [dispatch, entityIdentifier]);
const convertToRegionalGuidance = useCallback(() => {
dispatch(rasterLayerConvertedToRegionalGuidance({ entityIdentifier, replace: true }));
}, [dispatch, entityIdentifier]);
const convertToControlLayer = useCallback(() => {
dispatch(
rasterLayerConvertedToControlLayer({
entityIdentifier,
replace: true,
overrides: { controlAdapter: defaultControlAdapter },
})
);
}, [defaultControlAdapter, dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.convertRasterLayerTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem onClick={convertToInpaintMask} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem onClick={convertToRegionalGuidance} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem onClick={convertToControlLayer} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.controlLayer')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
RasterLayerMenuItemsConvertToSubMenu.displayName = 'RasterLayerMenuItemsConvertToSubMenu';

View File

@@ -1,66 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import {
rasterLayerConvertedToControlLayer,
rasterLayerConvertedToInpaintMask,
rasterLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold } from 'react-icons/pi';
export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isInteractable = useIsEntityInteractable(entityIdentifier);
const copyToInpaintMask = useCallback(() => {
dispatch(rasterLayerConvertedToInpaintMask({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
const copyToRegionalGuidance = useCallback(() => {
dispatch(rasterLayerConvertedToRegionalGuidance({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
const copyToControlLayer = useCallback(() => {
dispatch(
rasterLayerConvertedToControlLayer({
entityIdentifier,
overrides: { controlAdapter: defaultControlAdapter },
})
);
}, [defaultControlAdapter, dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.copyRasterLayerTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<CanvasEntityMenuItemsCopyToClipboard />
<MenuItem onClick={copyToInpaintMask} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newInpaintMask')}
</MenuItem>
<MenuItem onClick={copyToRegionalGuidance} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newRegionalGuidance')}
</MenuItem>
<MenuItem onClick={copyToControlLayer} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newControlLayer')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
RasterLayerMenuItemsCopyToSubMenu.displayName = 'RasterLayerMenuItemsCopyToSubMenu';

View File

@@ -16,10 +16,7 @@ type Props = {
};
export const RegionalGuidance = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier<'regional_guidance'>>(
() => ({ id, type: 'regional_guidance' }),
[id]
);
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'regional_guidance' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>

View File

@@ -1,5 +1,4 @@
import { MenuDivider } from '@invoke-ai/ui-library';
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { Flex, MenuDivider } from '@invoke-ai/ui-library';
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
@@ -7,26 +6,22 @@ import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/component
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { RegionalGuidanceMenuItemsAddPromptsAndIPAdapter } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsAddPromptsAndIPAdapter';
import { RegionalGuidanceMenuItemsAutoNegative } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsAutoNegative';
import { RegionalGuidanceMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsConvertToSubMenu';
import { RegionalGuidanceMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsCopyToSubMenu';
import { memo } from 'react';
export const RegionalGuidanceMenuItems = memo(() => {
return (
<>
<IconMenuItemGroup>
<Flex gap={2}>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
</IconMenuItemGroup>
</Flex>
<MenuDivider />
<RegionalGuidanceMenuItemsAddPromptsAndIPAdapter />
<MenuDivider />
<CanvasEntityMenuItemsTransform />
<RegionalGuidanceMenuItemsAutoNegative />
<MenuDivider />
<RegionalGuidanceMenuItemsCopyToSubMenu />
<RegionalGuidanceMenuItemsConvertToSubMenu />
<CanvasEntityMenuItemsCropToBbox />
</>
);

View File

@@ -1,38 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { rgConvertedToInpaintMask } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSwapBold } from 'react-icons/pi';
export const RegionalGuidanceMenuItemsConvertToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const convertToInpaintMask = useCallback(() => {
dispatch(rgConvertedToInpaintMask({ entityIdentifier, replace: true }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.convertRegionalGuidanceTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem onClick={convertToInpaintMask} icon={<PiSwapBold />} isDisabled={!isInteractable}>
{t('controlLayers.inpaintMask')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
RegionalGuidanceMenuItemsConvertToSubMenu.displayName = 'RegionalGuidanceMenuItemsConvertToSubMenu';

View File

@@ -1,40 +0,0 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { rgConvertedToInpaintMask } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold } from 'react-icons/pi';
export const RegionalGuidanceMenuItemsCopyToSubMenu = memo(() => {
const { t } = useTranslation();
const subMenu = useSubMenu();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const copyToInpaintMask = useCallback(() => {
dispatch(rgConvertedToInpaintMask({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.copyRegionalGuidanceTo')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<CanvasEntityMenuItemsCopyToClipboard />
<MenuItem onClick={copyToInpaintMask} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.newInpaintMask')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
);
});
RegionalGuidanceMenuItemsCopyToSubMenu.displayName = 'RegionalGuidanceMenuItemsCopyToSubMenu';

View File

@@ -0,0 +1,124 @@
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { SegmentAnythingPointType } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingPointType';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiStarBold, PiXBold } from 'react-icons/pi';
const SegmentAnythingContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.segmentAnything.$isProcessing);
const hasPoints = useStore(adapter.segmentAnything.$hasPoints);
const autoProcess = useAppSelector(selectAutoProcess);
useRegisteredHotkeys({
id: 'applySegmentAnything',
category: 'canvas',
callback: adapter.segmentAnything.apply,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelSegmentAnything',
category: 'canvas',
callback: adapter.segmentAnything.cancel,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
minW={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.segment.autoMask')}
</Heading>
<Spacer />
<CanvasAutoProcessSwitch />
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<SegmentAnythingPointType adapter={adapter} />
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
leftIcon={<PiStarBold />}
onClick={adapter.segmentAnything.processImmediate}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.process')}
variant="ghost"
isDisabled={!hasPoints || autoProcess}
>
{t('controlLayers.segment.process')}
</Button>
<Spacer />
<Button
leftIcon={<PiArrowsCounterClockwiseBold />}
onClick={adapter.segmentAnything.reset}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.reset')}
variant="ghost"
>
{t('controlLayers.segment.reset')}
</Button>
<Button
leftIcon={<PiCheckBold />}
onClick={adapter.segmentAnything.apply}
isLoading={isProcessing}
loadingText={t('controlLayers.segment.apply')}
variant="ghost"
>
{t('controlLayers.segment.apply')}
</Button>
<Button
leftIcon={<PiXBold />}
onClick={adapter.segmentAnything.cancel}
isLoading={isProcessing}
loadingText={t('common.cancel')}
variant="ghost"
>
{t('controlLayers.segment.cancel')}
</Button>
</ButtonGroup>
</Flex>
);
}
);
SegmentAnythingContent.displayName = 'SegmentAnythingContent';
export const SegmentAnything = () => {
const canvasManager = useCanvasManager();
const adapter = useStore(canvasManager.stateApi.$segmentingAdapter);
if (!adapter) {
return null;
}
return <SegmentAnythingContent adapter={adapter} />;
};

View File

@@ -6,7 +6,7 @@ import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const SelectObjectPointType = memo(
export const SegmentAnythingPointType = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const pointType = useStore(adapter.segmentAnything.$pointTypeString);
@@ -21,15 +21,18 @@ export const SelectObjectPointType = memo(
);
return (
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.selectObject.pointType')}</FormLabel>
<FormControl w="full">
<FormLabel>{t('controlLayers.segment.pointType')}</FormLabel>
<RadioGroup value={pointType} onChange={onChange} w="full" size="md">
<Flex alignItems="center" w="full" gap={4} fontWeight="semibold" color="base.300">
<Radio value="foreground">
<Text>{t('controlLayers.selectObject.include')}</Text>
<Text>{t('controlLayers.segment.foreground')}</Text>
</Radio>
<Radio value="background">
<Text>{t('controlLayers.selectObject.exclude')}</Text>
<Text>{t('controlLayers.segment.background')}</Text>
</Radio>
<Radio value="neutral">
<Text>{t('controlLayers.segment.neutral')}</Text>
</Radio>
</Flex>
</RadioGroup>
@@ -38,4 +41,4 @@ export const SelectObjectPointType = memo(
}
);
SelectObjectPointType.displayName = 'SelectObject';
SegmentAnythingPointType.displayName = 'SegmentAnythingPointType';

View File

@@ -1,223 +0,0 @@
import {
Button,
ButtonGroup,
Flex,
Heading,
Icon,
ListItem,
Menu,
MenuButton,
MenuItem,
MenuList,
Spacer,
Spinner,
Text,
Tooltip,
UnorderedList,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
import { SelectObjectInvert } from 'features/controlLayers/components/SelectObject/SelectObjectInvert';
import { SelectObjectPointType } from 'features/controlLayers/components/SelectObject/SelectObjectPointType';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useRef } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiCaretDownBold, PiInfoBold } from 'react-icons/pi';
const SelectObjectContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.segmentAnything.$isProcessing);
const hasPoints = useStore(adapter.segmentAnything.$hasPoints);
const hasImageState = useStore(adapter.segmentAnything.$hasImageState);
const autoProcess = useAppSelector(selectAutoProcess);
const saveAsInpaintMask = useCallback(() => {
adapter.segmentAnything.saveAs('inpaint_mask');
}, [adapter.segmentAnything]);
const saveAsRegionalGuidance = useCallback(() => {
adapter.segmentAnything.saveAs('regional_guidance');
}, [adapter.segmentAnything]);
const saveAsRasterLayer = useCallback(() => {
adapter.segmentAnything.saveAs('raster_layer');
}, [adapter.segmentAnything]);
const saveAsControlLayer = useCallback(() => {
adapter.segmentAnything.saveAs('control_layer');
}, [adapter.segmentAnything]);
useRegisteredHotkeys({
id: 'applySegmentAnything',
category: 'canvas',
callback: adapter.segmentAnything.apply,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelSegmentAnything',
category: 'canvas',
callback: adapter.segmentAnything.cancel,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
minW={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
<Flex w="full" gap={4} alignItems="center">
<Flex gap={2}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.selectObject.selectObject')}
</Heading>
<Tooltip label={<SelectObjectHelpTooltipContent />}>
<Flex alignItems="center">
<Icon as={PiInfoBold} color="base.500" />
</Flex>
</Tooltip>
</Flex>
<Spacer />
<CanvasAutoProcessSwitch />
<CanvasOperationIsolatedLayerPreviewSwitch />
</Flex>
<Flex w="full" justifyContent="space-between" py={2}>
<SelectObjectPointType adapter={adapter} />
<SelectObjectInvert adapter={adapter} />
</Flex>
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
onClick={adapter.segmentAnything.processImmediate}
loadingText={t('controlLayers.selectObject.process')}
variant="ghost"
isDisabled={isProcessing || !hasPoints || (autoProcess && hasImageState)}
>
{t('controlLayers.selectObject.process')}
{isProcessing && <Spinner ms={3} boxSize={5} color="base.600" />}
</Button>
<Spacer />
<Button
onClick={adapter.segmentAnything.reset}
isDisabled={isProcessing || !hasPoints}
loadingText={t('controlLayers.selectObject.reset')}
variant="ghost"
>
{t('controlLayers.selectObject.reset')}
</Button>
<Button
onClick={adapter.segmentAnything.apply}
loadingText={t('controlLayers.selectObject.apply')}
variant="ghost"
isDisabled={isProcessing || !hasImageState}
>
{t('controlLayers.selectObject.apply')}
</Button>
<Menu>
<MenuButton
as={Button}
loadingText={t('controlLayers.selectObject.saveAs')}
variant="ghost"
isDisabled={isProcessing || !hasImageState}
rightIcon={<PiCaretDownBold />}
>
{t('controlLayers.selectObject.saveAs')}
</MenuButton>
<MenuList>
<MenuItem isDisabled={isProcessing || !hasImageState} onClick={saveAsInpaintMask}>
{t('controlLayers.newInpaintMask')}
</MenuItem>
<MenuItem isDisabled={isProcessing || !hasImageState} onClick={saveAsRegionalGuidance}>
{t('controlLayers.newRegionalGuidance')}
</MenuItem>
<MenuItem isDisabled={isProcessing || !hasImageState} onClick={saveAsControlLayer}>
{t('controlLayers.newControlLayer')}
</MenuItem>
<MenuItem isDisabled={isProcessing || !hasImageState} onClick={saveAsRasterLayer}>
{t('controlLayers.newRasterLayer')}
</MenuItem>
</MenuList>
</Menu>
<Button
onClick={adapter.segmentAnything.cancel}
isDisabled={isProcessing}
loadingText={t('common.cancel')}
variant="ghost"
>
{t('controlLayers.selectObject.cancel')}
</Button>
</ButtonGroup>
</Flex>
);
}
);
SelectObjectContent.displayName = 'SegmentAnythingContent';
export const SelectObject = memo(() => {
const canvasManager = useCanvasManager();
const adapter = useStore(canvasManager.stateApi.$segmentingAdapter);
if (!adapter) {
return null;
}
return <SelectObjectContent adapter={adapter} />;
});
SelectObject.displayName = 'SelectObject';
const Bold = (props: PropsWithChildren) => (
<Text as="span" fontWeight="semibold">
{props.children}
</Text>
);
const SelectObjectHelpTooltipContent = memo(() => {
const { t } = useTranslation();
return (
<Flex gap={3} flexDir="column">
<Text>
<Trans i18nKey="controlLayers.selectObject.help1" components={{ Bold: <Bold /> }} />
</Text>
<Text>
<Trans i18nKey="controlLayers.selectObject.help2" components={{ Bold: <Bold /> }} />
</Text>
<Text>
<Trans i18nKey="controlLayers.selectObject.help3" />
</Text>
<UnorderedList>
<ListItem>{t('controlLayers.selectObject.clickToAdd')}</ListItem>
<ListItem>{t('controlLayers.selectObject.dragToMove')}</ListItem>
<ListItem>{t('controlLayers.selectObject.clickToRemove')}</ListItem>
</UnorderedList>
</Flex>
);
});
SelectObjectHelpTooltipContent.displayName = 'SelectObjectHelpTooltipContent';

View File

@@ -1,26 +0,0 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const SelectObjectInvert = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const invert = useStore(adapter.segmentAnything.$invert);
const onChange = useCallback(() => {
adapter.segmentAnything.$invert.set(!adapter.segmentAnything.$invert.get());
}, [adapter.segmentAnything.$invert]);
return (
<FormControl w="min-content">
<FormLabel m={0}>{t('controlLayers.selectObject.invertSelection')}</FormLabel>
<Switch size="sm" isChecked={invert} onChange={onChange} />
</FormControl>
);
}
);
SelectObjectInvert.displayName = 'SelectObjectInvert';

View File

@@ -1,4 +1,4 @@
import { Button, ButtonGroup, Flex, Heading, Spacer, Spinner } from '@invoke-ai/ui-library';
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
@@ -8,6 +8,7 @@ import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEnt
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiXBold } from 'react-icons/pi';
const TransformContent = memo(({ adapter }: { adapter: CanvasEntityAdapter }) => {
const { t } = useTranslation();
@@ -61,28 +62,30 @@ const TransformContent = memo(({ adapter }: { adapter: CanvasEntityAdapter }) =>
<TransformFitToBboxButtons adapter={adapter} />
<ButtonGroup isAttached={false} size="sm" w="full" alignItems="center">
{isProcessing && <Spinner ms={3} boxSize={5} color="base.600" />}
<ButtonGroup isAttached={false} size="sm" w="full">
<Spacer />
<Button
leftIcon={<PiArrowsCounterClockwiseBold />}
onClick={adapter.transformer.resetTransform}
isDisabled={isProcessing}
isLoading={isProcessing}
loadingText={t('controlLayers.transform.reset')}
variant="ghost"
>
{t('controlLayers.transform.reset')}
</Button>
<Button
leftIcon={<PiCheckBold />}
onClick={adapter.transformer.applyTransform}
isDisabled={isProcessing}
isLoading={isProcessing}
loadingText={t('controlLayers.transform.apply')}
variant="ghost"
>
{t('controlLayers.transform.apply')}
</Button>
<Button
leftIcon={<PiXBold />}
onClick={adapter.transformer.stopTransform}
isDisabled={isProcessing}
isLoading={isProcessing}
loadingText={t('common.cancel')}
variant="ghost"
>

View File

@@ -4,6 +4,7 @@ import { useStore } from '@nanostores/react';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsOutBold } from 'react-icons/pi';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { z } from 'zod';
@@ -59,9 +60,10 @@ export const TransformFitToBboxButtons = memo(({ adapter }: { adapter: CanvasEnt
<Combobox options={options} value={value} onChange={onChange} isSearchable={false} isClearable={false} />
</FormControl>
<Button
leftIcon={<PiArrowsOutBold />}
size="sm"
onClick={onClick}
isDisabled={isProcessing}
isLoading={isProcessing}
loadingText={t('controlLayers.transform.fitToBbox')}
variant="ghost"
>

View File

@@ -1,11 +1,9 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useBoolean } from 'common/hooks/useBoolean';
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { PropsWithChildren } from 'react';
@@ -23,7 +21,6 @@ const _hover: SystemStyleObject = {
export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props) => {
const title = useEntityTypeTitle(type);
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
const collapse = useBoolean(true);
const canMergeVisible = useMemo(() => type === 'raster_layer' || type === 'inpaint_mask', [type]);
const canHideAll = useMemo(() => type !== 'reference_image', [type]);
@@ -50,30 +47,15 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props
transitionProperty="common"
transitionDuration="fast"
/>
{informationalPopoverFeature ? (
<InformationalPopover feature={informationalPopoverFeature}>
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
</InformationalPopover>
) : (
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
)}
<Text
fontWeight="semibold"
color={isSelected ? 'base.200' : 'base.500'}
userSelect="none"
transitionProperty="common"
transitionDuration="fast"
>
{title}
</Text>
<Spacer />
</Flex>
{canMergeVisible && <CanvasEntityMergeVisibleButton type={type} />}

View File

@@ -2,7 +2,6 @@ import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCopyLayerToClipboard } from 'features/controlLayers/hooks/useCopyLayerToClipboard';
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -13,7 +12,6 @@ export const CanvasEntityMenuItemsCopyToClipboard = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const adapter = useEntityAdapterSafe(entityIdentifier);
const isInteractable = useIsEntityInteractable(entityIdentifier);
const isEmpty = useEntityIsEmpty(entityIdentifier);
const copyLayerToClipboard = useCopyLayerToClipboard();
const onClick = useCallback(() => {
@@ -21,8 +19,8 @@ export const CanvasEntityMenuItemsCopyToClipboard = memo(() => {
}, [copyLayerToClipboard, adapter]);
return (
<MenuItem onClick={onClick} icon={<PiCopyBold />} isDisabled={!isInteractable || isEmpty}>
{t('common.clipboard')}
<MenuItem onClick={onClick} icon={<PiCopyBold />} isDisabled={!isInteractable}>
{t('controlLayers.copyToClipboard')}
</MenuItem>
);
});

View File

@@ -3,7 +3,7 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
import { useEntityFilter } from 'features/controlLayers/hooks/useEntityFilter';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShootingStarFill } from 'react-icons/pi';
import { PiShootingStarBold } from 'react-icons/pi';
export const CanvasEntityMenuItemsFilter = memo(() => {
const { t } = useTranslation();
@@ -11,7 +11,7 @@ export const CanvasEntityMenuItemsFilter = memo(() => {
const filter = useEntityFilter(entityIdentifier);
return (
<MenuItem onClick={filter.start} icon={<PiShootingStarFill />} isDisabled={filter.isDisabled}>
<MenuItem onClick={filter.start} icon={<PiShootingStarBold />} isDisabled={filter.isDisabled}>
{t('controlLayers.filter.filter')}
</MenuItem>
);

View File

@@ -3,18 +3,18 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
import { useEntitySegmentAnything } from 'features/controlLayers/hooks/useEntitySegmentAnything';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiShapesFill } from 'react-icons/pi';
import { PiMaskHappyBold } from 'react-icons/pi';
export const CanvasEntityMenuItemsSelectObject = memo(() => {
export const CanvasEntityMenuItemsSegment = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext();
const segmentAnything = useEntitySegmentAnything(entityIdentifier);
return (
<MenuItem onClick={segmentAnything.start} icon={<PiShapesFill />} isDisabled={segmentAnything.isDisabled}>
{t('controlLayers.selectObject.selectObject')}
<MenuItem onClick={segmentAnything.start} icon={<PiMaskHappyBold />} isDisabled={segmentAnything.isDisabled}>
{t('controlLayers.segment.autoMask')}
</MenuItem>
);
});
CanvasEntityMenuItemsSelectObject.displayName = 'CanvasEntityMenuItemsSelectObject';
CanvasEntityMenuItemsSegment.displayName = 'CanvasEntityMenuItemsSegment';

View File

@@ -24,9 +24,7 @@ import {
selectEntityOrThrow,
} from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
ControlNetConfig,
@@ -46,8 +44,6 @@ import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
export const selectDefaultControlAdapter = createSelector(
selectModelConfigsQuery,
@@ -128,60 +124,6 @@ export const useNewRasterLayerFromImage = () => {
return func;
};
export const useNewControlLayerFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasControlLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useNewInpaintMaskFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasInpaintMaskState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useNewRegionalGuidanceFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRegionalGuidanceState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(rgAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
/**
* Returns a function that adds a new canvas with the given image as the initial image, replicating the img2img flow:
* - Reset the canvas
@@ -196,31 +138,18 @@ export const useNewCanvasFromImage = () => {
const bboxRect = useAppSelector(selectBboxRect);
const base = useAppSelector(selectBboxModelBase);
const func = useCallback(
(imageDTO: ImageDTO, type: CanvasRasterLayerState['type'] | CanvasControlLayerState['type']) => {
(imageDTO: ImageDTO) => {
// Calculate the new bbox dimensions to fit the image's aspect ratio at the optimal size
const ratio = imageDTO.width / imageDTO.height;
const optimalDimension = getOptimalDimension(base);
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
// The overrides need to include the layer's ID so we can transform the layer it is initialized
let overrides: Partial<CanvasRasterLayerState> | Partial<CanvasControlLayerState>;
if (type === 'raster_layer') {
overrides = {
id: getPrefixedId('raster_layer'),
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageDTOToImageObject(imageDTO)],
} satisfies Partial<CanvasRasterLayerState>;
} else if (type === 'control_layer') {
overrides = {
id: getPrefixedId('control_layer'),
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageDTOToImageObject(imageDTO)],
} satisfies Partial<CanvasControlLayerState>;
} else {
// Catch unhandled types
assert<Equals<typeof type, never>>(false);
}
const overrides = {
id: getPrefixedId('raster_layer'),
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageDTOToImageObject(imageDTO)],
} satisfies Partial<CanvasRasterLayerState>;
CanvasEntityAdapterBase.registerInitCallback(async (adapter) => {
// Skip the callback if the adapter is not the one we are creating
@@ -237,16 +166,7 @@ export const useNewCanvasFromImage = () => {
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
// The type casts are safe because the type is checked above
if (type === 'raster_layer') {
dispatch(rasterLayerAdded({ overrides: overrides as Partial<CanvasRasterLayerState>, isSelected: true }));
} else if (type === 'control_layer') {
dispatch(controlLayerAdded({ overrides: overrides as Partial<CanvasControlLayerState>, isSelected: true }));
} else {
// Catch unhandled types
assert<Equals<typeof type, never>>(false);
}
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
},
[base, bboxRect.x, bboxRect.y, dispatch]
);

View File

@@ -1,4 +1,3 @@
import { logger } from 'app/logging/logger';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
@@ -8,9 +7,6 @@ import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
const log = logger('canvas');
export const useCopyLayerToClipboard = () => {
const { t } = useTranslation();
@@ -30,13 +26,11 @@ export const useCopyLayerToClipboard = () => {
const canvas = adapter.getCanvas();
const blob = await canvasToBlob(canvas);
copyBlobToClipboard(blob);
log.trace('Layer copied to clipboard');
toast({
status: 'info',
title: t('toast.layerCopiedToClipboard'),
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
toast({
status: 'error',
title: t('toast.problemCopyingLayer'),

View File

@@ -5,13 +5,11 @@ import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdap
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isFilterableEntityIdentifier } from 'features/controlLayers/store/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useCallback, useMemo } from 'react';
export const useEntityFilter = (entityIdentifier: CanvasEntityIdentifier | null) => {
const canvasManager = useCanvasManager();
const adapter = useEntityAdapterSafe(entityIdentifier);
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
@@ -52,9 +50,8 @@ export const useEntityFilter = (entityIdentifier: CanvasEntityIdentifier | null)
if (!adapter) {
return;
}
imageViewer.close();
adapter.filterer.start();
}, [isDisabled, entityIdentifier, canvasManager, imageViewer]);
}, [isDisabled, entityIdentifier, canvasManager]);
return { isDisabled, start } as const;
};

View File

@@ -1,11 +0,0 @@
import { useAppSelector } from 'app/store/storeHooks';
import { buildSelectHasObjects } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
export const useEntityIsEmpty = (entityIdentifier: CanvasEntityIdentifier) => {
const selectHasObjects = useMemo(() => buildSelectHasObjects(entityIdentifier), [entityIdentifier]);
const hasObjects = useAppSelector(selectHasObjects);
return !hasObjects;
};

View File

@@ -5,13 +5,11 @@ import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdap
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isSegmentableEntityIdentifier } from 'features/controlLayers/store/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useCallback, useMemo } from 'react';
export const useEntitySegmentAnything = (entityIdentifier: CanvasEntityIdentifier | null) => {
const canvasManager = useCanvasManager();
const adapter = useEntityAdapterSafe(entityIdentifier);
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
@@ -52,9 +50,8 @@ export const useEntitySegmentAnything = (entityIdentifier: CanvasEntityIdentifie
if (!adapter) {
return;
}
imageViewer.close();
adapter.segmentAnything.start();
}, [isDisabled, entityIdentifier, canvasManager, imageViewer]);
}, [isDisabled, entityIdentifier, canvasManager]);
return { isDisabled, start } as const;
};

View File

@@ -5,13 +5,11 @@ import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdap
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isTransformableEntityIdentifier } from 'features/controlLayers/store/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useCallback, useMemo } from 'react';
export const useEntityTransform = (entityIdentifier: CanvasEntityIdentifier | null) => {
const canvasManager = useCanvasManager();
const adapter = useEntityAdapterSafe(entityIdentifier);
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
@@ -52,9 +50,8 @@ export const useEntityTransform = (entityIdentifier: CanvasEntityIdentifier | nu
if (!adapter) {
return;
}
imageViewer.close();
await adapter.transformer.startTransform();
}, [isDisabled, entityIdentifier, canvasManager, imageViewer]);
}, [isDisabled, entityIdentifier, canvasManager]);
const fitToBbox = useCallback(async () => {
if (isDisabled) {
@@ -70,11 +67,10 @@ export const useEntityTransform = (entityIdentifier: CanvasEntityIdentifier | nu
if (!adapter) {
return;
}
imageViewer.close();
await adapter.transformer.startTransform({ silent: true });
adapter.transformer.fitToBboxContain();
await adapter.transformer.applyTransform();
}, [canvasManager, entityIdentifier, imageViewer, isDisabled]);
}, [canvasManager, entityIdentifier, isDisabled]);
return { isDisabled, start, fitToBbox } as const;
};

View File

@@ -1,25 +0,0 @@
import type { Feature } from 'common/components/InformationalPopover/constants';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
export const useEntityTypeInformationalPopover = (type: CanvasEntityIdentifier['type']): Feature | undefined => {
const feature = useMemo(() => {
switch (type) {
case 'control_layer':
return 'controlNet';
case 'inpaint_mask':
return 'inpainting';
case 'raster_layer':
return 'rasterLayer';
case 'regional_guidance':
return 'regionalGuidanceAndReferenceImage';
case 'reference_image':
return 'globalReferenceImage';
default:
return undefined;
}
}, [type]);
return feature;
};

View File

@@ -1,36 +1,27 @@
import { deepClone } from 'common/util/deepClone';
import { withResult, withResultAsync } from 'common/util/result';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type { FilterConfig } from 'features/controlLayers/store/filters';
import { getFilterForModel, IMAGE_FILTERS } from 'features/controlLayers/store/filters';
import type { CanvasEntityType, CanvasImageState } from 'features/controlLayers/store/types';
import type { CanvasImageState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import Konva from 'konva';
import { debounce } from 'lodash-es';
import { atom, computed } from 'nanostores';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
import stableHash from 'stable-hash';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
type CanvasEntityFiltererConfig = {
/**
* The debounce time in milliseconds for processing the filter.
*/
PROCESS_DEBOUNCE_MS: number;
processDebounceMs: number;
};
const DEFAULT_CONFIG: CanvasEntityFiltererConfig = {
PROCESS_DEBOUNCE_MS: 1000,
processDebounceMs: 1000,
};
export class CanvasEntityFilterer extends CanvasModuleBase {
@@ -41,65 +32,20 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasEntityFiltererConfig = DEFAULT_CONFIG;
imageState: CanvasImageState | null = null;
subscriptions = new Set<() => void>();
config: CanvasEntityFiltererConfig = DEFAULT_CONFIG;
/**
* The AbortController used to cancel the filter processing.
*/
abortController: AbortController | null = null;
/**
* Whether the module is currently filtering an image.
*/
$isFiltering = atom<boolean>(false);
/**
* The hash of the last processed config. This is used to prevent re-processing the same config.
*/
$lastProcessedHash = atom<string>('');
/**
* Whether the module is currently processing the filter.
*/
$hasProcessed = atom<boolean>(false);
$isProcessing = atom<boolean>(false);
/**
* The config for the filter.
*/
$filterConfig = atom<FilterConfig>(IMAGE_FILTERS.canny_edge_detection.buildDefaults());
/**
* The initial filter config, used to reset the filter config.
*/
$initialFilterConfig = atom<FilterConfig | null>(null);
/**
* The ephemeral image state of the filtered image.
*/
$imageState = atom<CanvasImageState | null>(null);
/**
* Whether the module has an image state. This is a computed value based on $imageState.
*/
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
/**
* The filtered image object module, if it exists.
*/
imageModule: CanvasObjectImage | null = null;
/**
* The Konva nodes for the module.
*/
konva: {
/**
* The main Konva group node for the module. This is added to the parent layer on start, and removed on teardown.
*/
group: Konva.Group;
};
KONVA_GROUP_NAME = `${this.type}:group`;
constructor(parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer) {
super();
this.id = getPrefixedId(this.type);
@@ -109,17 +55,9 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.log = this.manager.buildLogger(this);
this.log.debug('Creating filter module');
this.konva = {
group: new Konva.Group({ name: this.KONVA_GROUP_NAME }),
};
}
/**
* Adds event listeners needed while filtering the entity.
*/
subscribe = () => {
// As the filter config changes, process the filter
this.subscriptions.add(
this.$filterConfig.listen(() => {
if (this.manager.stateApi.getSettings().autoProcess && this.$isFiltering.get()) {
@@ -127,7 +65,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
}
})
);
// When auto-process is enabled, process the filter
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectAutoProcess, (autoProcess) => {
if (autoProcess && this.$isFiltering.get()) {
@@ -137,18 +74,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
);
};
/**
* Removes event listeners used while filtering the entity.
*/
unsubscribe = () => {
this.subscriptions.forEach((unsubscribe) => unsubscribe());
this.subscriptions.clear();
};
/**
* Starts the filter module.
* @param config The filter config to start with. If omitted, the default filter config is used.
*/
start = (config?: FilterConfig) => {
const filteringAdapter = this.manager.stateApi.$filteringAdapter.get();
if (filteringAdapter) {
@@ -158,57 +88,30 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.log.trace('Initializing filter');
// Reset any previous state
this.resetEphemeralState();
this.$isFiltering.set(true);
// Update the konva group's position to match the parent entity
const pixelRect = this.parent.transformer.$pixelRect.get();
const position = addCoords(this.parent.state.position, pixelRect);
this.konva.group.setAttrs(position);
// Add the group to the parent layer
this.parent.konva.layer.add(this.konva.group);
if (config) {
// If a config is provided, use it
this.$filterConfig.set(config);
this.$initialFilterConfig.set(config);
} else {
this.$filterConfig.set(this.createInitialFilterConfig());
}
this.$initialFilterConfig.set(this.$filterConfig.get());
this.subscribe();
this.manager.stateApi.$filteringAdapter.set(this.parent);
if (this.manager.stateApi.getSettings().autoProcess) {
this.processImmediate();
}
};
createInitialFilterConfig = (): FilterConfig => {
if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) {
if (config) {
this.$filterConfig.set(config);
} else if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) {
// If the parent is a control layer adapter, we should check if the model has a default filter and set it if so
const selectModelConfig = buildSelectModelConfig(
this.parent.state.controlAdapter.model.key,
isControlNetOrT2IAdapterModelConfig
);
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
// This always returns a filter
const filter = getFilterForModel(modelConfig);
return filter.buildDefaults();
this.$filterConfig.set(filter.buildDefaults());
} else {
// Otherwise, used the default filter
return IMAGE_FILTERS.canny_edge_detection.buildDefaults();
// Otherwise, set the default filter
this.$filterConfig.set(IMAGE_FILTERS.canny_edge_detection.buildDefaults());
}
this.$isFiltering.set(true);
this.manager.stateApi.$filteringAdapter.set(this.parent);
if (this.manager.stateApi.getSettings().autoProcess) {
this.processImmediate();
}
};
/**
* Processes the filter, updating the module's state and rendering the filtered image.
*/
processImmediate = async () => {
const config = this.$filterConfig.get();
const filterData = IMAGE_FILTERS[config.type];
@@ -220,12 +123,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
return;
}
const hash = stableHash({ config });
if (hash === this.$lastProcessedHash.get()) {
this.log.trace('Already processed config');
return;
}
this.log.trace({ config }, 'Processing filter');
const rect = this.parent.transformer.getRelativeRect();
@@ -259,181 +156,91 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.manager.stateApi.runGraphAndReturnImageOutput({
graph,
outputNodeId,
// The filter graph should always be prepended to the queue so it's processed ASAP.
prepend: true,
/**
* The filter node may need to download a large model. Currently, the models required by the filter nodes are
* downloaded just-in-time, as required by the filter. If we use a timeout here, we might get into a catch-22
* where the filter node is waiting for the model to download, but the download gets canceled if the filter
* node times out.
*
* (I suspect the model download will actually _not_ be canceled if the graph is canceled, but let's not chance it!)
*
* TODO(psyche): Figure out a better way to handle this. Probably need to download the models ahead of time.
*/
// timeout: 5000,
/**
* The filter node should be able to cancel the request if it's taking too long. This will cancel the graph's
* queue item and clear any event listeners on the request.
*/
signal: controller.signal,
})
);
// If there is an error, log it and bail out of this processing run
if (filterResult.isErr()) {
this.log.error({ error: serializeError(filterResult.error) }, 'Error filtering');
this.log.error({ error: serializeError(filterResult.error) }, 'Error processing filter');
this.$isProcessing.set(false);
// Clean up the abort controller as needed
if (!this.abortController.signal.aborted) {
this.abortController.abort();
}
this.abortController = null;
return;
}
this.log.trace({ imageDTO: filterResult.value }, 'Filtered');
this.log.trace({ imageDTO: filterResult.value }, 'Filter processed');
this.imageState = imageDTOToImageObject(filterResult.value);
// Prepare the ephemeral image state
const imageState = imageDTOToImageObject(filterResult.value);
this.$imageState.set(imageState);
// Destroy any existing masked image and create a new one
if (this.imageModule) {
this.imageModule.destroy();
}
this.imageModule = new CanvasObjectImage(imageState, this);
// Force update the masked image - after awaiting, the image will be rendered (in memory)
await this.imageModule.update(imageState, true);
this.konva.group.add(this.imageModule.konva.group);
// The porcessing is complete, set can set the last processed hash and isProcessing to false
this.$lastProcessedHash.set(hash);
await this.parent.bufferRenderer.setBuffer(this.imageState, true);
this.$isProcessing.set(false);
// Clean up the abort controller as needed
if (!this.abortController.signal.aborted) {
this.abortController.abort();
}
this.$hasProcessed.set(true);
this.abortController = null;
};
/**
* Debounced version of processImmediate.
*/
process = debounce(this.processImmediate, this.config.PROCESS_DEBOUNCE_MS);
process = debounce(this.processImmediate, this.config.processDebounceMs);
/**
* Applies the filter image to the entity, replacing the entity's objects with the filtered image.
*/
apply = () => {
const filteredImageObjectState = this.$imageState.get();
if (!filteredImageObjectState) {
const imageState = this.imageState;
if (!imageState) {
this.log.warn('No image state to apply filter to');
return;
}
this.log.trace('Applying');
// Have the parent adopt the image module - this prevents a flash of the original layer content before the filtered
// image is rendered
if (this.imageModule) {
this.parent.renderer.adoptObjectRenderer(this.imageModule);
}
// Rasterize the entity, replacing the objects with the masked image
this.log.trace('Applying filter');
this.parent.bufferRenderer.commitBuffer();
const rect = this.parent.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.entityIdentifier,
imageObject: filteredImageObjectState,
imageObject: imageState,
position: {
x: Math.round(rect.x),
y: Math.round(rect.y),
},
replaceObjects: true,
});
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
/**
* Saves the filtered image as a new entity of the given type.
* @param type The type of entity to save the filtered image as.
*/
saveAs = (type: Exclude<CanvasEntityType, 'reference_image'>) => {
const imageState = this.$imageState.get();
if (!imageState) {
this.log.warn('No image state to apply filter to');
return;
}
this.log.trace(`Saving as ${type}`);
const rect = this.parent.transformer.getRelativeRect();
const arg = {
overrides: {
objects: [imageState],
position: {
x: Math.round(rect.x),
y: Math.round(rect.y),
},
},
isSelected: true,
};
switch (type) {
case 'raster_layer':
this.manager.stateApi.addRasterLayer(arg);
break;
case 'control_layer':
this.manager.stateApi.addControlLayer(arg);
break;
case 'inpaint_mask':
this.manager.stateApi.addInpaintMask(arg);
break;
case 'regional_guidance':
this.manager.stateApi.addRegionalGuidance(arg);
break;
default:
assert<Equals<typeof type, never>>(false);
}
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
resetEphemeralState = () => {
// First we need to bail out of any processing
if (this.abortController && !this.abortController.signal.aborted) {
this.abortController.abort();
}
this.abortController = null;
// If the image module exists, and is a child of the group, destroy it. It might not be a child of the group if
// the user has applied the filter and the image has been adopted by the parent entity.
if (this.imageModule && this.imageModule.konva.group.parent === this.konva.group) {
this.imageModule.destroy();
this.imageModule = null;
}
const initialFilterConfig = this.$initialFilterConfig.get() ?? this.createInitialFilterConfig();
this.$filterConfig.set(initialFilterConfig);
this.$imageState.set(null);
this.$lastProcessedHash.set('');
this.$isProcessing.set(false);
};
teardown = () => {
this.$initialFilterConfig.set(null);
this.konva.group.remove();
this.imageState = null;
this.unsubscribe();
this.$isFiltering.set(false);
this.$hasProcessed.set(false);
this.manager.stateApi.$filteringAdapter.set(null);
};
/**
* Resets the module (e.g. remove all points and the mask image).
*
* Does not cancel or otherwise complete the segmenting process.
*/
reset = () => {
this.log.trace('Resetting');
this.resetEphemeralState();
this.log.trace('Resetting filter');
this.abortController?.abort();
this.abortController = null;
this.parent.bufferRenderer.clearBuffer();
this.parent.transformer.updatePosition();
this.parent.renderer.syncKonvaCache(true);
this.imageState = null;
this.$hasProcessed.set(false);
};
cancel = () => {
this.log.trace('Canceling');
this.resetEphemeralState();
this.teardown();
this.log.trace('Cancelling filter');
this.reset();
this.unsubscribe();
this.$isProcessing.set(false);
this.$isFiltering.set(false);
this.$hasProcessed.set(false);
this.manager.stateApi.$filteringAdapter.set(null);
};
repr = () => {
@@ -441,14 +248,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
id: this.id,
type: this.type,
path: this.path,
parent: this.parent.id,
config: this.config,
imageState: deepClone(this.$imageState.get()),
$isFiltering: this.$isFiltering.get(),
$lastProcessedHash: this.$lastProcessedHash.get(),
$hasProcessed: this.$hasProcessed.get(),
$isProcessing: this.$isProcessing.get(),
$filterConfig: this.$filterConfig.get(),
konva: { group: getKonvaNodeDebugAttrs(this.konva.group) },
};
};
@@ -459,6 +263,5 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
}
this.abortController = null;
this.unsubscribe();
this.konva.group.destroy();
};
}

View File

@@ -1,7 +1,6 @@
import { Mutex } from 'async-mutex';
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
@@ -22,8 +21,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
| CanvasEntityObjectRenderer
| CanvasEntityBufferObjectRenderer
| CanvasStagingAreaModule
| CanvasSegmentAnythingModule
| CanvasEntityFilterer;
| CanvasSegmentAnythingModule;
readonly manager: CanvasManager;
readonly log: Logger;
@@ -45,7 +43,6 @@ export class CanvasObjectImage extends CanvasModuleBase {
| CanvasEntityBufferObjectRenderer
| CanvasStagingAreaModule
| CanvasSegmentAnythingModule
| CanvasEntityFilterer
) {
super();
this.id = state.id;

View File

@@ -6,22 +6,15 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import {
addCoords,
getKonvaNodeDebugAttrs,
getPrefixedId,
offsetCoord,
roundCoord,
} from 'features/controlLayers/konva/util';
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
import type {
CanvasEntityType,
CanvasImageState,
Coordinate,
RgbaColor,
SAMPoint,
SAMPointLabel,
SAMPointLabelString,
SAMPointWithId,
} from 'features/controlLayers/store/types';
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
@@ -34,9 +27,6 @@ import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
type CanvasSegmentAnythingModuleConfig = {
/**
@@ -80,7 +70,7 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // light green
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
MASK_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity
PROCESS_DEBOUNCE_MS: 1000,
};
@@ -95,7 +85,6 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
type SAMPointState = {
id: string;
label: SAMPointLabel;
coord: Coordinate;
konva: {
circle: Konva.Circle;
};
@@ -114,7 +103,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
subscriptions = new Set<() => void>();
/**
* The AbortController used to cancel the segment processing.
* The AbortController used to cancel the filter processing.
*/
abortController: AbortController | null = null;
@@ -124,9 +113,9 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
$isSegmenting = atom<boolean>(false);
/**
* The hash of the last processed points. This is used to prevent re-processing the same points.
* Whether the current set of points has been processed.
*/
$lastProcessedHash = atom<string>('');
$hasProcessed = atom<boolean>(false);
/**
* Whether the module is currently processing the points.
@@ -155,15 +144,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
/**
* The ephemeral image state of the processed image. Only used while segmenting.
*/
$imageState = atom<CanvasImageState | null>(null);
imageState: CanvasImageState | null = null;
/**
* Whether the module has an image state. This is a computed value based on $imageState.
*/
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
/**
* The current input points. A listener is added to this atom to process the points when they change.
* The current input points.
*/
$points = atom<SAMPointState[]>([]);
@@ -173,21 +157,16 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
$hasPoints = computed(this.$points, (points) => points.length > 0);
/**
* Whether the module should invert the mask image.
* The masked image object, if it exists.
*/
$invert = atom<boolean>(false);
/**
* The masked image object module, if it exists.
*/
imageModule: CanvasObjectImage | null = null;
maskedImage: CanvasObjectImage | null = null;
/**
* The Konva nodes for the module.
*/
konva: {
/**
* The main Konva group node for the module. This is added to the parent layer on start, and removed on teardown.
* The main Konva group node for the module.
*/
group: Konva.Group;
/**
@@ -208,10 +187,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
* It's rendered with a globalCompositeOperation of 'source-atop' to preview the mask as a semi-transparent overlay.
*/
compositingRect: Konva.Rect;
/**
* A tween for pulsing the mask group's opacity.
*/
maskTween: Konva.Tween | null;
};
KONVA_CIRCLE_NAME = `${this.type}:circle`;
@@ -234,7 +209,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.konva = {
group: new Konva.Group({ name: this.KONVA_GROUP_NAME }),
pointGroup: new Konva.Group({ name: this.KONVA_POINT_GROUP_NAME }),
maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME, opacity: 0.6 }),
maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME }),
compositingRect: new Konva.Rect({
name: this.KONVA_COMPOSITING_RECT_NAME,
fill: rgbaColorToString(this.config.MASK_COLOR),
@@ -244,7 +219,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
perfectDrawEnabled: false,
visible: false,
}),
maskTween: null,
};
// Points should always be rendered above the mask group
@@ -276,12 +250,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
const id = getPrefixedId('sam_point');
const roundedCoord = roundCoord(coord);
const circle = new Konva.Circle({
name: this.KONVA_CIRCLE_NAME,
x: roundedCoord.x,
y: roundedCoord.y,
x: Math.round(coord.x),
y: Math.round(coord.y),
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes
fill: rgbaColorToString(this.getSAMPointColor(label)),
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
@@ -298,18 +270,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
if (this.$isDraggingPoint.get()) {
return;
}
if (e.evt.button !== 0) {
return;
}
// This event should not bubble up to the parent, stage or any other nodes
e.cancelBubble = true;
circle.destroy();
const newPoints = this.$points.get().filter((point) => point.id !== id);
if (newPoints.length === 0) {
this.$points.set(this.$points.get().filter((point) => point.id !== id));
if (this.$points.get().length === 0) {
this.resetEphemeralState();
} else {
this.$points.set(newPoints);
this.$hasProcessed.set(false);
}
});
@@ -318,28 +286,25 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
});
circle.on('dragend', () => {
const roundedCoord = roundCoord(circle.position());
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Moved SAM point');
this.$isDraggingPoint.set(false);
const newPoints = this.$points.get().map((point) => {
if (point.id === id) {
return { ...point, coord: roundedCoord };
}
return point;
});
this.$points.set(newPoints);
// Point has changed!
this.$hasProcessed.set(false);
this.$points.notify();
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Moved SAM point'
);
});
this.konva.pointGroup.add(circle);
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Created SAM point');
this.log.trace(
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
'Created SAM point'
);
return {
id,
coord: roundedCoord,
label,
konva: { circle },
};
@@ -362,14 +327,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
/**
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
*/
getSAMPoints = (): SAMPointWithId[] => {
const points: SAMPointWithId[] = [];
getSAMPoints = (): SAMPoint[] => {
const points: SAMPoint[] = [];
for (const { id, coord, label } of this.$points.get()) {
for (const { konva, label } of this.$points.get()) {
points.push({
id,
x: coord.x,
y: coord.y,
// Pull out and round the x and y values from Konva
x: Math.round(konva.circle.x()),
y: Math.round(konva.circle.y()),
label,
});
}
@@ -416,8 +381,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
// Create a SAM point at the normalized position
const point = this.createPoint(normalizedPoint, this.$pointType.get());
const newPoints = [...this.$points.get(), point];
this.$points.set(newPoints);
this.$points.set([...this.$points.get(), point]);
// Mark the module as having _not_ processed the points now that they have changed
this.$hasProcessed.set(false);
};
/**
@@ -454,20 +421,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
if (points.length === 0) {
return;
}
if (this.manager.stateApi.getSettings().autoProcess) {
this.process();
}
})
);
// When the invert flag changes, process if autoProcess is enabled
this.subscriptions.add(
this.$invert.listen(() => {
if (this.$points.get().length === 0) {
return;
}
if (this.manager.stateApi.getSettings().autoProcess) {
this.process();
}
@@ -480,7 +433,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
if (this.$points.get().length === 0) {
return;
}
if (autoProcess) {
if (autoProcess && !this.$hasProcessed.get()) {
this.process();
}
})
@@ -488,7 +441,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
};
/**
* Removes event listeners used while segmenting the entity.
* Adds event listeners needed while segmenting the entity.
*/
unsubscribe = () => {
this.subscriptions.forEach((unsubscribe) => unsubscribe());
@@ -547,14 +500,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
return;
}
const invert = this.$invert.get();
const hash = stableHash({ points, invert });
if (hash === this.$lastProcessedHash.get()) {
this.log.trace('Already processed points');
return;
}
this.$isProcessing.set(true);
this.log.trace({ points }, 'Segmenting');
@@ -576,7 +521,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.abortController = controller;
// Build the graph for segmenting the image, using the rasterized image DTO
const { graph, outputNodeId } = CanvasSegmentAnythingModule.buildGraph(rasterizeResult.value, points, invert);
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value);
// Run the graph and get the segmented image output
const segmentResult = await withResultAsync(() =>
@@ -603,56 +548,38 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.log.trace({ imageDTO: segmentResult.value }, 'Segmented');
// Prepare the ephemeral image state
const imageState = imageDTOToImageObject(segmentResult.value);
this.$imageState.set(imageState);
this.imageState = imageDTOToImageObject(segmentResult.value);
// Destroy any existing masked image and create a new one
if (this.imageModule) {
this.imageModule.destroy();
if (this.maskedImage) {
this.maskedImage.destroy();
}
if (this.konva.maskTween) {
this.konva.maskTween.destroy();
this.konva.maskTween = null;
}
this.imageModule = new CanvasObjectImage(imageState, this);
this.maskedImage = new CanvasObjectImage(this.imageState, this);
// Force update the masked image - after awaiting, the image will be rendered (in memory)
await this.imageModule.update(imageState, true);
await this.maskedImage.update(this.imageState, true);
// Update the compositing rect to match the image size
this.konva.compositingRect.setAttrs({
width: imageState.image.width,
height: imageState.image.height,
width: this.imageState.image.width,
height: this.imageState.image.height,
visible: true,
});
// Now we can add the masked image to the mask group. It will be rendered above the compositing rect, but should be
// under it, so we will move the compositing rect to the top
this.konva.maskGroup.add(this.imageModule.konva.group);
this.konva.maskGroup.add(this.maskedImage.konva.group);
this.konva.compositingRect.moveToTop();
// Cache the group to ensure the mask is rendered correctly w/ opacity
this.konva.maskGroup.cache();
// Create a pulsing tween
this.konva.maskTween = new Konva.Tween({
node: this.konva.maskGroup,
duration: 1,
opacity: 0.4, // oscillate between this value and pre-tween opacity
yoyo: true,
repeat: Infinity,
easing: Konva.Easings.EaseOut,
});
// Start the pulsing effect
this.konva.maskTween.play();
this.$lastProcessedHash.set(hash);
// We are done processing (still segmenting though!)
this.$isProcessing.set(false);
// The current points have been processed
this.$hasProcessed.set(true);
// Clean up the abort controller as needed
if (!this.abortController.signal.aborted) {
this.abortController.abort();
@@ -666,17 +593,24 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
process = debounce(this.processImmediate, this.config.PROCESS_DEBOUNCE_MS);
/**
* Applies the segmented image to the entity, replacing the entity's objects with the masked image.
* Applies the segmented image to the entity.
*/
apply = () => {
const imageState = this.$imageState.get();
if (!this.$hasProcessed.get()) {
this.log.error('Cannot apply unprocessed points');
return;
}
const imageState = this.imageState;
if (!imageState) {
this.log.error('No image state to apply');
return;
}
this.log.trace('Applying');
// Rasterize the entity, replacing the objects with the masked image
// Commit the buffer, which will move the buffer to from the layers' buffer renderer to its main renderer
this.parent.bufferRenderer.commitBuffer();
// Rasterize the entity, this time replacing the objects with the masked image
const rect = this.parent.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.entityIdentifier,
@@ -693,59 +627,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
this.teardown();
};
/**
* Saves the segmented image as a new entity of the given type.
* @param type The type of entity to save the segmented image as.
*/
saveAs = (type: Exclude<CanvasEntityType, 'reference_image'>) => {
const imageState = this.$imageState.get();
if (!imageState) {
this.log.error('No image state to save as');
return;
}
this.log.trace(`Saving as ${type}`);
// Have the parent adopt the image module - this prevents a flash of the original layer content before the
// segmented image is rendered
if (this.imageModule) {
this.parent.renderer.adoptObjectRenderer(this.imageModule);
}
// Create the new entity with the masked image as its only object
const rect = this.parent.transformer.getRelativeRect();
const arg = {
overrides: {
objects: [imageState],
position: {
x: Math.round(rect.x),
y: Math.round(rect.y),
},
},
isSelected: true,
};
switch (type) {
case 'raster_layer':
this.manager.stateApi.addRasterLayer(arg);
break;
case 'control_layer':
this.manager.stateApi.addControlLayer(arg);
break;
case 'inpaint_mask':
this.manager.stateApi.addInpaintMask(arg);
break;
case 'regional_guidance':
this.manager.stateApi.addRegionalGuidance(arg);
break;
default:
assert<Equals<typeof type, never>>(false);
}
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
/**
* Resets the module (e.g. remove all points and the mask image).
*
@@ -802,39 +683,30 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
for (const point of this.$points.get()) {
point.konva.circle.destroy();
}
// If the image module exists, and is a child of the group, destroy it. It might not be a child of the group if
// the user has applied the segmented image and the image has been adopted by the parent entity.
if (this.imageModule && this.imageModule.konva.group.parent === this.konva.group) {
this.imageModule.destroy();
this.imageModule = null;
}
if (this.konva.maskTween) {
this.konva.maskTween.destroy();
this.konva.maskTween = null;
if (this.maskedImage) {
this.maskedImage.destroy();
}
// Empty internal module state
this.$points.set([]);
this.$imageState.set(null);
this.imageState = null;
this.$pointType.set(1);
this.$invert.set(false);
this.$lastProcessedHash.set('');
this.$hasProcessed.set(false);
this.$isProcessing.set(false);
// Reset non-ephemeral konva nodes
this.konva.compositingRect.visible(false);
this.konva.maskGroup.clearCache();
// The parent module's buffer should be reset & forcibly sync the cache
this.parent.bufferRenderer.clearBuffer();
this.parent.renderer.syncKonvaCache(true);
};
/**
* Builds a graph for segmenting an image with the given image DTO.
*/
static buildGraph = (
{ image_name }: ImageDTO,
points: SAMPointWithId[],
invert: boolean
): { graph: Graph; outputNodeId: string } => {
buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => {
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
// TODO(psyche): When SAM2 is available in transformers, use it here
@@ -844,7 +716,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
type: 'segment_anything',
model: 'segment-anything-huge',
image: { image_name },
point_lists: [{ points: points.map(({ x, y, label }) => ({ x, y, label })) }],
point_lists: [{ points: this.getSAMPoints() }],
mask_filter: 'largest',
});
@@ -853,7 +725,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
id: getPrefixedId('apply_tensor_mask_to_image'),
type: 'apply_tensor_mask_to_image',
image: { image_name },
invert,
});
graph.addEdge(segmentAnything, 'mask', applyMask, 'mask');
@@ -888,11 +759,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
label,
circle: getKonvaNodeDebugAttrs(konva.circle),
})),
imageState: deepClone(this.$imageState.get()),
imageModule: this.imageModule?.repr(),
imageState: deepClone(this.imageState),
maskedImage: this.maskedImage?.repr(),
config: deepClone(this.config),
$isSegmenting: this.$isSegmenting.get(),
$lastProcessedHash: this.$lastProcessedHash.get(),
$hasProcessed: this.$hasProcessed.get(),
$isProcessing: this.$isProcessing.get(),
$pointType: this.$pointType.get(),
$pointTypeString: this.$pointTypeString.get(),

View File

@@ -51,16 +51,10 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
/**
* Sync the $isStaging flag with the redux state. $isStaging is used by the manager to determine the global busy
* state of the canvas.
*
* We also set the $shouldShowStagedImage flag when we enter staging mode, so that the staged images are shown,
* even if the user disabled this in the last staging session.
*/
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectIsStaging, (isStaging, oldIsStaging) => {
this.manager.stateApi.createStoreSubscription(selectIsStaging, (isStaging) => {
this.$isStaging.set(isStaging);
if (isStaging && !oldIsStaging) {
this.$shouldShowStagedImage.set(true);
}
})
);
}

View File

@@ -17,16 +17,12 @@ import {
} from 'features/controlLayers/store/canvasSettingsSlice';
import {
bboxChangedFromCanvas,
controlLayerAdded,
entityBrushLineAdded,
entityEraserLineAdded,
entityMoved,
entityRasterized,
entityRectAdded,
entityReset,
inpaintMaskAdded,
rasterLayerAdded,
rgAdded,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasStagingAreaSlice } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
@@ -55,7 +51,6 @@ import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
import { QueueError } from 'services/events/errors';
import type { Param0 } from 'tsafe';
import { assert } from 'tsafe';
import type { CanvasEntityAdapter } from './CanvasEntity/types';
@@ -165,34 +160,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
this.store.dispatch(entityRectAdded(arg));
};
/**
* Adds a raster layer to the canvas, pushing state to redux.
*/
addRasterLayer = (arg: Param0<typeof rasterLayerAdded>) => {
this.store.dispatch(rasterLayerAdded(arg));
};
/**
* Adds a control layer to the canvas, pushing state to redux.
*/
addControlLayer = (arg: Param0<typeof controlLayerAdded>) => {
this.store.dispatch(controlLayerAdded(arg));
};
/**
* Adds an inpaint mask to the canvas, pushing state to redux.
*/
addInpaintMask = (arg: Param0<typeof inpaintMaskAdded>) => {
this.store.dispatch(inpaintMaskAdded(arg));
};
/**
* Adds regional guidance to the canvas, pushing state to redux.
*/
addRegionalGuidance = (arg: Param0<typeof rgAdded>) => {
this.store.dispatch(rgAdded(arg));
};
/**
* Rasterizes an entity, pushing state to redux.
*/
@@ -293,8 +260,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
},
};
let didSuceed = false;
/**
* If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout
* if the graph completes or errors before the timeout.
@@ -346,8 +311,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
return;
}
didSuceed = true;
// Ok!
resolve(getImageDTOResult.value);
};
@@ -438,10 +401,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
if (timeout) {
timeoutId = window.setTimeout(() => {
if (didSuceed) {
// If we already succeeded, we don't need to do anything
return;
}
this.log.trace('Graph canceled by timeout');
clearListeners();
cancelGraph();
@@ -451,10 +410,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
if (signal) {
signal.addEventListener('abort', () => {
if (didSuceed) {
// If we already succeeded, we don't need to do anything
return;
}
this.log.trace('Graph canceled by signal');
_clearTimeout();
clearListeners();

View File

@@ -216,14 +216,12 @@ export class CanvasEraserToolModule extends CanvasModuleBase {
*/
onStagePointerDown = async (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!cursorPos || !selectedEntity || !isPrimaryPointerDown) {
if (!cursorPos || !selectedEntity) {
/**
* Can't do anything without:
* - A cursor position: the cursor is not on the stage
* - The mouse is down: the user is not drawing
* - A selected entity: there is no entity to draw on
*/
return;

View File

@@ -160,16 +160,11 @@ export class CanvasToolModule extends CanvasModuleBase {
const stage = this.manager.stage;
const tool = this.$tool.get();
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
const transformingAdapter = this.manager.stateApi.$transformingAdapter.get();
if (this.manager.stage.getIsDragging()) {
this.tools.view.syncCursorStyle();
} else if (tool === 'view') {
if ((this.manager.stage.getIsDragging() || tool === 'view') && !segmentingAdapter) {
this.tools.view.syncCursorStyle();
} else if (segmentingAdapter) {
segmentingAdapter.segmentAnything.syncCursorStyle();
} else if (transformingAdapter) {
// The transformer handles cursor style via events
} else if (this.manager.stateApi.$isFiltering.get()) {
stage.setCursor('not-allowed');
} else if (this.manager.stagingArea.$isStaging.get()) {

View File

@@ -126,13 +126,6 @@ export const floorCoord = (coord: Coordinate): Coordinate => {
};
};
export const roundCoord = (coord: Coordinate): Coordinate => {
return {
x: Math.round(coord.x),
y: Math.round(coord.y),
};
};
/**
* Snaps a position to the edge of the given rect if within a threshold of the edge
* @param pos The position to snap

View File

@@ -29,7 +29,7 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import { merge, omit } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@@ -57,13 +57,13 @@ import type {
} from './types';
import { getEntityIdentifier, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
getInpaintMaskState,
getRasterLayerState,
getReferenceImageState,
getRegionalGuidanceState,
imageDTOToImageWithDims,
initialControlNet,
initialIPAdapter,
} from './util';
@@ -157,25 +157,28 @@ export const canvasSlice = createSlice({
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasControlLayerState>; replace?: boolean },
'raster_layer'
>
EntityIdentifierPayload<{ newId: string; overrides?: Partial<CanvasControlLayerState> }, 'raster_layer'>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
const { entityIdentifier, newId, overrides } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the raster layer to control layer
const controlLayerState = converters.rasterLayer.toControlLayer(newId, layer, overrides);
const controlLayerState: CanvasControlLayerState = {
...deepClone(layer),
id: newId,
type: 'control_layer',
controlAdapter: deepClone(initialControlNet),
withTransparencyEffect: true,
};
if (replace) {
// Remove the raster layer
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
}
merge(controlLayerState, overrides);
// Remove the raster layer
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
// Add the converted control layer
state.controlLayers.entities.push(controlLayerState);
@@ -183,90 +186,11 @@ export const canvasSlice = createSlice({
state.selectedEntityIdentifier = { type: controlLayerState.type, id: controlLayerState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasControlLayerState>; replace?: boolean } | undefined,
'raster_layer'
>
payload: EntityIdentifierPayload<{ overrides?: Partial<CanvasControlLayerState> } | undefined, 'raster_layer'>
) => ({
payload: { ...payload, newId: getPrefixedId('control_layer') },
}),
},
rasterLayerConvertedToInpaintMask: {
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean },
'raster_layer'
>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the raster layer to inpaint mask
const inpaintMaskState = converters.rasterLayer.toInpaintMask(newId, layer, overrides);
if (replace) {
// Remove the raster layer
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
}
// Add the converted inpaint mask
state.inpaintMasks.entities.push(inpaintMaskState);
state.selectedEntityIdentifier = { type: inpaintMaskState.type, id: inpaintMaskState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean } | undefined,
'raster_layer'
>
) => ({
payload: { ...payload, newId: getPrefixedId('inpaint_mask') },
}),
},
rasterLayerConvertedToRegionalGuidance: {
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean },
'raster_layer'
>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the raster layer to inpaint mask
const regionalGuidanceState = converters.rasterLayer.toRegionalGuidance(newId, layer, overrides);
if (replace) {
// Remove the raster layer
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
}
// Add the converted inpaint mask
state.regionalGuidance.entities.push(regionalGuidanceState);
state.selectedEntityIdentifier = { type: regionalGuidanceState.type, id: regionalGuidanceState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean } | undefined,
'raster_layer'
>
) => ({
payload: { ...payload, newId: getPrefixedId('regional_guidance') },
}),
},
//#region Control layers
controlLayerAdded: {
reducer: (
@@ -293,125 +217,32 @@ export const canvasSlice = createSlice({
state.selectedEntityIdentifier = { type: 'control_layer', id: data.id };
},
controlLayerConvertedToRasterLayer: {
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasRasterLayerState>; replace?: boolean },
'control_layer'
>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
reducer: (state, action: PayloadAction<EntityIdentifierPayload<{ newId: string }, 'control_layer'>>) => {
const { entityIdentifier, newId } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the raster layer to control layer
const rasterLayerState = converters.controlLayer.toRasterLayer(newId, layer, overrides);
const rasterLayerState: CanvasRasterLayerState = {
...omit(deepClone(layer), ['type', 'controlAdapter', 'withTransparencyEffect']),
id: newId,
type: 'raster_layer',
};
if (replace) {
// Remove the control layer
state.controlLayers.entities = state.controlLayers.entities.filter(
(layer) => layer.id !== entityIdentifier.id
);
}
// Remove the control layer
state.controlLayers.entities = state.controlLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
// Add the new raster layer
state.rasterLayers.entities.push(rasterLayerState);
state.selectedEntityIdentifier = { type: rasterLayerState.type, id: rasterLayerState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasRasterLayerState>; replace?: boolean } | undefined,
'control_layer'
>
) => ({
prepare: (payload: EntityIdentifierPayload<void, 'control_layer'>) => ({
payload: { ...payload, newId: getPrefixedId('raster_layer') },
}),
},
controlLayerConvertedToInpaintMask: {
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean },
'control_layer'
>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the control layer to inpaint mask
const inpaintMaskState = converters.controlLayer.toInpaintMask(newId, layer, overrides);
if (replace) {
// Remove the control layer
state.controlLayers.entities = state.controlLayers.entities.filter(
(layer) => layer.id !== entityIdentifier.id
);
}
// Add the new inpaint mask
state.inpaintMasks.entities.push(inpaintMaskState);
state.selectedEntityIdentifier = { type: inpaintMaskState.type, id: inpaintMaskState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean } | undefined,
'control_layer'
>
) => ({
payload: { ...payload, newId: getPrefixedId('inpaint_mask') },
}),
},
controlLayerConvertedToRegionalGuidance: {
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean },
'control_layer'
>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the control layer to regional guidance
const regionalGuidanceState = converters.controlLayer.toRegionalGuidance(newId, layer, overrides);
if (replace) {
// Remove the control layer
state.controlLayers.entities = state.controlLayers.entities.filter(
(layer) => layer.id !== entityIdentifier.id
);
}
// Add the new regional guidance
state.regionalGuidance.entities.push(regionalGuidanceState);
state.selectedEntityIdentifier = { type: regionalGuidanceState.type, id: regionalGuidanceState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean } | undefined,
'control_layer'
>
) => ({
payload: { ...payload, newId: getPrefixedId('regional_guidance') },
}),
},
controlLayerModelChanged: (
state,
action: PayloadAction<
@@ -616,46 +447,6 @@ export const canvasSlice = createSlice({
state.regionalGuidance.entities.push(data);
state.selectedEntityIdentifier = { type: 'regional_guidance', id: data.id };
},
rgConvertedToInpaintMask: {
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean },
'regional_guidance'
>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the regional guidance to inpaint mask
const inpaintMaskState = converters.regionalGuidance.toInpaintMask(newId, layer, overrides);
if (replace) {
// Remove the regional guidance
state.regionalGuidance.entities = state.regionalGuidance.entities.filter(
(layer) => layer.id !== entityIdentifier.id
);
}
// Add the new inpaint mask
state.inpaintMasks.entities.push(inpaintMaskState);
state.selectedEntityIdentifier = { type: inpaintMaskState.type, id: inpaintMaskState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean } | undefined,
'regional_guidance'
>
) => ({
payload: { ...payload, newId: getPrefixedId('inpaint_mask') },
}),
},
rgPositivePromptChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ prompt: string | null }, 'regional_guidance'>>
@@ -853,44 +644,6 @@ export const canvasSlice = createSlice({
state.inpaintMasks.entities = [data];
state.selectedEntityIdentifier = { type: 'inpaint_mask', id: data.id };
},
inpaintMaskConvertedToRegionalGuidance: {
reducer: (
state,
action: PayloadAction<
EntityIdentifierPayload<
{ newId: string; overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean },
'inpaint_mask'
>
>
) => {
const { entityIdentifier, newId, overrides, replace } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer) {
return;
}
// Convert the inpaint mask to regional guidance
const regionalGuidanceState = converters.inpaintMask.toRegionalGuidance(newId, layer, overrides);
if (replace) {
// Remove the inpaint mask
state.inpaintMasks.entities = state.inpaintMasks.entities.filter((layer) => layer.id !== entityIdentifier.id);
}
// Add the new regional guidance
state.regionalGuidance.entities.push(regionalGuidanceState);
state.selectedEntityIdentifier = { type: regionalGuidanceState.type, id: regionalGuidanceState.id };
},
prepare: (
payload: EntityIdentifierPayload<
{ overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean } | undefined,
'inpaint_mask'
>
) => ({
payload: { ...payload, newId: getPrefixedId('regional_guidance') },
}),
},
//#region BBox
bboxScaledWidthChanged: (state, action: PayloadAction<number>) => {
const gridSize = getGridSize(state.bbox.modelBase);
@@ -1457,14 +1210,10 @@ export const {
rasterLayerAdded,
// rasterLayerRecalled,
rasterLayerConvertedToControlLayer,
rasterLayerConvertedToInpaintMask,
rasterLayerConvertedToRegionalGuidance,
// Control layers
controlLayerAdded,
// controlLayerRecalled,
controlLayerConvertedToRasterLayer,
controlLayerConvertedToInpaintMask,
controlLayerConvertedToRegionalGuidance,
controlLayerModelChanged,
controlLayerControlModeChanged,
controlLayerWeightChanged,
@@ -1482,7 +1231,6 @@ export const {
// Regions
rgAdded,
// rgRecalled,
rgConvertedToInpaintMask,
rgPositivePromptChanged,
rgNegativePromptChanged,
rgAutoNegativeToggled,
@@ -1496,7 +1244,6 @@ export const {
rgIPAdapterCLIPVisionModelChanged,
// Inpaint mask
inpaintMaskAdded,
inpaintMaskConvertedToRegionalGuidance,
// inpaintMaskRecalled,
} = canvasSlice.actions;

View File

@@ -349,27 +349,6 @@ export const buildSelectIsSelected = (entityIdentifier: CanvasEntityIdentifier)
);
};
/**
* Builds a selector that selects if the entity is empty.
*
* Reference images are considered empty if the IP adapter is empty.
*
* Other entities are considered empty if they have no objects.
*/
export const buildSelectHasObjects = (entityIdentifier: CanvasEntityIdentifier) => {
return createSelector(selectCanvasSlice, (canvas) => {
const entity = selectEntity(canvas, entityIdentifier);
if (!entity) {
return false;
}
if (entity.type === 'reference_image') {
return entity.ipAdapter.image !== null;
}
return entity.objects.length > 0;
});
};
export const selectWidth = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.width);
export const selectHeight = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.height);
export const selectAspectRatioID = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.id);

View File

@@ -131,8 +131,7 @@ const zSAMPoint = z.object({
y: z.number().int().gte(0),
label: zSAMPointLabel,
});
type SAMPoint = z.infer<typeof zSAMPoint>;
export type SAMPointWithId = SAMPoint & { id: string };
export type SAMPoint = z.infer<typeof zSAMPoint>;
const zRect = z.object({
x: z.number(),

View File

@@ -184,153 +184,3 @@ export const getInpaintMaskState = (
merge(entityState, overrides);
return entityState;
};
const convertRasterLayerToControlLayer = (
newId: string,
rasterLayerState: CanvasRasterLayerState,
overrides?: Partial<CanvasControlLayerState>
): CanvasControlLayerState => {
const { name, objects, position } = rasterLayerState;
const controlLayerState = getControlLayerState(newId, {
name,
objects,
position,
});
merge(controlLayerState, overrides);
return controlLayerState;
};
const convertRasterLayerToInpaintMask = (
newId: string,
rasterLayerState: CanvasRasterLayerState,
overrides?: Partial<CanvasInpaintMaskState>
): CanvasInpaintMaskState => {
const { name, objects, position } = rasterLayerState;
const inpaintMaskState = getInpaintMaskState(newId, {
name,
objects,
position,
});
merge(inpaintMaskState, overrides);
return inpaintMaskState;
};
const convertRasterLayerToRegionalGuidance = (
newId: string,
rasterLayerState: CanvasRasterLayerState,
overrides?: Partial<CanvasRegionalGuidanceState>
): CanvasRegionalGuidanceState => {
const { name, objects, position } = rasterLayerState;
const regionalGuidanceState = getRegionalGuidanceState(newId, {
name,
objects,
position,
});
merge(regionalGuidanceState, overrides);
return regionalGuidanceState;
};
const convertControlLayerToRasterLayer = (
newId: string,
controlLayerState: CanvasControlLayerState,
overrides?: Partial<CanvasRasterLayerState>
): CanvasRasterLayerState => {
const { name, objects, position } = controlLayerState;
const rasterLayerState = getRasterLayerState(newId, {
name,
objects,
position,
});
merge(rasterLayerState, overrides);
return rasterLayerState;
};
const convertControlLayerToInpaintMask = (
newId: string,
rasterLayerState: CanvasControlLayerState,
overrides?: Partial<CanvasInpaintMaskState>
): CanvasInpaintMaskState => {
const { name, objects, position } = rasterLayerState;
const inpaintMaskState = getInpaintMaskState(newId, {
name,
objects,
position,
});
merge(inpaintMaskState, overrides);
return inpaintMaskState;
};
const convertControlLayerToRegionalGuidance = (
newId: string,
rasterLayerState: CanvasControlLayerState,
overrides?: Partial<CanvasRegionalGuidanceState>
): CanvasRegionalGuidanceState => {
const { name, objects, position } = rasterLayerState;
const regionalGuidanceState = getRegionalGuidanceState(newId, {
name,
objects,
position,
});
merge(regionalGuidanceState, overrides);
return regionalGuidanceState;
};
const convertInpaintMaskToRegionalGuidance = (
newId: string,
inpaintMaskState: CanvasInpaintMaskState,
overrides?: Partial<CanvasRegionalGuidanceState>
): CanvasRegionalGuidanceState => {
const { name, objects, position } = inpaintMaskState;
const regionalGuidanceState = getRegionalGuidanceState(newId, {
name,
objects,
position,
});
merge(regionalGuidanceState, overrides);
return regionalGuidanceState;
};
const convertRegionalGuidanceToInpaintMask = (
newId: string,
regionalGuidanceState: CanvasRegionalGuidanceState,
overrides?: Partial<CanvasInpaintMaskState>
): CanvasInpaintMaskState => {
const { name, objects, position } = regionalGuidanceState;
const inpaintMaskState = getInpaintMaskState(newId, {
name,
objects,
position,
});
merge(inpaintMaskState, overrides);
return inpaintMaskState;
};
/**
* Supported conversions:
* - Raster Layer -> Control Layer
* - Raster Layer -> Inpaint Mask
* - Raster Layer -> Regional Guidance
* - Control Layer -> Control Layer
* - Control Layer -> Inpaint Mask
* - Control Layer -> Regional Guidance
* - Inpaint Mask -> Regional Guidance
* - Regional Guidance -> Inpaint Mask
*/
export const converters = {
rasterLayer: {
toControlLayer: convertRasterLayerToControlLayer,
toInpaintMask: convertRasterLayerToInpaintMask,
toRegionalGuidance: convertRasterLayerToRegionalGuidance,
},
controlLayer: {
toRasterLayer: convertControlLayerToRasterLayer,
toInpaintMask: convertControlLayerToInpaintMask,
toRegionalGuidance: convertControlLayerToRegionalGuidance,
},
inpaintMask: {
toRegionalGuidance: convertInpaintMaskToRegionalGuidance,
},
regionalGuidance: {
toInpaintMask: convertRegionalGuidanceToInpaintMask,
},
};

View File

@@ -42,14 +42,6 @@ export type AddControlLayerFromImageDropData = BaseDropData & {
actionType: 'ADD_CONTROL_LAYER_FROM_IMAGE';
};
type AddInpaintMaskFromImageDropData = BaseDropData & {
actionType: 'ADD_INPAINT_MASK_FROM_IMAGE';
};
type AddRegionalGuidanceFromImageDropData = BaseDropData & {
actionType: 'ADD_REGIONAL_GUIDANCE_FROM_IMAGE';
};
export type AddRegionalReferenceImageFromImageDropData = BaseDropData & {
actionType: 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE';
};
@@ -61,7 +53,7 @@ export type AddGlobalReferenceImageFromImageDropData = BaseDropData & {
export type ReplaceLayerImageDropData = BaseDropData & {
actionType: 'REPLACE_LAYER_WITH_IMAGE';
context: {
entityIdentifier: CanvasEntityIdentifier<'control_layer' | 'raster_layer' | 'inpaint_mask' | 'regional_guidance'>;
entityIdentifier: CanvasEntityIdentifier<'control_layer' | 'raster_layer'>;
};
};
@@ -106,9 +98,7 @@ export type TypesafeDroppableData =
| AddControlLayerFromImageDropData
| ReplaceLayerImageDropData
| AddRegionalReferenceImageFromImageDropData
| AddGlobalReferenceImageFromImageDropData
| AddInpaintMaskFromImageDropData
| AddRegionalGuidanceFromImageDropData;
| AddGlobalReferenceImageFromImageDropData;
type BaseDragData = {
id: string;

View File

@@ -17,8 +17,6 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
case 'SET_RG_IP_ADAPTER_IMAGE':
case 'ADD_RASTER_LAYER_FROM_IMAGE':
case 'ADD_CONTROL_LAYER_FROM_IMAGE':
case 'ADD_INPAINT_MASK_FROM_IMAGE':
case 'ADD_REGIONAL_GUIDANCE_FROM_IMAGE':
case 'SET_UPSCALE_INITIAL_IMAGE':
case 'SET_NODES_IMAGE':
case 'SELECT_FOR_COMPARE':

View File

@@ -1,4 +1,4 @@
import { Link } from '@invoke-ai/ui-library';
import { Flex, Link, Spacer, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $projectName, $projectUrl } from 'app/store/nanostores/projectId';
import { memo } from 'react';
@@ -9,13 +9,15 @@ export const GalleryHeader = memo(() => {
if (projectName && projectUrl) {
return (
<Link fontSize="md" fontWeight="semibold" noOfLines={1} wordBreak="break-all" href={projectUrl}>
{projectName}
</Link>
<Flex gap={2} alignItems="center" justifyContent="space-evenly" pe={2} w="50%">
<Text fontSize="md" fontWeight="semibold" noOfLines={1} wordBreak="break-all" w="full" textAlign="center">
<Link href={projectUrl}>{projectName}</Link>
</Text>
</Flex>
);
}
return null;
return <Spacer />;
});
GalleryHeader.displayName = 'GalleryHeader';

View File

@@ -51,8 +51,8 @@ const GalleryPanelContent = () => {
return (
<Flex ref={galleryPanelFocusRef} position="relative" flexDirection="column" h="full" w="full" tabIndex={-1}>
<Flex alignItems="center" justifyContent="space-between" w="full">
<Flex flexGrow={1} flexBasis={0}>
<Flex alignItems="center" w="full">
<Flex w="25%">
<Button
size="sm"
variant="ghost"
@@ -62,10 +62,8 @@ const GalleryPanelContent = () => {
{boardsListPanel.isCollapsed ? t('boards.viewBoards') : t('boards.hideBoards')}
</Button>
</Flex>
<Flex>
<GalleryHeader />
</Flex>
<Flex flexGrow={1} flexBasis={0} justifyContent="flex-end">
<GalleryHeader />
<Flex h="full" w="25%" justifyContent="flex-end">
<BoardsSettingsPopover />
<IconButton
size="sm"

View File

@@ -1,11 +1,9 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { MenuItem } from '@invoke-ai/ui-library';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import {
PiArrowBendUpLeftBold,
PiArrowsCounterClockwiseBold,
PiAsteriskBold,
PiPaintBrushBold,
@@ -16,36 +14,28 @@ import {
export const ImageMenuItemMetadataRecallActions = memo(() => {
const { t } = useTranslation();
const imageDTO = useImageDTOContext();
const subMenu = useSubMenu();
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, createAsPreset } =
useImageActions(imageDTO);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiArrowBendUpLeftBold />}>
<Menu {...subMenu.menuProps}>
<MenuButton {...subMenu.menuButtonProps}>
<SubMenuButtonContent label="Recall Metadata" />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={remix} isDisabled={!hasMetadata}>
{t('parameters.remixImage')}
</MenuItem>
<MenuItem icon={<PiQuotesBold />} onClick={recallPrompts} isDisabled={!hasPrompts}>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem icon={<PiPlantBold />} onClick={recallSeed} isDisabled={!hasSeed}>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem icon={<PiAsteriskBold />} onClick={recallAll} isDisabled={!hasMetadata}>
{t('parameters.useAll')}
</MenuItem>
<MenuItem icon={<PiPaintBrushBold />} onClick={createAsPreset} isDisabled={!hasPrompts}>
{t('stylePresets.useForTemplate')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
<>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClickCapture={remix} isDisabled={!hasMetadata}>
{t('parameters.remixImage')}
</MenuItem>
<MenuItem icon={<PiQuotesBold />} onClickCapture={recallPrompts} isDisabled={!hasPrompts}>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem icon={<PiPlantBold />} onClickCapture={recallSeed} isDisabled={!hasSeed}>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem icon={<PiAsteriskBold />} onClickCapture={recallAll} isDisabled={!hasMetadata}>
{t('parameters.useAll')}
</MenuItem>
<MenuItem icon={<PiPaintBrushBold />} onClickCapture={createAsPreset} isDisabled={!hasPrompts}>
{t('stylePresets.useForTemplate')}
</MenuItem>
</>
);
});

Some files were not shown because too many files have changed in this diff Show More