Compare commits

..

28 Commits

Author SHA1 Message Date
Ryan Dick
2ed86c082a Assume LoRA alpha=8 for FLUX diffusers PEFT LoRAs. 2024-09-12 22:04:20 +00:00
Ryan Dick
41ff0a5af6 Consolidate all LoRA patching logic in the LoRAPatcher. 2024-09-12 21:59:17 +00:00
Ryan Dick
c2ec65d582 lora_layer_from_state_dict(...) -> any_lora_layer_from_state_dict(...) 2024-09-12 21:59:17 +00:00
Ryan Dick
a0bede02f4 Rename peft/ -> lora/ 2024-09-12 21:59:17 +00:00
Ryan Dick
8f3c09348d Genera cleanup/documentation. 2024-09-12 21:59:17 +00:00
Ryan Dick
940269e60a Add a check that all keys are handled in the FLUX Diffusers LoRA loading code. 2024-09-12 21:59:17 +00:00
Ryan Dick
4fe42e2e48 Add model probe support for FLUX LoRA models in Diffusers format. 2024-09-12 21:59:17 +00:00
Ryan Dick
cbba28bdec Add utility test function for creating a dummy state_dict. 2024-09-12 21:59:17 +00:00
Ryan Dick
2e8effe83f Add is_state_dict_likely_in_flux_diffusers_format(...) function with unit test. 2024-09-12 21:59:17 +00:00
Ryan Dick
1b406e6d6a Add unit test for lora_model_from_flux_diffusers_state_dict(...). 2024-09-12 21:59:17 +00:00
Ryan Dick
757251266d First draft of lora_model_from_flux_diffusers_state_dict(...). 2024-09-12 21:59:17 +00:00
Ryan Dick
b91a9ec54c (minor) Rename test file. 2024-09-12 21:59:17 +00:00
Ryan Dick
0745d7ecfa Add ConcatenateLoRALayer class. 2024-09-12 21:59:17 +00:00
Ryan Dick
690bf4eb9d WIP on supporting diffusers format FLUX LoRAs. 2024-09-12 21:59:17 +00:00
Ryan Dick
c238b60db9 Rename flux_kohya_lora_conversion_utils.py 2024-09-12 21:59:17 +00:00
Ryan Dick
de5e9f33fa Fixup FLUX LoRA unit tests. 2024-09-12 21:59:17 +00:00
Ryan Dick
91ada8fc4c WIP 2024-09-12 21:59:17 +00:00
Ryan Dick
26be5ea030 WIP - add invocations to support FLUX LORAs. 2024-09-12 21:59:17 +00:00
Ryan Dick
d038d635f1 Get probing of FLUX LoRA kohya models working. 2024-09-12 21:59:17 +00:00
Ryan Dick
5225c77908 Add utility function for detecting whether a state_dict is in the FLUX kohya LoRA format. 2024-09-12 21:59:17 +00:00
Ryan Dick
33761066f1 Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test. 2024-09-12 21:59:17 +00:00
Ryan Dick
931942754d Move the responsibilities of 1) state_dict loading from file, and 2) SDXL lora key conversions, out of LoRAModelRaw and into LoRALoader. 2024-09-12 21:59:17 +00:00
Ryan Dick
04cdf00702 Remove unused LoRAModelRaw.name attribute. 2024-09-12 21:59:17 +00:00
Ryan Dick
141de5755c Fix type errors in sdxl_lora_conversion_utils.py 2024-09-12 21:59:17 +00:00
Ryan Dick
9b52e9c116 Start moving SDXL-specific LoRA conversions out of the general-purpose LoRAModelRaw class. 2024-09-12 21:59:17 +00:00
Ryan Dick
d0e8ce9056 Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with unit tests. 2024-09-12 21:59:17 +00:00
Ryan Dick
1bad546626 WIP - Initial logic for kohya FLUX LoRA conversion. 2024-09-12 21:59:17 +00:00
Ryan Dick
fa78ad5c5a Add state_dict keys for two FLUX LoRA formats to be used in unit tests. 2024-09-12 21:59:17 +00:00
231 changed files with 4898 additions and 4110 deletions

View File

@@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -81,9 +82,10 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -176,9 +178,9 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
LoRAPatcher.apply_lora_patches(
text_encoder,
loras=_lora_loader(),
patches=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),

View File

@@ -28,10 +28,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(
description="Mask for denoise model run. Values of 0.0 represent the regions to be fully denoised, and 1.0 "
+ "represent the regions to be preserved."
)
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)

View File

@@ -37,6 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
@@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
LoRAPatcher.apply_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
cached_weights=cached_weights,
),
):

View File

@@ -181,7 +181,7 @@ class FieldDescriptions:
)
num_1 = "The first number"
num_2 = "The second number"
denoise_mask = "A mask of the region to apply the denoising process to. Values of 0.0 represent the regions to be fully denoised, and 1.0 represent the regions to be preserved."
denoise_mask = "A mask of the region to apply the denoising process to."
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"

View File

@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
@@ -29,6 +29,8 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -39,7 +41,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="2.0.0",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -187,7 +189,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
with transformer_info as transformer:
with (
transformer_info.model_on_device() as (cached_weights, transformer),
# Apply the LoRA after transformer has been moved to its target device for faster patching.
LoRAPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
cached_weights=cached_weights,
),
):
assert isinstance(transformer, Flux)
x = denoise(
@@ -220,19 +231,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
device, and dtype for the inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
represent the regions to be preserved.
torch.Tensor | None: Inpaint mask.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
# 1.0 represents the regions to be preserved.
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
@@ -247,6 +252,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()

View File

@@ -0,0 +1,53 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""
transformer: TransformerField = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return FluxLoRALoaderOutput(transformer=transformer)

View File

@@ -129,18 +129,7 @@ class MergeMetadataInvocation(BaseInvocation):
GENERATION_MODES = Literal[
"txt2img",
"img2img",
"inpaint",
"outpaint",
"sdxl_txt2img",
"sdxl_img2img",
"sdxl_inpaint",
"sdxl_outpaint",
"flux_txt2img",
"flux_img2img",
"flux_inpaint",
"flux_outpaint",
"txt2img", "img2img", "inpaint", "outpaint", "sdxl_txt2img", "sdxl_img2img", "sdxl_inpaint", "sdxl_outpaint"
]

View File

@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
@@ -202,7 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),

View File

@@ -23,7 +23,7 @@ from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
@@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:

View File

@@ -41,7 +41,6 @@ def denoise(
if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
step_callback(
PipelineIntermediateState(

View File

@@ -28,19 +28,8 @@ class InpaintExtension:
This function should be called after each denoising step.
"""
timestep_cutoff = 0.5
if timestep > timestep_cutoff:
# Early in the denoising process, use the smaller mask.
# I.e. treat gradient values as 0.0.
mask = self._inpaint_mask.where(self._inpaint_mask >= (1.0 - 1e-3), 0.0)
else:
# After the cut-off, use the larger mask.
# I.e. treat gradient values as 1.0.
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + 1e-3), 1.0)
# mask = (self._inpaint_mask > (0.0 + 1e-5)).float()
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)

View File

@@ -31,24 +31,10 @@ def get_noise(
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
"""Shift the timestep schedule.
This is a simmilar idea to the beta schedule introduced in https://arxiv.org/abs/2305.08898. But, the function for
remapping timesteps in [0, 1] is different.
Properties of this function:
- Recommended sigma values: 1.0 <= sigma <= 3.0.
- When sigma=1.0 and mu=0.0, the conversion is the identity function.
- Increasing sigma results in an increasingly steep logistic function.
- Adjusting mu shifts the midpoint of the logistic function.
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
"""Return a linear function that maps x to y given the coordsinates of two points on the line (x1, y1) and
(x2, y2).
"""
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
@@ -66,17 +52,9 @@ def get_schedule(
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# Select mu based on linear interpolation between two points.
# Point 1: (image_seq_len=256, mu=0.5)
# Point 2: (image_seq_len=4096, mu=1.15)
# This has the effect of increasing mu as the image size increases. image_seq_len=4096 corresponds to an image
# size of 1024x1024.
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
# Shift the timesteps based on mu. Higher values of mu mean that there will be more timesteps early in the
# denoising process (i.e. many small steps in the timestep range 1.0-0.9, and fewer large steps in the timestep
# range 0.1-0.0).
timesteps = time_shift(mu=mu, sigma=1.0, t=timesteps)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
@@ -116,10 +94,6 @@ def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoi
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
# clipped_timesteps = torch.tensor(timesteps)
# clipped_timesteps = clipped_timesteps * (t_start_val - t_end_val) + t_end_val
# clipped_timesteps = clipped_timesteps.tolist()
return clipped_timesteps

View File

@@ -214,14 +214,8 @@ class LineartEdgeDetector:
line = line.cpu().numpy()
line = (line * 255.0).clip(0, 255).astype(np.uint8)
detected_map = 255 - line
detected_map = line
# The lineart model often outputs a lot of almost-black noise. SD1.5 ControlNets seem to be OK with this, but
# SDXL ControlNets are not - they need a cleaner map. 12 was experimentally determined to be a good threshold,
# eliminating all the noise while keeping the actual edges. Other approaches to thresholding may be better,
# for example stretching the contrast or removing noise.
detected_map[detected_map < 12] = 0
detected_map = 255 - detected_map
output = np_to_pil(detected_map)
return output
return np_to_pil(detected_map)

View File

@@ -260,14 +260,8 @@ class LineartAnimeEdgeDetector:
line = cv2.resize(line, (width, height), interpolation=cv2.INTER_CUBIC)
line = line.clip(0, 255).astype(np.uint8)
detected_map = 255 - line
# The lineart model often outputs a lot of almost-black noise. SD1.5 ControlNets seem to be OK with this, but
# SDXL ControlNets are not - they need a cleaner map. 12 was experimentally determined to be a good threshold,
# eliminating all the noise while keeping the actual edges. Other approaches to thresholding may be better,
# for example stretching the contrast or removing noise.
detected_map[detected_map < 12] = 0
detected_map = line
detected_map = 255 - detected_map
output = np_to_pil(detected_map)
return output

View File

@@ -0,0 +1,211 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
# Next, check that this is likely a FLUX model by spot-checking a few keys.
expected_keys = [
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
]
all_expected_keys_present = all(k in state_dict for k in expected_keys)
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
# Constants for FLUX.1
num_double_layers = 19
num_single_layers = 38
# inner_dim = 3072
# mlp_ratio = 4.0
layers: dict[str, AnyLoRALayer] = {}
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
layers[dst_key] = LoRALayer(
dst_key,
{
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
"""
# We expect that either all src keys are present or none of them are. Verify this.
keys_present = [key in grouped_state_dict for key in src_keys]
assert all(keys_present) or not any(keys_present)
# If none of the keys are present, return early.
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayerBase] = []
for src_layer_dict in src_layer_dicts:
sub_layers.append(
LoRALayer(
layer_key="",
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
)
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(layer_key=dst_qkv_key, lora_layers=sub_layers, concat_axis=0)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
# time_text_embed.text_embedder -> vector_in.
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
# time_text_embed.guidance_embedder -> guidance_in.
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
# context_embedder -> txt_in.
add_lora_layer_if_present("context_embedder", "txt_in")
# x_embedder -> img_in.
add_lora_layer_if_present("x_embedder", "img_in")
# Double transformer blocks.
for i in range(num_double_layers):
# norms.
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
# Q, K, V
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.to_q",
f"transformer_blocks.{i}.attn.to_k",
f"transformer_blocks.{i}.attn.to_v",
],
f"double_blocks.{i}.img_attn.qkv",
)
add_qkv_lora_layer_if_present(
[
f"transformer_blocks.{i}.attn.add_q_proj",
f"transformer_blocks.{i}.attn.add_k_proj",
f"transformer_blocks.{i}.attn.add_v_proj",
],
f"double_blocks.{i}.txt_attn.qkv",
)
# ff img_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.0.proj",
f"double_blocks.{i}.img_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff.net.2",
f"double_blocks.{i}.img_mlp.2",
)
# ff txt_mlp
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.0.proj",
f"double_blocks.{i}.txt_mlp.0",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.ff_context.net.2",
f"double_blocks.{i}.txt_mlp.2",
)
# output projections.
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_out.0",
f"double_blocks.{i}.img_attn.proj",
)
add_lora_layer_if_present(
f"transformer_blocks.{i}.attn.to_add_out",
f"double_blocks.{i}.txt_attn.proj",
)
# Single transformer blocks.
for i in range(num_single_layers):
# norms
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.norm.linear",
f"single_blocks.{i}.modulation.lin",
)
# Q, K, V, mlp
add_qkv_lora_layer_if_present(
[
f"single_transformer_blocks.{i}.attn.to_q",
f"single_transformer_blocks.{i}.attn.to_k",
f"single_transformer_blocks.{i}.attn.to_v",
f"single_transformer_blocks.{i}.proj_mlp",
],
f"single_blocks.{i}.linear1",
)
# Output projections.
add_lora_layer_if_present(
f"single_transformer_blocks.{i}.proj_out",
f"single_blocks.{i}.linear2",
)
# Final layer.
add_lora_layer_if_present("proj_out", "final_layer.linear")
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
return LoRAModelRaw(layers=layers)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Groups the keys in the state dict by layer."""
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
for key in state_dict:
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])
if layer_name not in layer_dict:
layer_dict[layer_name] = {}
layer_dict[layer_name][key_name] = state_dict[key]
return layer_dict

View File

@@ -0,0 +1,81 @@
import re
from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layer = any_lora_layer_from_state_dict(layer_key, layer_state_dict)
layers[layer_key] = layer
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
if match.group(4):
s += f".{match.group(4)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -0,0 +1,30 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layer = any_lora_layer_from_state_dict(layer_key, values)
layers[layer_key] = layer
return LoRAModelRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@@ -0,0 +1,154 @@
import bisect
from typing import Dict, List, Tuple, TypeVar
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
}

View File

@@ -1,5 +1,6 @@
from typing import Union
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
@@ -7,4 +8,4 @@ from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]

View File

@@ -0,0 +1,46 @@
from typing import List, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
class ConcatenatedLoRALayer(LoRALayerBase):
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, layer_key: str, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
super().__init__(layer_key, values={})
self._lora_layers = lora_layers
self._concat_axis = concat_axis
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original weight tensor here.
layer_weights = [lora_layer.get_weight(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
return torch.cat(layer_weights, dim=self._concat_axis)
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original bias tensor here.
layer_biases = [lora_layer.get_bias(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
layer_bias_is_none = [layer_bias is None for layer_bias in layer_biases]
if any(layer_bias_is_none):
assert all(layer_bias_is_none)
return None
# Ignore the type error, because we have just verified that all layer biases are non-None.
return torch.cat(layer_biases, dim=self._concat_axis)
def calc_size(self) -> int:
return sum(lora_layer.calc_size() for lora_layer in self._lora_layers)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for lora_layer in self._lora_layers:
lora_layer.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,33 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
def any_lora_layer_from_state_dict(layer_key: str, state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer(layer_key, state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer(layer_key, state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer(layer_key, state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer(layer_key, state_dict)
elif "on_input" in state_dict:
return IA3Layer(layer_key, state_dict)
elif "w_norm" in state_dict:
return NormLayer(layer_key, state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@@ -1,43 +1,17 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional
import torch
from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
self._name = name
def __init__(self, layers: Dict[str, AnyLoRALayer]):
self.layers = layers
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
@@ -46,234 +20,3 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@@ -0,0 +1,148 @@
from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more LoRA patches to a model within a context manager.
:param model: The model to patch.
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
that the LoRA patches do not need to be loaded into memory all at once.
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LoRAPatcher.apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
del patch
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@staticmethod
@torch.no_grad()
def apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply a single LoRA patch to a model.
:param model: The model to patch.
:param patch: LoRA model to patch in.
:param patch_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.
:param model: The model to search.
:param layer_key: The layer key to search for.
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
searching, but some legacy code still uses flattened keys.
:return: A tuple containing the module key and the submodule.
"""
if not layer_key_is_flattened:
return layer_key, model.get_submodule(layer_key)
# Handle flattened keys.
assert "." not in layer_key
module = model
module_key = ""
key_parts = layer_key.split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return module_key, module

View File

@@ -5,8 +5,18 @@ from logging import Logger
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -45,14 +55,39 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
# Load the state dict from the model file.
if model_path.suffix == ".safetensors":
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format == ModelFormat.Diffusers:
# HACK(ryand): We assume alpha=8 for diffusers PEFT format models. These models are typically
# distributed as a single file without the associated metadata containing the alpha value. We chose
# alpha=8, because this is the default value in the PEFT library:
# https://github.com/huggingface/peft/blob/7868d0372b86a6b9ac5f365b8f0eef2f2f5dedce/src/peft/tuners/lora/config.py#L169
# Other reasonable defaults for alpha could be 1.0 or the rank of the LoRA. If our assumption is wrong,
# the user will need to adjust the weight accordingly to account for the difference.
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=8)
elif config.format == ModelFormat.LyCORIS:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
model = lora_model_from_sd_state_dict(state_dict=state_dict)
else:
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
model.to(dtype=self._torch_dtype)
return model
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
self._model_base = config.base

View File

@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@@ -244,7 +248,9 @@ class ModelProbe(object):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
return ModelType.LoRA
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
return ModelType.ControlNet
@@ -554,12 +560,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
def get_format(self) -> ModelFormat:
return ModelFormat("lycoris")
if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
# TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
# ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
# file, but the weight keys are in the diffusers format.
return ModelFormat.Diffusers
return ModelFormat.LyCORIS
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
self.checkpoint
):
return BaseModelType.Flux
# If we've gotten here, we assume that the model is a Stable Diffusion model.
token_vector_length = lora_token_vector_length(self.checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:

View File

@@ -5,32 +5,18 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
class ModelPatcher:
@@ -54,95 +40,6 @@ class ModelPatcher:
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
cached_weights=cached_weights,
):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for lora_model, lora_weight in loras:
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
original_weights=original_weights,
)
del lora_model
yield
finally:
with torch.no_grad():
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@contextmanager
def apply_ti(
@@ -282,26 +179,6 @@ class ModelPatcher:
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
@@ -31,107 +30,14 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
self.patch_model(
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,
prefix="lora_unet_",
lora=lora_model,
lora_weight=self._weight,
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
)
del lora_model
yield
@classmethod
@torch.no_grad()
def patch_model(
cls,
model: torch.nn.Module,
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if lora_weight == 0:
return
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)

View File

@@ -173,11 +173,102 @@
"comparing": "Comparing",
"comparingDesc": "Comparing two images",
"enabled": "Enabled",
"disabled": "Disabled",
"placeholderSelectAModel": "Select a model",
"reset": "Reset",
"disabled": "Disabled"
},
"controlnet": {
"controlAdapter_one": "Control Adapter",
"controlAdapter_other": "Control Adapters",
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
"addControlNet": "Add $t(common.controlNet)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"addT2IAdapter": "Add $t(common.t2iAdapter)",
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
"base": "Base",
"beginEndStepPercent": "Begin / End Step Percentage",
"beginEndStepPercentShort": "Begin/End %",
"bgth": "bg_th",
"canny": "Canny",
"cannyDescription": "Canny edge detection",
"colorMap": "Color",
"colorMapDescription": "Generates a color map from the image",
"coarse": "Coarse",
"contentShuffle": "Content Shuffle",
"contentShuffleDescription": "Shuffles the content in an image",
"control": "Control",
"controlMode": "Control Mode",
"crop": "Crop",
"delete": "Delete",
"depthAnything": "Depth Anything",
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
"depthAnythingSmallV2": "Small V2",
"depthMidas": "Depth (Midas)",
"depthMidasDescription": "Depth map generation using Midas",
"depthZoe": "Depth (Zoe)",
"depthZoeDescription": "Depth map generation using Zoe",
"detectResolution": "Detect Resolution",
"duplicate": "Duplicate",
"f": "F",
"fill": "Fill",
"h": "H",
"face": "Face",
"body": "Body",
"hands": "Hands",
"hed": "HED",
"hedDescription": "Holistically-Nested Edge Detection",
"hideAdvanced": "Hide Advanced",
"highThreshold": "High Threshold",
"imageResolution": "Image Resolution",
"colorMapTileSize": "Tile Size",
"importImageFromCanvas": "Import Image From Canvas",
"importMaskFromCanvas": "Import Mask From Canvas",
"large": "Large",
"lineart": "Lineart",
"lineartAnime": "Lineart Anime",
"lineartAnimeDescription": "Anime-style lineart processing",
"lineartDescription": "Converts image to lineart",
"lowThreshold": "Low Threshold",
"maxFaces": "Max Faces",
"mediapipeFace": "Mediapipe Face",
"mediapipeFaceDescription": "Face detection using Mediapipe",
"megaControl": "Mega Control",
"minConfidence": "Min Confidence",
"mlsd": "M-LSD",
"mlsdDescription": "Minimalist Line Segment Detector",
"modelSize": "Model Size",
"none": "None",
"new": "New"
"noneDescription": "No processing applied",
"normalBae": "Normal BAE",
"normalBaeDescription": "Normal BAE processing",
"dwOpenpose": "DW Openpose",
"dwOpenposeDescription": "Human pose estimation using DW Openpose",
"pidi": "PIDI",
"pidiDescription": "PIDI image processing",
"processor": "Processor",
"prompt": "Prompt",
"resetControlImage": "Reset Control Image",
"resize": "Resize",
"resizeSimple": "Resize (Simple)",
"resizeMode": "Resize Mode",
"ipAdapterMethod": "Method",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only",
"safe": "Safe",
"saveControlImage": "Save Control Image",
"scribble": "Scribble",
"selectModel": "Select a model",
"selectCLIPVisionModel": "Select a CLIP Vision model",
"setControlImageDimensions": "Copy size to W/H (optimize for model)",
"setControlImageDimensionsForce": "Copy size to W/H (ignore model)",
"showAdvanced": "Show Advanced",
"small": "Small",
"toggleControlNet": "Toggle this ControlNet",
"w": "W",
"weight": "Weight"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -907,7 +998,6 @@
"downloadImage": "Download Image",
"general": "General",
"globalSettings": "Global Settings",
"guidance": "Guidance",
"height": "Height",
"imageFit": "Fit Initial Image To Output Size",
"images": "Images",
@@ -930,14 +1020,8 @@
"noModelForControlAdapter": "Control Adapter #{{number}} has no model selected.",
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
"canvasManagerNotLoaded": "Canvas Manager not loaded",
"canvasIsFiltering": "Canvas is filtering",
"canvasIsTransforming": "Canvas is transforming",
"canvasIsRasterizing": "Canvas is rasterizing",
"canvasIsCompositing": "Canvas is compositing",
"canvasBusy": "Canvas is busy",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
@@ -1300,13 +1384,6 @@
"High CFG Scale values can result in over-saturation and distorted generation results. "
]
},
"paramGuidance": {
"heading": "Guidance",
"paragraphs": [
"Controls how much the prompt influences the generation process.",
"High guidance values can result in over-saturation and high or low guidance may result in distorted generation results. Guidance only applies to FLUX DEV models."
]
},
"paramCFGRescaleMultiplier": {
"heading": "CFG Rescale Multiplier",
"paragraphs": [
@@ -1584,36 +1661,21 @@
"storeNotInitialized": "Store is not initialized"
},
"controlLayers": {
"regional": "Regional",
"global": "Global",
"canvas": "Canvas",
"bookmark": "Bookmark for Quick Switch",
"fitBboxToLayers": "Fit Bbox To Layers",
"removeBookmark": "Remove Bookmark",
"saveCanvasToGallery": "Save Canvas to Gallery",
"saveBboxToGallery": "Save Bbox to Gallery",
"newControlLayerFromBbox": "New Control Layer from Bbox",
"newRasterLayerFromBbox": "New Raster Layer from Bbox",
"sendBboxToRegionalIPAdapter": "Send Bbox to Regional IP Adapter",
"sendBboxToGlobalIPAdapter": "Send Bbox to Global IP Adapter",
"sendBboxToControlLayer": "Send Bbox to Control Layer",
"sendBboxToRasterLayer": "Send Bbox to Raster Layer",
"savedToGalleryOk": "Saved to Gallery",
"savedToGalleryError": "Error saving to gallery",
"newGlobalReferenceImageOk": "Created Global Reference Image",
"newGlobalReferenceImageError": "Problem Creating Global Reference Image",
"newRegionalReferenceImageOk": "Created Regional Reference Image",
"newRegionalReferenceImageError": "Problem Creating Regional Reference Image",
"newControlLayerOk": "Created Control Layer",
"newControlLayerError": "Problem Creating Control Layer",
"newRasterLayerOk": "Created Raster Layer",
"newRasterLayerError": "Problem Creating Raster Layer",
"pullBboxIntoLayerOk": "Bbox Pulled Into Layer",
"pullBboxIntoLayerError": "Problem Pulling BBox Into Layer",
"pullBboxIntoReferenceImageOk": "Bbox Pulled Into ReferenceImage",
"pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage",
"regionIsEmpty": "Selected region is empty",
"mergeVisible": "Merge Visible",
"mergeVisibleOk": "Merged visible layers",
"mergeVisibleError": "Error merging visible layers",
"clearHistory": "Clear History",
"bboxOverlay": "Show Bbox Overlay",
"generateMode": "Generate",
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
"composeMode": "Compose",
@@ -1624,7 +1686,7 @@
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Masked Regions",
"compositeMaskedRegions": "Composite Masked Regions",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -1641,29 +1703,25 @@
"enableAutoNegative": "Enable Auto Negative",
"disableAutoNegative": "Disable Auto Negative",
"deletePrompt": "Delete Prompt",
"deleteReferenceImage": "Delete Reference Image",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",
"showHUD": "Show HUD",
"rectangle": "Rectangle",
"maskFill": "Mask Fill",
"addPositivePrompt": "Add $t(controlLayers.prompt)",
"addNegativePrompt": "Add $t(controlLayers.negativePrompt)",
"addReferenceImage": "Add $t(controlLayers.referenceImage)",
"addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"addRasterLayer": "Add $t(controlLayers.rasterLayer)",
"addControlLayer": "Add $t(controlLayers.controlLayer)",
"addInpaintMask": "Add $t(controlLayers.inpaintMask)",
"addRegionalGuidance": "Add $t(controlLayers.regionalGuidance)",
"addGlobalReferenceImage": "Add $t(controlLayers.globalReferenceImage)",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
"raster": "Raster",
"rasterLayer": "Raster Layer",
"controlLayer": "Control Layer",
"inpaintMask": "Inpaint Mask",
"regionalGuidance": "Regional Guidance",
"referenceImage": "Reference Image",
"regionalReferenceImage": "Regional Reference Image",
"globalReferenceImage": "Global Reference Image",
"ipAdapter": "IP Adapter",
"sendingToCanvas": "Sending to Canvas",
"sendingToGallery": "Sending to Gallery",
"sendToGallery": "Send To Gallery",
@@ -1676,23 +1734,29 @@
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
"inpaintMask_withCount_one": "$t(controlLayers.inpaintMask)",
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
"globalReferenceImage_withCount_one": "$t(controlLayers.globalReferenceImage)",
"ipAdapter_withCount_one": "$t(controlLayers.ipAdapter)",
"rasterLayer_withCount_other": "Raster Layers",
"controlLayer_withCount_other": "Control Layers",
"inpaintMask_withCount_other": "Inpaint Masks",
"regionalGuidance_withCount_other": "Regional Guidance",
"globalReferenceImage_withCount_other": "Global Reference Images",
"ipAdapter_withCount_other": "IP Adapters",
"opacity": "Opacity",
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
"globalReferenceImages_withCount_hidden": "Global Reference Images ({{count}} hidden)",
"globalIPAdapters_withCount_hidden": "Global IP Adapters ({{count}} hidden)",
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
"controlLayers_withCount_visible": "Control Layers ({{count}})",
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
"globalReferenceImages_withCount_visible": "Global Reference Images ({{count}})",
"globalIPAdapters_withCount_visible": "Global IP Adapters ({{count}})",
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
"globalIPAdapter": "Global $t(common.ipAdapter)",
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
"globalInitialImage": "Global Initial Image",
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
"layer": "Layer",
"opacityFilter": "Opacity Filter",
"clearProcessor": "Clear Processor",
@@ -1723,27 +1787,7 @@
"stagingOnCanvas": "Staging images on",
"replaceLayer": "Replace Layer",
"pullBboxIntoLayer": "Pull Bbox into Layer",
"pullBboxIntoReferenceImage": "Pull Bbox into Reference Image",
"showProgressOnCanvas": "Show Progress on Canvas",
"prompt": "Prompt",
"negativePrompt": "Negative Prompt",
"beginEndStepPercentShort": "Begin/End %",
"weight": "Weight",
"controlMode": {
"controlMode": "Control Mode",
"balanced": "Balanced",
"prompt": "Prompt",
"control": "Control",
"megaControl": "Mega Control"
},
"ipAdapterMethod": {
"ipAdapterMethod": "IP Adapter Method",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only"
},
"useSizeOptimizeForModel": "Copy size to W/H (optimize for model)",
"useSizeIgnoreModel": "Copy size to W/H (ignore model)",
"pullBboxIntoIPAdapter": "Pull Bbox into IP Adapter",
"fill": {
"fillColor": "Fill Color",
"fillStyle": "Fill Style",
@@ -1861,10 +1905,6 @@
"label": "Snap to Grid",
"on": "On",
"off": "Off"
},
"preserveMask": {
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
}
},
"HUD": {
@@ -1872,23 +1912,15 @@
"scaledBbox": "Scaled Bbox",
"autoSave": "Auto Save",
"entityStatus": {
"isFiltering": "{{title}} is filtering",
"isTransforming": "{{title}} is transforming",
"isLocked": "{{title}} is locked",
"isHidden": "{{title}} is hidden",
"isDisabled": "{{title}} is disabled",
"isEmpty": "{{title}} is empty"
"selectedEntity": "Selected Entity",
"selectedEntityIs": "Selected Entity is",
"isFiltering": "is filtering",
"isTransforming": "is transforming",
"isLocked": "is locked",
"isHidden": "is hidden",
"isDisabled": "is disabled",
"enabled": "Enabled"
}
},
"canvasContextMenu": {
"saveToGalleryGroup": "Save To Gallery",
"saveCanvasToGallery": "Save Canvas To Gallery",
"saveBboxToGallery": "Save Bbox To Gallery",
"bboxGroup": "Create From Bbox",
"newGlobalReferenceImage": "New Global Reference Image",
"newRegionalReferenceImage": "New Regional Reference Image",
"newControlLayer": "New Control Layer",
"newRasterLayer": "New Raster Layer"
}
},
"upscaling": {

View File

@@ -19,6 +19,7 @@ import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/Cl
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
import { configChanged } from 'features/system/store/configSlice';
import { selectLanguage } from 'features/system/store/systemSelectors';
import { AppContent } from 'features/ui/components/AppContent';
@@ -137,6 +138,7 @@ const App = ({
<StylePresetModal />
<ClearQueueConfirmationsAlertDialog />
<PreselectedImage selectedImage={selectedImage} />
<SettingsModal />
<RefreshAfterResetModal />
<DeleteBoardModal />
</ErrorBoundary>

View File

@@ -5,7 +5,7 @@ import { canvasReset, rasterLayerAdded } from 'features/controlLayers/store/canv
import { stagingAreaImageAccepted, stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';

View File

@@ -3,7 +3,7 @@ import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import { isErr, withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
import {
selectIsStaging,
@@ -11,7 +11,6 @@ import {
stagingAreaStartedStaging,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -25,7 +24,7 @@ const log = logger('generation');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
enqueueRequested.match(action) && action.payload.tabName === 'generation',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const model = state.params.model;
@@ -48,11 +47,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
};
let buildGraphResult: Result<
{
g: Graph;
noise: Invocation<'noise' | 'flux_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
},
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
Error
>;
@@ -67,14 +62,11 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
case `flux`:
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
break;
default:
assert(false, `No graph builders for base ${base}`);
}
if (buildGraphResult.isErr()) {
if (isErr(buildGraphResult)) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
abortStaging();
return;
@@ -88,7 +80,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
);
if (prepareBatchResult.isErr()) {
if (isErr(prepareBatchResult)) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
abortStaging();
return;
@@ -103,7 +95,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const enqueueResult = await withResultAsync(() => req.unwrap());
if (enqueueResult.isErr()) {
if (isErr(enqueueResult)) {
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
abortStaging();
return;

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import { entityDeleted, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
@@ -53,9 +53,9 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
// };
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
dispatch(ipaImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
}
});
};

View File

@@ -1,27 +1,17 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import {
controlLayerAdded,
entityRasterized,
entitySelected,
ipaImageChanged,
rasterLayerAdded,
referenceImageAdded,
referenceImageIPAdapterImageChanged,
rgAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type {
CanvasControlLayerState,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
@@ -65,10 +55,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
) {
const { id } = overData.context;
dispatch(
referenceImageIPAdapterImageChanged({
entityIdentifier: { id, type: 'reference_image' },
imageDTO: activeData.payload.imageDTO,
})
ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO: activeData.payload.imageDTO })
);
return;
}
@@ -81,11 +68,11 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id, referenceImageId } = overData.context;
const { id, ipAdapterId } = overData.context;
dispatch(
rgIPAdapterImageChanged({
entityIdentifier: { id, type: 'regional_guidance' },
referenceImageId,
ipAdapterId,
imageDTO: activeData.payload.imageDTO,
})
);
@@ -131,36 +118,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
return;
}
if (
overData.actionType === 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasRegionalGuidanceState> = {
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }],
};
dispatch(rgAdded({ overrides, isSelected: true }));
return;
}
if (
overData.actionType === 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
const overrides: Partial<CanvasReferenceImageState> = {
ipAdapter,
};
dispatch(referenceImageAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Raster layer
*/
@@ -170,7 +127,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
}

View File

@@ -1,13 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import {
entityRasterized,
entitySelected,
referenceImageIPAdapterImageChanged,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
@@ -101,15 +94,15 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
if (postUploadAction?.type === 'SET_IPA_IMAGE') {
const { id } = postUploadAction;
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: { id, type: 'reference_image' }, imageDTO }));
dispatch(ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction?.type === 'SET_RG_IP_ADAPTER_IMAGE') {
const { id, referenceImageId } = postUploadAction;
const { id, ipAdapterId } = postUploadAction;
dispatch(
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, referenceImageId, imageDTO })
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, ipAdapterId, imageDTO })
);
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
@@ -121,17 +114,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
return;
}
if (postUploadAction?.type === 'REPLACE_LAYER_WITH_IMAGE') {
const { entityIdentifier } = postUploadAction;
const state = getState();
const imageObject = imageDTOToImageObject(imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
dispatch(entitySelected({ entityIdentifier }));
return;
}
},
});

View File

@@ -6,18 +6,11 @@ import {
bboxHeightChanged,
bboxWidthChanged,
controlLayerModelChanged,
referenceImageIPAdapterModelChanged,
ipaModelChanged,
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {
clipEmbedModelSelected,
fluxVAESelected,
modelChanged,
refinerModelChanged,
t5EncoderModelSelected,
vaeSelected,
} from 'features/controlLayers/store/paramsSlice';
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
@@ -28,16 +21,13 @@ import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonFluxVAEModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT5EncoderModelConfig,
isVAEModelConfig,
} from 'services/api/types';
const log = logger('models');
@@ -60,9 +50,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleControlAdapterModels(models, state, dispatch, log);
handleSpandrelImageToImageModels(models, state, dispatch, log);
handleIPAdapterModels(models, state, dispatch, log);
handleT5EncoderModels(models, state, dispatch, log);
handleCLIPEmbedModels(models, state, dispatch, log);
handleFLUXVAEModels(models, state, dispatch, log);
},
});
};
@@ -144,7 +131,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const vaeModels = models.filter(isNonFluxVAEModelConfig);
const vaeModels = models.filter(isVAEModelConfig);
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
@@ -194,22 +181,22 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
if (isModelAvailable) {
return;
}
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
dispatch(ipaModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
const isModelAvailable = ipaModels.some((m) => m.key === ipAdapter.model?.key);
selectCanvasSlice(state).regions.entities.forEach((entity) => {
entity.ipAdapters.forEach(({ id: ipAdapterId, model }) => {
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
if (isModelAvailable) {
return;
}
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), ipAdapterId, modelConfig: null })
);
});
});
@@ -236,45 +223,3 @@ const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch,
dispatch(postProcessingModelChanged(firstModel));
}
};
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, _log) => {
const { t5EncoderModel: currentT5EncoderModel } = state.params;
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
const firstModel = t5EncoderModels[0] || null;
const isCurrentT5EncoderModelAvailable = currentT5EncoderModel
? t5EncoderModels.some((m) => m.key === currentT5EncoderModel.key)
: false;
if (!isCurrentT5EncoderModelAvailable) {
dispatch(t5EncoderModelSelected(firstModel));
}
};
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, _log) => {
const { clipEmbedModel: currentCLIPEmbedModel } = state.params;
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
const firstModel = CLIPEmbedModels[0] || null;
const isCurrentCLIPEmbedModelAvailable = currentCLIPEmbedModel
? CLIPEmbedModels.some((m) => m.key === currentCLIPEmbedModel.key)
: false;
if (!isCurrentCLIPEmbedModelAvailable) {
dispatch(clipEmbedModelSelected(firstModel));
}
};
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, _log) => {
const { fluxVAE: currentFLUXVAEModel } = state.params;
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
const firstModel = fluxVAEModels[0] || null;
const isCurrentFLUXVAEModelAvailable = currentFLUXVAEModel
? fluxVAEModels.some((m) => m.key === currentFLUXVAEModel.key)
: false;
if (!isCurrentFLUXVAEModelAvailable) {
dispatch(fluxVAESelected(firstModel));
}
};

View File

@@ -1,13 +0,0 @@
import type { ReadableAtom } from 'nanostores';
import { atom } from 'nanostores';
/**
* A fallback non-writable atom that always returns `false`, used when a nanostores atom is only conditionally available
* in a hook or component.
*/
export const $false: ReadableAtom<boolean> = atom(false);
/**
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
* in a hook or component.
*/
export const $true: ReadableAtom<boolean> = atom(true);

View File

@@ -114,9 +114,6 @@ export type AppConfig = {
weight: NumericalParameterConfig;
};
};
flux: {
guidance: NumericalParameterConfig;
};
};
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;

View File

@@ -155,19 +155,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
return styles;
}, [isUploadDisabled, minSize]);
const openInNewTab = useCallback(
(e: MouseEvent) => {
if (!imageDTO) {
return;
}
if (e.button !== 1) {
return;
}
window.open(imageDTO.image_url, '_blank');
},
[imageDTO]
);
return (
<ImageContextMenu imageDTO={imageDTO}>
{(ref) => (
@@ -225,12 +212,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
<IAIDraggable data={draggableData} disabled={isDragDisabled || !imageDTO} onClick={onClick} />
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}

View File

@@ -10,9 +10,7 @@ const sx: SystemStyleObject = {
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: `drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-900))
drop-shadow(0px 0px 0.3rem var(--invoke-colors-base-900))
drop-shadow(0px 0px 0.3rem var(--invoke-colors-base-900))`,
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
};
@@ -29,6 +27,7 @@ const IAIDndImageIcon = (props: Props) => {
onClick={onClick}
aria-label={tooltip}
icon={icon}
size="sm"
variant="link"
sx={sx}
data-testid={tooltip}

View File

@@ -1,4 +1,4 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import type { AnimationProps } from 'framer-motion';
import { motion } from 'framer-motion';
import type { ReactNode } from 'react';
@@ -28,13 +28,11 @@ const IAIDropOverlay = (props: Props) => {
const motionId = useRef(uuidv4());
return (
<motion.div key={motionId.current} initial={initial} animate={animate} exit={exit}>
<Flex position="absolute" top={0} right={0} bottom={0} left={0}>
<Flex position="absolute" top={0} insetInlineStart={0} w="full" h="full">
<Flex
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
insetInlineStart={0}
w="full"
h="full"
bg="base.900"
@@ -49,30 +47,29 @@ const IAIDropOverlay = (props: Props) => {
<Flex
position="absolute"
top={0.5}
right={0.5}
insetInlineStart={0.5}
insetInlineEnd={0.5}
bottom={0.5}
left={0.5}
opacity={1}
borderWidth={1.5}
borderColor={isOver ? 'invokeYellow.300' : 'base.500'}
borderWidth={2}
borderColor={isOver ? 'base.300' : 'base.500'}
borderRadius="base"
borderStyle="dashed"
transitionProperty="common"
transitionDuration="0.1s"
alignItems="center"
justifyContent="center"
p={4}
>
<Text
fontSize="xl"
<Box
fontSize="2xl"
fontWeight="semibold"
color={isOver ? 'invokeYellow.300' : 'base.500'}
transform={isOver ? 'scale(1.1)' : 'scale(1)'}
color={isOver ? 'base.50' : 'base.300'}
transitionProperty="common"
transitionDuration="0.1s"
textAlign="center"
>
{label}
</Text>
</Box>
</Flex>
</Flex>
</motion.div>

View File

@@ -30,9 +30,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
ref={setNodeRef}
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
insetInlineStart={0}
w="full"
h="full"
pointerEvents={active ? 'auto' : 'none'}

View File

@@ -30,7 +30,6 @@ export type Feature =
| 'noiseUseCPU'
| 'paramAspect'
| 'paramCFGScale'
| 'paramGuidance'
| 'paramCFGRescaleMultiplier'
| 'paramDenoisingStrength'
| 'paramHeight'

View File

@@ -68,7 +68,7 @@ export const useGlobalHotkeys = () => {
useHotkeys(
'1',
() => {
dispatch(setActiveTab('canvas'));
dispatch(setActiveTab('generation'));
addScope('canvas');
removeScope('workflows');
},

View File

@@ -1,7 +1,6 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { $true } from 'app/store/nanostores/util';
import { useAppSelector } from 'app/store/storeHooks';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
@@ -19,25 +18,19 @@ import { selectSystemSlice } from 'features/system/store/systemSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { atom } from 'nanostores';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY = {
reference_image: 'controlLayers.referenceImage',
ip_adapter: 'controlLayers.ipAdapter',
inpaint_mask: 'controlLayers.inpaintMask',
regional_guidance: 'controlLayers.regionalGuidance',
raster_layer: 'controlLayers.rasterLayer',
control_layer: 'controlLayers.controlLayer',
raster_layer: 'controlLayers.raster',
control_layer: 'controlLayers.globalControlAdapter',
} as const;
const createSelector = (
templates: Templates,
isConnected: boolean,
canvasIsFiltering: boolean,
canvasIsTransforming: boolean,
canvasIsRasterizing: boolean,
canvasIsCompositing: boolean
) =>
const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy: boolean) =>
createMemoizedSelector(
[
selectSystemSlice,
@@ -126,17 +119,8 @@ const createSelector = (
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
}
} else {
if (canvasIsFiltering) {
reasons.push({ content: i18n.t('parameters.invoke.canvasIsFiltering') });
}
if (canvasIsTransforming) {
reasons.push({ content: i18n.t('parameters.invoke.canvasIsTransforming') });
}
if (canvasIsRasterizing) {
reasons.push({ content: i18n.t('parameters.invoke.canvasIsRasterizing') });
}
if (canvasIsCompositing) {
reasons.push({ content: i18n.t('parameters.invoke.canvasIsCompositing') });
if (canvasIsBusy) {
reasons.push({ content: i18n.t('parameters.invoke.canvasBusy') });
}
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
@@ -147,18 +131,6 @@ const createSelector = (
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (model?.base === 'flux') {
if (!params.t5EncoderModel) {
reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') });
}
if (!params.clipEmbedModel) {
reasons.push({ content: i18n.t('parameters.invoke.noCLIPEmbedModelSelected') });
}
if (!params.fluxVAE) {
reasons.push({ content: i18n.t('parameters.invoke.noFLUXVAEModelSelected') });
}
}
canvas.controlLayers.entities
.filter((controlLayer) => controlLayer.isEnabled)
.forEach((controlLayer, i) => {
@@ -189,7 +161,7 @@ const createSelector = (
}
});
canvas.referenceImages.entities
canvas.ipAdapters.entities
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {
const layerLiteral = i18n.t('controlLayers.layer_one');
@@ -217,7 +189,7 @@ const createSelector = (
}
});
canvas.regionalGuidance.entities
canvas.regions.entities
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {
const layerLiteral = i18n.t('controlLayers.layer_one');
@@ -230,14 +202,10 @@ const createSelector = (
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (
entity.positivePrompt === null &&
entity.negativePrompt === null &&
entity.referenceImages.length === 0
) {
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
entity.referenceImages.forEach(({ ipAdapter }) => {
entity.ipAdapters.forEach((ipAdapter) => {
// Must have model
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
@@ -278,25 +246,16 @@ const createSelector = (
}
);
const dummyAtom = atom(true);
export const useIsReadyToEnqueue = () => {
const templates = useStore($templates);
const isConnected = useStore($isConnected);
const canvasManager = useCanvasManagerSafe();
const canvasIsFiltering = useStore(canvasManager?.stateApi.$isFiltering ?? $true);
const canvasIsTransforming = useStore(canvasManager?.stateApi.$isTransforming ?? $true);
const canvasIsRasterizing = useStore(canvasManager?.stateApi.$isRasterizing ?? $true);
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
const canvasIsBusy = useStore(canvasManager?.$isBusy ?? dummyAtom);
const selector = useMemo(
() =>
createSelector(
templates,
isConnected,
canvasIsFiltering,
canvasIsTransforming,
canvasIsRasterizing,
canvasIsCompositing
),
[templates, isConnected, canvasIsFiltering, canvasIsTransforming, canvasIsRasterizing, canvasIsCompositing]
() => createSelector(templates, isConnected, canvasIsBusy),
[templates, isConnected, canvasIsBusy]
);
const value = useAppSelector(selector);
return value;

View File

@@ -1,14 +0,0 @@
import { getPrefixedId, nanoid } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
export const useNanoid = (prefix?: string) => {
const id = useMemo(() => {
if (prefix) {
return getPrefixedId(prefix);
} else {
return nanoid();
}
}, [prefix]);
return id;
};

View File

@@ -2,7 +2,8 @@ import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, expect, it } from 'vitest';
import { Err, ErrResult, Ok, OkResult, withResult, withResultAsync } from './result';
import type { ErrResult, OkResult } from './result';
import { Err, isErr, isOk, Ok, withResult, withResultAsync } from './result'; // Adjust import as needed
const promiseify = <T>(fn: () => T): (() => Promise<T>) => {
return () =>
@@ -12,30 +13,28 @@ const promiseify = <T>(fn: () => T): (() => Promise<T>) => {
};
describe('Result Utility Functions', () => {
it('OkResult() should create an Ok result', () => {
const result = OkResult(42);
expect(result).toBeInstanceOf(Ok);
expect(result.isOk()).toBe(true);
expect(result.isErr()).toBe(false);
expect(result.value).toBe(42);
assert<Equals<Ok<number>, typeof result>>(result);
it('Ok() should create an OkResult', () => {
const result = Ok(42);
expect(result).toEqual({ type: 'Ok', value: 42 });
expect(isOk(result)).toBe(true);
expect(isErr(result)).toBe(false);
assert<Equals<OkResult<number>, typeof result>>(result);
});
it('ErrResult() should create an Err result', () => {
it('Err() should create an ErrResult', () => {
const error = new Error('Something went wrong');
const result = ErrResult(error);
expect(result).toBeInstanceOf(Err);
expect(result.isOk()).toBe(false);
expect(result.isErr()).toBe(true);
expect(result.error).toBe(error);
assert<Equals<Err<Error>, typeof result>>(result);
const result = Err(error);
expect(result).toEqual({ type: 'Err', error });
expect(isOk(result)).toBe(false);
expect(isErr(result)).toBe(true);
assert<Equals<ErrResult<Error>, typeof result>>(result);
});
it('withResult() should return Ok on success', () => {
const fn = () => 42;
const result = withResult(fn);
expect(result.isOk()).toBe(true);
if (result.isOk()) {
expect(isOk(result)).toBe(true);
if (isOk(result)) {
expect(result.value).toBe(42);
}
});
@@ -45,8 +44,8 @@ describe('Result Utility Functions', () => {
throw new Error('Failure');
};
const result = withResult(fn);
expect(result.isErr()).toBe(true);
if (result.isErr()) {
expect(isErr(result)).toBe(true);
if (isErr(result)) {
expect(result.error.message).toBe('Failure');
}
});
@@ -54,8 +53,8 @@ describe('Result Utility Functions', () => {
it('withResultAsync() should return Ok on success', async () => {
const fn = promiseify(() => 42);
const result = await withResultAsync(fn);
expect(result.isOk()).toBe(true);
if (result.isOk()) {
expect(isOk(result)).toBe(true);
if (isOk(result)) {
expect(result.value).toBe(42);
}
});
@@ -65,8 +64,8 @@ describe('Result Utility Functions', () => {
throw new Error('Async failure');
});
const result = await withResultAsync(fn);
expect(result.isErr()).toBe(true);
if (result.isErr()) {
expect(isErr(result)).toBe(true);
if (isErr(result)) {
expect(result.error.message).toBe('Async failure');
}
});

View File

@@ -2,81 +2,39 @@
* Represents a successful result.
* @template T The type of the value.
*/
export class Ok<T> {
readonly value: T;
constructor(value: T) {
this.value = value;
}
/**
* Type guard to check if this result is an `Ok` result.
* @returns {this is Ok<T>} `true` if the result is an `Ok` result, otherwise `false`.
*/
isOk(): this is Ok<T> {
return true;
}
/**
* Type guard to check if this result is an `Err` result.
* @returns {this is Err<never>} `true` if the result is an `Err` result, otherwise `false`.
*/
isErr(): this is Err<never> {
return false;
}
}
export type OkResult<T> = { type: 'Ok'; value: T };
/**
* Represents a failed result.
* @template E The type of the error.
*/
export class Err<E> {
readonly error: E;
constructor(error: E) {
this.error = error;
}
/**
* Type guard to check if this result is an `Ok` result.
* @returns {this is Ok<never>} `true` if the result is an `Ok` result, otherwise `false`.
*/
isOk(): this is Ok<never> {
return false;
}
/**
* Type guard to check if this result is an `Err` result.
* @returns {this is Err<E>} `true` if the result is an `Err` result, otherwise `false`.
*/
isErr(): this is Err<E> {
return true;
}
}
export type ErrResult<E> = { type: 'Err'; error: E };
/**
* A union type that represents either a successful result (`Ok`) or a failed result (`Err`).
* @template T The type of the value in the `Ok` case.
* @template E The type of the error in the `Err` case.
*/
export type Result<T, E = Error> = Ok<T> | Err<E>;
export type Result<T, E = Error> = OkResult<T> | ErrResult<E>;
/**
* Creates a successful result.
* @template T The type of the value.
* @param {T} value The value to wrap in an `Ok` result.
* @returns {Ok<T>} The `Ok` result containing the value.
* @returns {OkResult<T>} The `Ok` result containing the value.
*/
export function OkResult<T>(value: T): Ok<T> {
return new Ok(value);
export function Ok<T>(value: T): OkResult<T> {
return { type: 'Ok', value };
}
/**
* Creates a failed result.
* @template E The type of the error.
* @param {E} error The error to wrap in an `Err` result.
* @returns {Err<E>} The `Err` result containing the error.
* @returns {ErrResult<E>} The `Err` result containing the error.
*/
export function ErrResult<E>(error: E): Err<E> {
return new Err(error);
export function Err<E>(error: E): ErrResult<E> {
return { type: 'Err', error };
}
/**
@@ -87,9 +45,9 @@ export function ErrResult<E>(error: E): Err<E> {
*/
export function withResult<T>(fn: () => T): Result<T> {
try {
return new Ok(fn());
return Ok(fn());
} catch (error) {
return new Err(error instanceof Error ? error : new Error(String(error)));
return Err(error instanceof Error ? error : new Error(String(error)));
}
}
@@ -102,8 +60,30 @@ export function withResult<T>(fn: () => T): Result<T> {
export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T>> {
try {
const result = await fn();
return new Ok(result);
return Ok(result);
} catch (error) {
return new Err(error instanceof Error ? error : new Error(String(error)));
return Err(error instanceof Error ? error : new Error(String(error)));
}
}
/**
* Type guard to check if a `Result` is an `Ok` result.
* @template T The type of the value in the `Ok` result.
* @template E The type of the error in the `Err` result.
* @param {Result<T, E>} result The result to check.
* @returns {result is OkResult<T>} `true` if the result is an `Ok` result, otherwise `false`.
*/
export function isOk<T, E>(result: Result<T, E>): result is OkResult<T> {
return result.type === 'Ok';
}
/**
* Type guard to check if a `Result` is an `Err` result.
* @template T The type of the value in the `Ok` result.
* @template E The type of the error in the `Err` result.
* @param {Result<T, E>} result The result to check.
* @returns {result is ErrResult<E>} `true` if the result is an `Err` result, otherwise `false`.
*/
export function isErr<T, E>(result: Result<T, E>): result is ErrResult<E> {
return result.type === 'Err';
}

View File

@@ -1,14 +1,11 @@
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddIPAdapter,
useAddRasterLayer,
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -19,82 +16,27 @@ export const CanvasAddEntityButtons = memo(() => {
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isFLUX = useAppSelector(selectIsFLUX);
const addIPAdapter = useAddIPAdapter();
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
<Flex position="relative" flexDir="column" gap={4} top="20%">
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
<Heading size="xs">{t('controlLayers.global')}</Heading>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
isDisabled={isFLUX}
>
{t('controlLayers.globalReferenceImage')}
</Button>
</Flex>
<Flex flexDir="column" gap={2}>
<Heading size="xs">{t('controlLayers.regional')}</Heading>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addInpaintMask}
>
{t('controlLayers.inpaintMask')}
</Button>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX}
>
{t('controlLayers.regionalGuidance')}
</Button>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX}
>
{t('controlLayers.regionalReferenceImage')}
</Button>
</Flex>
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
<Heading size="xs">{t('controlLayers.layer_other')}</Heading>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
isDisabled={isFLUX}
>
{t('controlLayers.controlLayer')}
</Button>
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRasterLayer}
>
{t('controlLayers.rasterLayer')}
</Button>
</Flex>
</Flex>
<Flex flexDir="column" w="full" h="full" alignItems="center">
<ButtonGroup position="relative" orientation="vertical" isAttached={false} top="20%">
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</Button>
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRegionalGuidance}>
{t('controlLayers.regionalGuidance')}
</Button>
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRasterLayer}>
{t('controlLayers.rasterLayer')}
</Button>
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addControlLayer}>
{t('controlLayers.controlLayer')}
</Button>
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
{t('controlLayers.globalIPAdapter')}
</Button>
</ButtonGroup>
</Flex>
);
});

View File

@@ -1,23 +0,0 @@
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectPreserveMask } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasAlertsPreserveMask = memo(() => {
const { t } = useTranslation();
const preserveMask = useAppSelector(selectPreserveMask);
if (!preserveMask) {
return null;
}
return (
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content" alignSelf="flex-end">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.preserveMask.alert')}</AlertTitle>
</Alert>
);
});
CanvasAlertsPreserveMask.displayName = 'CanvasAlertsPreserveMask';

View File

@@ -1,53 +0,0 @@
import { MenuGroup, MenuItem } from '@invoke-ai/ui-library';
import {
useNewControlLayerFromBbox,
useNewGlobalReferenceImageFromBbox,
useNewRasterLayerFromBbox,
useNewRegionalReferenceImageFromBbox,
useSaveBboxToGallery,
useSaveCanvasToGallery,
} from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold, PiStackPlusFill } from 'react-icons/pi';
export const CanvasContextMenuGlobalMenuItems = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const saveCanvasToGallery = useSaveCanvasToGallery();
const saveBboxToGallery = useSaveBboxToGallery();
const newRegionalReferenceImageFromBbox = useNewRegionalReferenceImageFromBbox();
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
const newRasterLayerFromBbox = useNewRasterLayerFromBbox();
const newControlLayerFromBbox = useNewControlLayerFromBbox();
return (
<>
<MenuGroup title={t('controlLayers.canvasContextMenu.saveToGalleryGroup')}>
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveCanvasToGallery}>
{t('controlLayers.canvasContextMenu.saveCanvasToGallery')}
</MenuItem>
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveBboxToGallery}>
{t('controlLayers.canvasContextMenu.saveBboxToGallery')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.canvasContextMenu.bboxGroup')}>
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newGlobalReferenceImageFromBbox}>
{t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
</MenuItem>
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newRegionalReferenceImageFromBbox}>
{t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
</MenuItem>
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newControlLayerFromBbox}>
{t('controlLayers.canvasContextMenu.newControlLayer')}
</MenuItem>
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newRasterLayerFromBbox}>
{t('controlLayers.canvasContextMenu.newRasterLayer')}
</MenuItem>
</MenuGroup>
</>
);
});
CanvasContextMenuGlobalMenuItems.displayName = 'CanvasContextMenuGlobalMenuItems';

View File

@@ -0,0 +1,49 @@
import { MenuItem } from '@invoke-ai/ui-library';
import {
useIsSavingCanvas,
useSaveBboxAsControlLayer,
useSaveBboxAsGlobalIPAdapter,
useSaveBboxAsRasterLayer,
useSaveBboxAsRegionalGuidanceIPAdapter,
useSaveBboxToGallery,
useSaveCanvasToGallery,
} from 'features/controlLayers/hooks/saveCanvasHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold, PiShareFatBold } from 'react-icons/pi';
export const CanvasContextMenuItems = memo(() => {
const { t } = useTranslation();
const isSaving = useIsSavingCanvas();
const saveCanvasToGallery = useSaveCanvasToGallery();
const saveBboxToGallery = useSaveBboxToGallery();
const saveBboxAsRegionalGuidanceIPAdapter = useSaveBboxAsRegionalGuidanceIPAdapter();
const saveBboxAsIPAdapter = useSaveBboxAsGlobalIPAdapter();
const saveBboxAsRasterLayer = useSaveBboxAsRasterLayer();
const saveBboxAsControlLayer = useSaveBboxAsControlLayer();
return (
<>
<MenuItem icon={<PiFloppyDiskBold />} isLoading={isSaving.isTrue} onClick={saveCanvasToGallery}>
{t('controlLayers.saveCanvasToGallery')}
</MenuItem>
<MenuItem icon={<PiFloppyDiskBold />} isLoading={isSaving.isTrue} onClick={saveBboxToGallery}>
{t('controlLayers.saveBboxToGallery')}
</MenuItem>
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsIPAdapter}>
{t('controlLayers.sendBboxToGlobalIPAdapter')}
</MenuItem>
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsRegionalGuidanceIPAdapter}>
{t('controlLayers.sendBboxToRegionalIPAdapter')}
</MenuItem>
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsControlLayer}>
{t('controlLayers.sendBboxToControlLayer')}
</MenuItem>
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsRasterLayer}>
{t('controlLayers.sendBboxToRasterLayer')}
</MenuItem>
</>
);
});
CanvasContextMenuItems.displayName = 'CanvasContextMenuItems';

View File

@@ -1,43 +0,0 @@
import { MenuGroup } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import {
EntityIdentifierContext,
useEntityIdentifierContext,
} from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityTitle } from 'features/controlLayers/hooks/useEntityTitle';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { isFilterableEntityIdentifier, isTransformableEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const CanvasContextMenuSelectedEntityMenuItemsContent = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const title = useEntityTitle(entityIdentifier);
return (
<MenuGroup title={title}>
{isFilterableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsFilter />}
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsTransform />}
<CanvasEntityMenuItemsDelete />
</MenuGroup>
);
});
CanvasContextMenuSelectedEntityMenuItemsContent.displayName = 'CanvasContextMenuSelectedEntityMenuItemsContent';
export const CanvasContextMenuSelectedEntityMenuItems = memo(() => {
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
if (!selectedEntityIdentifier) {
return null;
}
return (
<EntityIdentifierContext.Provider value={selectedEntityIdentifier}>
<CanvasContextMenuSelectedEntityMenuItemsContent />
</EntityIdentifierContext.Provider>
);
});
CanvasContextMenuSelectedEntityMenuItems.displayName = 'CanvasContextMenuSelectedEntityMenuItems';

View File

@@ -1,11 +1,6 @@
import { Grid, GridItem } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import IAIDroppable from 'common/components/IAIDroppable';
import type {
AddControlLayerFromImageDropData,
AddGlobalReferenceImageFromImageDropData,
AddRasterLayerFromImageDropData,
AddRegionalReferenceImageFromImageDropData,
} from 'features/dnd/types';
import type { AddControlLayerFromImageDropData, AddRasterLayerFromImageDropData } from 'features/dnd/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo } from 'react';
@@ -19,16 +14,6 @@ const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
actionType: 'ADD_CONTROL_LAYER_FROM_IMAGE',
};
const addRegionalReferenceImageFromImageDropData: AddRegionalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE',
};
const addGlobalReferenceImageFromImageDropData: AddGlobalReferenceImageFromImageDropData = {
id: 'add-control-layer-from-image-drop-data',
actionType: 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE',
};
export const CanvasDropArea = memo(() => {
const imageViewer = useImageViewer();
@@ -38,29 +23,12 @@ export const CanvasDropArea = memo(() => {
return (
<>
<Grid
gridTemplateRows="1fr 1fr"
gridTemplateColumns="1fr 1fr"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
pointerEvents="none"
>
<GridItem position="relative">
<IAIDroppable dropLabel="New Raster Layer" data={addRasterLayerFromImageDropData} />
</GridItem>
<GridItem position="relative">
<IAIDroppable dropLabel="New Control Layer" data={addControlLayerFromImageDropData} />
</GridItem>
<GridItem position="relative">
<IAIDroppable dropLabel="New Regional Reference Image" data={addRegionalReferenceImageFromImageDropData} />
</GridItem>
<GridItem position="relative">
<IAIDroppable dropLabel="New Global Reference Image" data={addGlobalReferenceImageFromImageDropData} />
</GridItem>
</Grid>
<Flex position="absolute" top={0} right={0} bottom="50%" left={0} gap={2} pointerEvents="none">
<IAIDroppable dropLabel="Create Raster Layer" data={addRasterLayerFromImageDropData} />
</Flex>
<Flex position="absolute" top="50%" right={0} bottom={0} left={0} gap={2} pointerEvents="none">
<IAIDroppable dropLabel="Create Control Layer" data={addControlLayerFromImageDropData} />
</Flex>
</>
);
});

View File

@@ -11,9 +11,9 @@ export const CanvasEntityList = memo(() => {
return (
<ScrollableContent>
<Flex flexDir="column" gap={2} data-testid="control-layers-layer-list" w="full" h="full">
<IPAdapterList />
<InpaintMaskList />
<RegionalGuidanceEntityList />
<IPAdapterList />
<ControlLayerEntityList />
<RasterLayerEntityList />
</Flex>

View File

@@ -1,27 +1,22 @@
import { IconButton, Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddIPAdapter,
useAddRasterLayer,
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const { t } = useTranslation();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addInpaintMask = useAddInpaintMask();
const addRegionalGuidance = useAddRegionalGuidance();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isFLUX = useAppSelector(selectIsFLUX);
const addIPAdapter = useAddIPAdapter();
return (
<Menu>
@@ -36,30 +31,21 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
data-testid="control-layers-add-layer-menu-button"
/>
<MenuList>
<MenuGroup title={t('controlLayers.global')}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isFLUX}>
{t('controlLayers.globalReferenceImage')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.regional')}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX}>
{t('controlLayers.regionalReferenceImage')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.layer_other')}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isFLUX}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
{t('controlLayers.rasterLayer')}
</MenuItem>
</MenuGroup>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
{t('controlLayers.rasterLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addIPAdapter}>
{t('controlLayers.globalIPAdapter')}
</MenuItem>
</MenuList>
</Menu>
);

View File

@@ -1,6 +1,5 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { entityDuplicated } from 'features/controlLayers/store/canvasSlice';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { memo, useCallback } from 'react';
@@ -10,7 +9,6 @@ import { PiCopyFill } from 'react-icons/pi';
export const EntityListSelectedEntityActionBarDuplicateButton = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const onClick = useCallback(() => {
if (!selectedEntityIdentifier) {
@@ -22,7 +20,7 @@ export const EntityListSelectedEntityActionBarDuplicateButton = memo(() => {
return (
<IconButton
onClick={onClick}
isDisabled={!selectedEntityIdentifier || isBusy}
isDisabled={!selectedEntityIdentifier}
size="sm"
variant="link"
alignSelf="stretch"

View File

@@ -50,14 +50,7 @@ export const EntityListSelectedEntityActionBarFill = memo(() => {
<Flex role="button" aria-label={t('controlLayers.maskFill')} tabIndex={-1} w={8} h={8}>
<Tooltip label={t('controlLayers.maskFill')}>
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Box
borderRadius="full"
borderColor="base.300"
w={6}
h={6}
borderWidth={1}
bg={rgbColorToString(fill.color)}
/>
<Box borderRadius="full" w={6} h={6} borderWidth={1} bg={rgbColorToString(fill.color)} />
</Flex>
</Tooltip>
</Flex>

View File

@@ -137,7 +137,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
<FormControl
w="min-content"
gap={2}
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'ip_adapter'}
>
<FormLabel m={0}>{t('controlLayers.opacity')}</FormLabel>
<PopoverAnchor>
@@ -167,7 +167,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
position="absolute"
insetInlineEnd={0}
h="full"
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'ip_adapter'}
/>
</PopoverTrigger>
</NumberInput>
@@ -185,7 +185,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
marks={marks}
formatValue={formatSliderValue}
alwaysShowMarks
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'ip_adapter'}
/>
</PopoverBody>
</PopoverContent>

View File

@@ -1,20 +1,19 @@
import { ContextMenu, Flex, MenuList } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useScopeOnFocus } from 'common/hooks/interactionScopes';
import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask';
import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus';
import { CanvasAlertsSendingToGallery } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo';
import { CanvasContextMenuGlobalMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems';
import { CanvasContextMenuSelectedEntityMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuSelectedEntityMenuItems';
import { CanvasContextMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItems';
import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea';
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { CanvasSelectedEntityStatusAlert } from 'features/controlLayers/components/HUD/CanvasSelectedEntityStatusAlert';
import { SendingToGalleryAlert } from 'features/controlLayers/components/HUD/CanvasSendingToGalleryAlert';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { StagingAreaIsStagingGate } from 'features/controlLayers/components/StagingArea/StagingAreaIsStagingGate';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
import { Transform } from 'features/controlLayers/components/Transform/Transform';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { TRANSPARENCY_CHECKERBOARD_PATTERN_DATAURL } from 'features/controlLayers/konva/patterns/transparency-checkerboard-pattern';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { memo, useCallback, useRef } from 'react';
@@ -28,8 +27,7 @@ export const CanvasMainPanelContent = memo(() => {
return (
<CanvasManagerProviderGate>
<MenuList>
<CanvasContextMenuGlobalMenuItems />
<CanvasContextMenuSelectedEntityMenuItems />
<CanvasContextMenuItems />
</MenuList>
</CanvasManagerProviderGate>
);
@@ -50,9 +48,7 @@ export const CanvasMainPanelContent = memo(() => {
alignItems="center"
justifyContent="center"
>
<CanvasManagerProviderGate>
<CanvasToolbar />
</CanvasManagerProviderGate>
<CanvasToolbar />
<ContextMenu<HTMLDivElement> renderMenu={renderMenu}>
{(ref) => (
<Flex
@@ -63,6 +59,18 @@ export const CanvasMainPanelContent = memo(() => {
bg={dynamicGrid ? 'base.850' : 'base.900'}
borderRadius="base"
>
{!dynamicGrid && (
<Flex
position="absolute"
borderRadius="base"
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DATAURL}
top={0}
right={0}
bottom={0}
left={0}
opacity={0.1}
/>
)}
<InvokeCanvasComponent />
<CanvasManagerProviderGate>
{showHUD && (
@@ -71,9 +79,8 @@ export const CanvasMainPanelContent = memo(() => {
</Flex>
)}
<Flex flexDir="column" position="absolute" top={1} insetInlineEnd={1} pointerEvents="none" gap={2}>
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsSendingToGallery />
<CanvasSelectedEntityStatusAlert />
<SendingToGalleryAlert />
</Flex>
</CanvasManagerProviderGate>
</Flex>

View File

@@ -1,14 +1,12 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useIsSavingCanvas, usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useEntityFilter } from 'features/controlLayers/hooks/useEntityFilter';
import {
controlLayerBeginEndStepPctChanged,
@@ -20,8 +18,8 @@ import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/s
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarBold, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
import { PiBoundingBoxBold, PiShootingStarBold } from 'react-icons/pi';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
const selectControlAdapter = useMemo(
@@ -72,12 +70,7 @@ export const ControlLayerControlAdapter = memo(() => {
);
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
const isBusy = useCanvasIsBusy();
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
[entityIdentifier]
);
const uploadApi = useImageUploadButton({ postUploadAction });
const isSaving = useIsSavingCanvas();
return (
<Flex flexDir="column" gap={3} position="relative" w="full">
@@ -86,34 +79,19 @@ export const ControlLayerControlAdapter = memo(() => {
<IconButton
onClick={filter.start}
isDisabled={filter.isDisabled}
size="sm"
alignSelf="stretch"
variant="link"
variant="ghost"
aria-label={t('controlLayers.filter.filter')}
tooltip={t('controlLayers.filter.filter')}
icon={<PiShootingStarBold />}
/>
<IconButton
onClick={pullBboxIntoLayer}
isDisabled={isBusy}
size="sm"
alignSelf="stretch"
variant="link"
isLoading={isSaving.isTrue}
variant="ghost"
aria-label={t('controlLayers.pullBboxIntoLayer')}
tooltip={t('controlLayers.pullBboxIntoLayer')}
icon={<PiBoundingBoxBold />}
/>
<IconButton
isDisabled={isBusy}
size="sm"
alignSelf="stretch"
variant="link"
aria-label={t('accessibility.uploadImage')}
tooltip={t('accessibility.uploadImage')}
icon={<PiUploadBold />}
{...uploadApi.getUploadButtonProps()}
/>
<input {...uploadApi.getUploadInputProps()} />
</Flex>
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />

View File

@@ -16,10 +16,10 @@ export const ControlLayerControlAdapterControlMode = memo(({ controlMode, onChan
const { t } = useTranslation();
const CONTROL_MODE_DATA = useMemo(
() => [
{ label: t('controlLayers.controlMode.balanced'), value: 'balanced' },
{ label: t('controlLayers.controlMode.prompt'), value: 'more_prompt' },
{ label: t('controlLayers.controlMode.control'), value: 'more_control' },
{ label: t('controlLayers.controlMode.megaControl'), value: 'unbalanced' },
{ label: t('controlnet.balanced'), value: 'balanced' },
{ label: t('controlnet.prompt'), value: 'more_prompt' },
{ label: t('controlnet.control'), value: 'more_control' },
{ label: t('controlnet.megaControl'), value: 'unbalanced' },
],
[t]
);
@@ -44,7 +44,7 @@ export const ControlLayerControlAdapterControlMode = memo(({ controlMode, onChan
return (
<FormControl>
<InformationalPopover feature="controlNetControlMode">
<FormLabel m={0}>{t('controlLayers.controlMode.controlMode')}</FormLabel>
<FormLabel m={0}>{t('controlnet.control')}</FormLabel>
</InformationalPopover>
<Combobox
value={value}

View File

@@ -51,7 +51,7 @@ export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onCha
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}

View File

@@ -4,7 +4,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { ControlLayerMenuItemsConvertControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertControlToRaster';
import { ControlLayerMenuItemsControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsControlToRaster';
import { ControlLayerMenuItemsTransparencyEffect } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsTransparencyEffect';
import { memo } from 'react';
@@ -13,7 +13,7 @@ export const ControlLayerMenuItems = memo(() => {
<>
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<ControlLayerMenuItemsConvertControlToRaster />
<ControlLayerMenuItemsControlToRaster />
<ControlLayerMenuItemsTransparencyEffect />
<MenuDivider />
<CanvasEntityMenuItemsArrange />

View File

@@ -1,27 +1,27 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { controlLayerConvertedToRasterLayer } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLightningBold } from 'react-icons/pi';
export const ControlLayerMenuItemsConvertControlToRaster = memo(() => {
export const ControlLayerMenuItemsControlToRaster = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const convertControlLayerToRasterLayer = useCallback(() => {
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiLightningBold />} isDisabled={!isInteractable}>
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiLightningBold />} isDisabled={isBusy}>
{t('controlLayers.convertToRasterLayer')}
</MenuItem>
);
});
ControlLayerMenuItemsConvertControlToRaster.displayName = 'ControlLayerMenuItemsConvertControlToRaster';
ControlLayerMenuItemsControlToRaster.displayName = 'ControlLayerMenuItemsControlToRaster';

View File

@@ -2,7 +2,6 @@ import { MenuItem } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { controlLayerWithTransparencyEffectToggled } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import { memo, useCallback, useMemo } from 'react';
@@ -13,7 +12,6 @@ export const ControlLayerMenuItemsTransparencyEffect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('control_layer');
const isInteractable = useIsEntityInteractable(entityIdentifier);
const selectWithTransparencyEffect = useMemo(
() =>
createSelector(selectCanvasSlice, (canvas) => {
@@ -28,7 +26,7 @@ export const ControlLayerMenuItemsTransparencyEffect = memo(() => {
}, [dispatch, entityIdentifier]);
return (
<MenuItem onClick={onToggle} icon={<PiDropHalfBold />} isDisabled={!isInteractable}>
<MenuItem onClick={onToggle} icon={<PiDropHalfBold />}>
{withTransparencyEffect
? t('controlLayers.disableTransparencyEffect')
: t('controlLayers.enableTransparencyEffect')}

View File

@@ -75,7 +75,7 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
<Button
variant="ghost"
leftIcon={<PiShootingStarBold />}
onClick={adapter.filterer.processImmediate}
onClick={adapter.filterer.process}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.process')}
isDisabled={!isValid || autoProcessFilter}
@@ -108,6 +108,7 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
onClick={adapter.filterer.cancel}
isLoading={isProcessing}
loadingText={t('controlLayers.filter.cancel')}
isDisabled={!isValid}
>
{t('controlLayers.filter.cancel')}
</Button>

View File

@@ -70,9 +70,8 @@ export const FilterSpandrel = ({ onChange, config }: Props) => {
);
useEffect(() => {
const firstModel = options[0];
if (!config.model && firstModel) {
onChangeModel(firstModel);
if (!config.model) {
onChangeModel(options[0] ?? null);
}
}, [config.model, onChangeModel, options]);
@@ -81,14 +80,14 @@ export const FilterSpandrel = ({ onChange, config }: Props) => {
<FormControl w="full" orientation="vertical">
<Flex w="full" alignItems="center">
<FormLabel m={0} flexGrow={1}>
{t('controlLayers.filter.spandrel_filter.autoScale')}
{t('controlLayers.filter.spandrel.paramAutoScale')}
</FormLabel>
<Switch size="sm" isChecked={config.autoScale} onChange={onAutoscaleChanged} />
</Flex>
<FormHelperText>{t('controlLayers.filter.spandrel_filter.autoScaleDesc')}</FormHelperText>
<FormHelperText>{t('controlLayers.filter.spandrel.paramAutoScaleDesc')}</FormHelperText>
</FormControl>
<FormControl isDisabled={!config.autoScale}>
<FormLabel m={0}>{t('controlLayers.filter.spandrel_filter.scale')}</FormLabel>
<FormLabel m={0}>{t('controlLayers.filter.spandrel.paramScale')}</FormLabel>
<CompositeSlider
value={config.scale}
onChange={onScaleChanged}
@@ -105,7 +104,7 @@ export const FilterSpandrel = ({ onChange, config }: Props) => {
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.spandrel_filter.model')}</FormLabel>
<FormLabel m={0}>{t('controlLayers.filter.spandrel.paramModel')}</FormLabel>
<Tooltip label={tooltipLabel}>
<Box w="full">
<Combobox

View File

@@ -1,5 +1,5 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { Alert, AlertDescription, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
@@ -27,9 +27,10 @@ const $isFilteringFallback = atom(false);
type AlertData = {
status: AlertStatus;
title: string;
description: string;
};
const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapter }: ContentProps) => {
const CanvasSelectedEntityStatusAlertContent = memo(({ entityIdentifier, adapter }: ContentProps) => {
const { t } = useTranslation();
const title = useEntityTitle(entityIdentifier);
const selectIsEnabled = useMemo(
@@ -45,69 +46,67 @@ const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapte
const isHidden = useEntityTypeIsHidden(entityIdentifier.type);
const isFiltering = useStore(adapter.filterer?.$isFiltering ?? $isFilteringFallback);
const isTransforming = useStore(adapter.transformer.$isTransforming);
const isEmpty = useStore(adapter.$isEmpty);
const alert = useMemo<AlertData | null>(() => {
if (isFiltering) {
return {
status: 'info',
title: t('controlLayers.HUD.entityStatus.isFiltering', { title }),
title,
description: t('controlLayers.HUD.entityStatus.isFiltering'),
};
}
if (isTransforming) {
return {
status: 'info',
title: t('controlLayers.HUD.entityStatus.isTransforming', { title }),
};
}
if (isEmpty) {
return {
status: 'info',
title: t('controlLayers.HUD.entityStatus.isEmpty', { title }),
title,
description: t('controlLayers.HUD.entityStatus.isTransforming'),
};
}
if (isHidden) {
return {
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isHidden', { title }),
title,
description: t('controlLayers.HUD.entityStatus.isHidden'),
};
}
if (isLocked) {
return {
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isLocked', { title }),
title,
description: t('controlLayers.HUD.entityStatus.isLocked'),
};
}
if (!isEnabled) {
return {
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isDisabled', { title }),
title,
description: t('controlLayers.HUD.entityStatus.isDisabled'),
};
}
return null;
}, [isFiltering, isTransforming, isEmpty, isHidden, isLocked, isEnabled, title, t]);
}, [isFiltering, isTransforming, isHidden, isLocked, isEnabled, title, t]);
if (!alert) {
return null;
}
return (
<Alert status={alert.status} borderRadius="base" fontSize="sm" shadow="md" w="fit-content" alignSelf="flex-end">
<Alert status={alert.status} borderRadius="base" fontSize="sm" shadow="md">
<AlertIcon />
<AlertTitle>{alert.title}</AlertTitle>
<AlertDescription>{alert.description}.</AlertDescription>
</Alert>
);
});
CanvasAlertsSelectedEntityStatusContent.displayName = 'CanvasAlertsSelectedEntityStatusContent';
CanvasSelectedEntityStatusAlertContent.displayName = 'CanvasSelectedEntityStatusAlertContent';
export const CanvasAlertsSelectedEntityStatus = memo(() => {
export const CanvasSelectedEntityStatusAlert = memo(() => {
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const adapter = useEntityAdapterSafe(selectedEntityIdentifier);
@@ -115,7 +114,7 @@ export const CanvasAlertsSelectedEntityStatus = memo(() => {
return null;
}
return <CanvasAlertsSelectedEntityStatusContent entityIdentifier={selectedEntityIdentifier} adapter={adapter} />;
return <CanvasSelectedEntityStatusAlertContent entityIdentifier={selectedEntityIdentifier} adapter={adapter} />;
});
CanvasAlertsSelectedEntityStatus.displayName = 'CanvasAlertsSelectedEntityStatus';
CanvasSelectedEntityStatusAlert.displayName = 'CanvasSelectedEntityStatusAlert';

View File

@@ -38,7 +38,7 @@ const ActivateImageViewerButton = (props: PropsWithChildren) => {
);
};
export const CanvasAlertsSendingToGallery = () => {
export const SendingToGalleryAlert = () => {
const { t } = useTranslation();
const destination = useCurrentDestination();
const isVisible = useMemo(() => {
@@ -68,7 +68,7 @@ const ActivateCanvasButton = (props: PropsWithChildren) => {
const dispatch = useAppDispatch();
const imageViewer = useImageViewer();
const onClick = useCallback(() => {
dispatch(setActiveTab('canvas'));
dispatch(setActiveTab('generation'));
setRightPanelTabToLayers();
imageViewer.close();
}, [dispatch, imageViewer]);
@@ -79,7 +79,7 @@ const ActivateCanvasButton = (props: PropsWithChildren) => {
);
};
export const CanvasAlertsSendingToCanvas = () => {
export const SendingToCanvasAlert = () => {
const { t } = useTranslation();
const destination = useCurrentDestination();
const isVisible = useMemo(() => {
@@ -136,16 +136,7 @@ const AlertWrapper = ({
onMouseEnter={isHovered.setTrue}
onMouseLeave={isHovered.setFalse}
>
<Alert
status="warning"
flexDir="column"
pointerEvents="auto"
borderRadius="base"
fontSize="sm"
shadow="md"
w="fit-content"
alignSelf="flex-end"
>
<Alert status="warning" flexDir="column" pointerEvents="auto" borderRadius="base" fontSize="sm" shadow="md">
<Flex w="full" alignItems="center">
<AlertIcon />
<AlertTitle>{title}</AlertTitle>

View File

@@ -13,7 +13,7 @@ type Props = {
};
export const IPAdapter = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'reference_image' }), [id]);
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'ip_adapter' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>

View File

@@ -5,7 +5,6 @@ import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { useNanoid } from 'common/hooks/useNanoid';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { selectOptimalDimension } from 'features/controlLayers/store/selectors';
import type { ImageWithDims } from 'features/controlLayers/store/types';
@@ -20,86 +19,90 @@ import type { ImageDTO, PostUploadAction } from 'services/api/types';
type Props = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
ipAdapterId: string; // required for the dnd/upload interactions
droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction;
};
export const IPAdapterImagePreview = memo(({ image, onChangeImage, droppableData, postUploadAction }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isConnected = useStore($isConnected);
const optimalDimension = useAppSelector(selectOptimalDimension);
const shift = useShiftModifier();
const dndId = useNanoid('ip_adapter_image_preview');
export const IPAdapterImagePreview = memo(
({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isConnected = useStore($isConnected);
const optimalDimension = useAppSelector(selectOptimalDimension);
const shift = useShiftModifier();
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.image_name ?? skipToken
);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
image?.image_name ?? skipToken
);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
const handleSetControlImageToDimensions = useCallback(() => {
if (!controlImage) {
return;
}
const handleSetControlImageToDimensions = useCallback(() => {
if (!controlImage) {
return;
}
const options = { updateAspectRatio: true, clamp: true };
if (shift) {
const { width, height } = controlImage;
dispatch(bboxWidthChanged({ width, ...options }));
dispatch(bboxHeightChanged({ height, ...options }));
} else {
const { width, height } = calculateNewSize(
controlImage.width / controlImage.height,
optimalDimension * optimalDimension
);
dispatch(bboxWidthChanged({ width, ...options }));
dispatch(bboxHeightChanged({ height, ...options }));
}
}, [controlImage, dispatch, optimalDimension, shift]);
const options = { updateAspectRatio: true, clamp: true };
if (shift) {
const { width, height } = controlImage;
dispatch(bboxWidthChanged({ width, ...options }));
dispatch(bboxHeightChanged({ height, ...options }));
} else {
const { width, height } = calculateNewSize(
controlImage.width / controlImage.height,
optimalDimension * optimalDimension
);
dispatch(bboxWidthChanged({ width, ...options }));
dispatch(bboxHeightChanged({ height, ...options }));
}
}, [controlImage, dispatch, optimalDimension, shift]);
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
if (controlImage) {
return {
id: dndId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, dndId]);
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
if (controlImage) {
return {
id: ipAdapterId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, ipAdapterId]);
useEffect(() => {
if (isConnected && isErrorControlImage) {
handleResetControlImage();
}
}, [handleResetControlImage, isConnected, isErrorControlImage]);
useEffect(() => {
if (isConnected && isErrorControlImage) {
handleResetControlImage();
}
}, [handleResetControlImage, isConnected, isErrorControlImage]);
return (
<Flex position="relative" w="full" h="full" alignItems="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
postUploadAction={postUploadAction}
/>
return (
<Flex position="relative" w={36} h={36} alignItems="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
postUploadAction={postUploadAction}
/>
{controlImage && (
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={shift ? t('controlLayers.useSizeIgnoreModel') : t('controlLayers.useSizeOptimizeForModel')}
/>
</Flex>
)}
</Flex>
);
});
{controlImage && (
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
</Flex>
);
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';

View File

@@ -9,10 +9,10 @@ import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/cont
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(mapId).reverse();
return canvas.ipAdapters.entities.map(mapId).reverse();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'reference_image';
return selectedEntityIdentifier?.type === 'ip_adapter';
});
export const IPAdapterList = memo(() => {
@@ -25,7 +25,7 @@ export const IPAdapterList = memo(() => {
if (ipaIds.length > 0) {
return (
<CanvasEntityGroupList type="reference_image" isSelected={isSelected}>
<CanvasEntityGroupList type="ip_adapter" isSelected={isSelected}>
{ipaIds.map((id) => (
<IPAdapter key={id} id={id} />
))}

View File

@@ -16,9 +16,9 @@ export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
const { t } = useTranslation();
const options: { label: string; value: IPMethodV2 }[] = useMemo(
() => [
{ label: t('controlLayers.ipAdapterMethod.full'), value: 'full' },
{ label: `${t('controlLayers.ipAdapterMethod.style')} (${t('common.beta')})`, value: 'style' },
{ label: `${t('controlLayers.ipAdapterMethod.composition')} (${t('common.beta')})`, value: 'composition' },
{ label: t('controlnet.full'), value: 'full' },
{ label: `${t('controlnet.style')} (${t('common.beta')})`, value: 'style' },
{ label: `${t('controlnet.composition')} (${t('common.beta')})`, value: 'composition' },
],
[t]
);
@@ -34,7 +34,7 @@ export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
return (
<FormControl>
<InformationalPopover feature="ipAdapterMethod">
<FormLabel>{t('controlLayers.ipAdapterMethod.ipAdapterMethod')}</FormLabel>
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} onChange={_onChange} />
</FormControl>

View File

@@ -70,12 +70,12 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
);
return (
<Flex gap={2}>
<Flex gap={4}>
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
@@ -86,7 +86,7 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
<Combobox
options={CLIP_VISION_OPTIONS}
placeholder={t('common.placeholderSelectAModel')}
placeholder={t('controlnet.selectCLIPVisionModel')}
value={clipVisionModelValue}
onChange={_onChangeCLIPVisionModel}
/>

View File

@@ -6,15 +6,14 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
import { Weight } from 'features/controlLayers/components/common/Weight';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useIsSavingCanvas, usePullBboxIntoIPAdapter } from 'features/controlLayers/hooks/saveCanvasHooks';
import {
referenceImageIPAdapterBeginEndStepPctChanged,
referenceImageIPAdapterCLIPVisionModelChanged,
referenceImageIPAdapterImageChanged,
referenceImageIPAdapterMethodChanged,
referenceImageIPAdapterModelChanged,
referenceImageIPAdapterWeightChanged,
ipaBeginEndStepPctChanged,
ipaCLIPVisionModelChanged,
ipaImageChanged,
ipaMethodChanged,
ipaModelChanged,
ipaWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
@@ -30,7 +29,7 @@ import { IPAdapterModel } from './IPAdapterModel';
export const IPAdapterSettings = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const entityIdentifier = useEntityIdentifierContext('ip_adapter');
const selectIPAdapter = useMemo(
() => createSelector(selectCanvasSlice, (s) => selectEntityOrThrow(s, entityIdentifier).ipAdapter),
[entityIdentifier]
@@ -39,42 +38,42 @@ export const IPAdapterSettings = memo(() => {
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
dispatch(referenceImageIPAdapterBeginEndStepPctChanged({ entityIdentifier, beginEndStepPct }));
dispatch(ipaBeginEndStepPctChanged({ entityIdentifier, beginEndStepPct }));
},
[dispatch, entityIdentifier]
);
const onChangeWeight = useCallback(
(weight: number) => {
dispatch(referenceImageIPAdapterWeightChanged({ entityIdentifier, weight }));
dispatch(ipaWeightChanged({ entityIdentifier, weight }));
},
[dispatch, entityIdentifier]
);
const onChangeIPMethod = useCallback(
(method: IPMethodV2) => {
dispatch(referenceImageIPAdapterMethodChanged({ entityIdentifier, method }));
dispatch(ipaMethodChanged({ entityIdentifier, method }));
},
[dispatch, entityIdentifier]
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
dispatch(ipaModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
);
const onChangeCLIPVisionModel = useCallback(
(clipVisionModel: CLIPVisionModelV2) => {
dispatch(referenceImageIPAdapterCLIPVisionModelChanged({ entityIdentifier, clipVisionModel }));
dispatch(ipaCLIPVisionModelChanged({ entityIdentifier, clipVisionModel }));
},
[dispatch, entityIdentifier]
);
const onChangeImage = useCallback(
(imageDTO: ImageDTO | null) => {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier, imageDTO }));
dispatch(ipaImageChanged({ entityIdentifier, imageDTO }));
},
[dispatch, entityIdentifier]
);
@@ -87,13 +86,13 @@ export const IPAdapterSettings = memo(() => {
() => ({ type: 'SET_IPA_IMAGE', id: entityIdentifier.id }),
[entityIdentifier.id]
);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const isBusy = useCanvasIsBusy();
const pullBboxIntoIPAdapter = usePullBboxIntoIPAdapter(entityIdentifier);
const isSaving = useIsSavingCanvas();
return (
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<Flex flexDir="column" gap={4} position="relative" w="full">
<Flex gap={3} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<IPAdapterModel
modelKey={ipAdapter.model?.key ?? null}
@@ -104,23 +103,24 @@ export const IPAdapterSettings = memo(() => {
</Box>
<IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
isLoading={isSaving.isTrue}
variant="ghost"
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
aria-label={t('controlLayers.pullBboxIntoIPAdapter')}
tooltip={t('controlLayers.pullBboxIntoIPAdapter')}
icon={<PiBoundingBoxBold />}
/>
</Flex>
<Flex gap={2} w="full" alignItems="center">
<Flex flexDir="column" gap={2} w="full">
<Flex gap={4} w="full" alignItems="center">
<Flex flexDir="column" gap={3} w="full">
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAdapterImagePreview
image={ipAdapter.image ?? null}
onChangeImage={onChangeImage}
ipAdapterId={entityIdentifier.id}
droppableData={droppableData}
postUploadAction={postUploadAction}
/>

View File

@@ -4,7 +4,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { RasterLayerMenuItemsConvertRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertRasterToControl';
import { RasterLayerMenuItemsRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsRasterToControl';
import { memo } from 'react';
export const RasterLayerMenuItems = memo(() => {
@@ -12,7 +12,7 @@ export const RasterLayerMenuItems = memo(() => {
<>
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<RasterLayerMenuItemsConvertRasterToControl />
<RasterLayerMenuItemsRasterToControl />
<MenuDivider />
<CanvasEntityMenuItemsArrange />
<MenuDivider />

View File

@@ -2,20 +2,20 @@ import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { rasterLayerConvertedToControlLayer } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLightningBold } from 'react-icons/pi';
export const RasterLayerMenuItemsConvertRasterToControl = memo(() => {
export const RasterLayerMenuItemsRasterToControl = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isInteractable = useIsEntityInteractable(entityIdentifier);
const isBusy = useCanvasIsBusy();
const onClick = useCallback(() => {
const convertRasterLayerToControlLayer = useCallback(() => {
dispatch(
rasterLayerConvertedToControlLayer({
entityIdentifier,
@@ -27,10 +27,10 @@ export const RasterLayerMenuItemsConvertRasterToControl = memo(() => {
}, [defaultControlAdapter, dispatch, entityIdentifier]);
return (
<MenuItem onClick={onClick} icon={<PiLightningBold />} isDisabled={!isInteractable}>
<MenuItem onClick={convertRasterLayerToControlLayer} icon={<PiLightningBold />} isDisabled={isBusy}>
{t('controlLayers.convertToControlLayer')}
</MenuItem>
);
});
RasterLayerMenuItemsConvertRasterToControl.displayName = 'RasterLayerMenuItemsConvertRasterToControl';
RasterLayerMenuItemsRasterToControl.displayName = 'RasterLayerMenuItemsRasterToControl';

View File

@@ -33,7 +33,7 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
onClick={addRegionalGuidancePositivePrompt}
isDisabled={!validActions.canAddPositivePrompt}
>
{t('controlLayers.prompt')}
{t('common.positivePrompt')}
</Button>
<Button
size="sm"
@@ -42,10 +42,10 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
onClick={addRegionalGuidanceNegativePrompt}
isDisabled={!validActions.canAddNegativePrompt}
>
{t('controlLayers.negativePrompt')}
{t('common.negativePrompt')}
</Button>
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addRegionalGuidanceIPAdapter}>
{t('controlLayers.referenceImage')}
{t('common.ipAdapter')}
</Button>
</Flex>
);

View File

@@ -1,7 +1,7 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleFill } from 'react-icons/pi';
import { PiTrashSimpleBold } from 'react-icons/pi';
type Props = {
onDelete: () => void;
@@ -14,12 +14,11 @@ export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) =>
<IconButton
variant="link"
aria-label={t('controlLayers.deletePrompt')}
icon={<PiTrashSimpleFill />}
icon={<PiTrashSimpleBold />}
onClick={onDelete}
flexGrow={0}
size="sm"
p={0}
colorScheme="error"
/>
</Tooltip>
);

View File

@@ -8,7 +8,7 @@ import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/cont
import { memo } from 'react';
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.regionalGuidance.entities.map(mapId).reverse();
return canvas.regions.entities.map(mapId).reverse();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'regional_guidance';

View File

@@ -7,8 +7,10 @@ import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapt
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
useIsSavingCanvas,
usePullBboxIntoRegionalGuidanceIPAdapter,
} from 'features/controlLayers/hooks/saveCanvasHooks';
import {
rgIPAdapterBeginEndStepPctChanged,
rgIPAdapterCLIPVisionModelChanged,
@@ -18,114 +20,111 @@ import {
rgIPAdapterModelChanged,
rgIPAdapterWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice, selectRegionalGuidanceReferenceImage } from 'features/controlLayers/store/selectors';
import { selectCanvasSlice, selectRegionalGuidanceIPAdapter } from 'features/controlLayers/store/selectors';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { RGIPAdapterImageDropData } from 'features/dnd/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiTrashSimpleFill } from 'react-icons/pi';
import { PiBoundingBoxBold, PiTrashSimpleBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig, RGIPAdapterImagePostUploadAction } from 'services/api/types';
import { assert } from 'tsafe';
type Props = {
referenceImageId: string;
ipAdapterId: string;
ipAdapterNumber: number;
};
export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Props) => {
export const RegionalGuidanceIPAdapterSettings = memo(({ ipAdapterId, ipAdapterNumber }: Props) => {
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onDeleteIPAdapter = useCallback(() => {
dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
}, [dispatch, entityIdentifier, referenceImageId]);
dispatch(rgIPAdapterDeleted({ entityIdentifier, ipAdapterId }));
}, [dispatch, entityIdentifier, ipAdapterId]);
const selectIPAdapter = useMemo(
() =>
createSelector(selectCanvasSlice, (canvas) => {
const referenceImage = selectRegionalGuidanceReferenceImage(canvas, entityIdentifier, referenceImageId);
assert(referenceImage, `Regional Guidance IP Adapter with id ${referenceImageId} not found`);
return referenceImage.ipAdapter;
const ipAdapter = selectRegionalGuidanceIPAdapter(canvas, entityIdentifier, ipAdapterId);
assert(ipAdapter, `Regional GuidanceIP Adapter with id ${ipAdapterId} not found`);
return ipAdapter;
}),
[entityIdentifier, referenceImageId]
[entityIdentifier, ipAdapterId]
);
const ipAdapter = useAppSelector(selectIPAdapter);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
dispatch(rgIPAdapterBeginEndStepPctChanged({ entityIdentifier, referenceImageId, beginEndStepPct }));
dispatch(rgIPAdapterBeginEndStepPctChanged({ entityIdentifier, ipAdapterId, beginEndStepPct }));
},
[dispatch, entityIdentifier, referenceImageId]
[dispatch, entityIdentifier, ipAdapterId]
);
const onChangeWeight = useCallback(
(weight: number) => {
dispatch(rgIPAdapterWeightChanged({ entityIdentifier, referenceImageId, weight }));
dispatch(rgIPAdapterWeightChanged({ entityIdentifier, ipAdapterId, weight }));
},
[dispatch, entityIdentifier, referenceImageId]
[dispatch, entityIdentifier, ipAdapterId]
);
const onChangeIPMethod = useCallback(
(method: IPMethodV2) => {
dispatch(rgIPAdapterMethodChanged({ entityIdentifier, referenceImageId, method }));
dispatch(rgIPAdapterMethodChanged({ entityIdentifier, ipAdapterId, method }));
},
[dispatch, entityIdentifier, referenceImageId]
[dispatch, entityIdentifier, ipAdapterId]
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig) => {
dispatch(rgIPAdapterModelChanged({ entityIdentifier, referenceImageId, modelConfig }));
dispatch(rgIPAdapterModelChanged({ entityIdentifier, ipAdapterId, modelConfig }));
},
[dispatch, entityIdentifier, referenceImageId]
[dispatch, entityIdentifier, ipAdapterId]
);
const onChangeCLIPVisionModel = useCallback(
(clipVisionModel: CLIPVisionModelV2) => {
dispatch(rgIPAdapterCLIPVisionModelChanged({ entityIdentifier, referenceImageId, clipVisionModel }));
dispatch(rgIPAdapterCLIPVisionModelChanged({ entityIdentifier, ipAdapterId, clipVisionModel }));
},
[dispatch, entityIdentifier, referenceImageId]
[dispatch, entityIdentifier, ipAdapterId]
);
const onChangeImage = useCallback(
(imageDTO: ImageDTO | null) => {
dispatch(rgIPAdapterImageChanged({ entityIdentifier, referenceImageId, imageDTO }));
dispatch(rgIPAdapterImageChanged({ entityIdentifier, ipAdapterId, imageDTO }));
},
[dispatch, entityIdentifier, referenceImageId]
[dispatch, entityIdentifier, ipAdapterId]
);
const droppableData = useMemo<RGIPAdapterImageDropData>(
() => ({
actionType: 'SET_RG_IP_ADAPTER_IMAGE',
context: { id: entityIdentifier.id, referenceImageId: referenceImageId },
context: { id: entityIdentifier.id, ipAdapterId },
id: entityIdentifier.id,
}),
[entityIdentifier.id, referenceImageId]
[entityIdentifier.id, ipAdapterId]
);
const postUploadAction = useMemo<RGIPAdapterImagePostUploadAction>(
() => ({ type: 'SET_RG_IP_ADAPTER_IMAGE', id: entityIdentifier.id, referenceImageId: referenceImageId }),
[entityIdentifier.id, referenceImageId]
() => ({ type: 'SET_RG_IP_ADAPTER_IMAGE', id: entityIdentifier.id, ipAdapterId }),
[entityIdentifier.id, ipAdapterId]
);
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
const isBusy = useCanvasIsBusy();
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceIPAdapter(entityIdentifier, ipAdapterId);
const isSaving = useIsSavingCanvas();
return (
<Flex flexDir="column" gap={2}>
<Flex alignItems="center" gap={2}>
<Text fontWeight="semibold" color="base.400">
{t('controlLayers.referenceImage')}
</Text>
<Flex flexDir="column" gap={3}>
<Flex alignItems="center" gap={3}>
<Text fontWeight="semibold" color="base.400">{`IP Adapter ${ipAdapterNumber}`}</Text>
<Spacer />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
icon={<PiTrashSimpleFill />}
tooltip={t('controlLayers.deleteReferenceImage')}
aria-label={t('controlLayers.deleteReferenceImage')}
icon={<PiTrashSimpleBold />}
aria-label="Delete IP Adapter"
onClick={onDeleteIPAdapter}
variant="ghost"
colorScheme="error"
/>
</Flex>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<Flex flexDir="column" gap={4} position="relative" w="full">
<Flex gap={3} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<IPAdapterModel
modelKey={ipAdapter.model?.key ?? null}
@@ -136,23 +135,24 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
</Box>
<IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
isLoading={isSaving.isTrue}
variant="ghost"
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
aria-label={t('controlLayers.pullBboxIntoIPAdapter')}
tooltip={t('controlLayers.pullBboxIntoIPAdapter')}
icon={<PiBoundingBoxBold />}
/>
</Flex>
<Flex gap={2} w="full">
<Flex flexDir="column" gap={2} w="full">
<Flex gap={4} w="full" alignItems="center">
<Flex flexDir="column" gap={3} w="full">
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAdapterImagePreview
image={ipAdapter.image ?? null}
onChangeImage={onChangeImage}
ipAdapterId={ipAdapter.id}
droppableData={droppableData}
postUploadAction={postUploadAction}
/>

View File

@@ -13,7 +13,7 @@ export const RegionalGuidanceIPAdapters = memo(() => {
const selectIPAdapterIds = useMemo(
() =>
createMemoizedSelector(selectCanvasSlice, (canvas) => {
const ipAdapterIds = selectEntityOrThrow(canvas, entityIdentifier).referenceImages.map(({ id }) => id);
const ipAdapterIds = selectEntityOrThrow(canvas, entityIdentifier).ipAdapters.map(({ id }) => id);
if (ipAdapterIds.length === 0) {
return EMPTY_ARRAY;
}
@@ -33,7 +33,7 @@ export const RegionalGuidanceIPAdapters = memo(() => {
{ipAdapterIds.map((ipAdapterId, index) => (
<Fragment key={ipAdapterId}>
{index > 0 && <Divider />}
<RegionalGuidanceIPAdapterSettings referenceImageId={ipAdapterId} />
<RegionalGuidanceIPAdapterSettings ipAdapterId={ipAdapterId} ipAdapterNumber={index + 1} />
</Fragment>
))}
</>

View File

@@ -33,7 +33,7 @@ export const RegionalGuidanceMenuItemsAddPromptsAndIPAdapter = memo(() => {
{t('controlLayers.addNegativePrompt')}
</MenuItem>
<MenuItem onClick={addRegionalGuidanceIPAdapter} isDisabled={isBusy}>
{t('controlLayers.addReferenceImage')}
{t('controlLayers.addIPAdapter')}
</MenuItem>
</>
);

View File

@@ -20,7 +20,7 @@ export const RegionalGuidanceSettings = memo(() => {
return {
hasPositivePrompt: entity.positivePrompt !== null,
hasNegativePrompt: entity.negativePrompt !== null,
hasIPAdapters: entity.referenceImages.length > 0,
hasIPAdapters: entity.ipAdapters.length > 0,
};
}),
[entityIdentifier]

View File

@@ -1,25 +0,0 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectBboxOverlay, settingsBboxOverlayToggled } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsBboxOverlaySwitch = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const bboxOverlay = useAppSelector(selectBboxOverlay);
const onChange = useCallback(() => {
dispatch(settingsBboxOverlayToggled());
}, [dispatch]);
return (
<FormControl>
<FormLabel m={0} flexGrow={1}>
{t('controlLayers.bboxOverlay')}
</FormLabel>
<Switch size="sm" isChecked={bboxOverlay} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsBboxOverlaySwitch.displayName = 'CanvasSettingsBboxOverlaySwitch';

View File

@@ -0,0 +1,33 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectCanvasSettingsSlice,
settingsCompositeMaskedRegionsChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selectCompositeMaskedRegions = createSelector(
selectCanvasSettingsSlice,
(canvasSettings) => canvasSettings.compositeMaskedRegions
);
export const CanvasSettingsCompositeMaskedRegionsCheckbox = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const compositeMaskedRegions = useAppSelector(selectCompositeMaskedRegions);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(settingsCompositeMaskedRegionsChanged(e.target.checked)),
[dispatch]
);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.compositeMaskedRegions')}</FormLabel>
<Checkbox isChecked={compositeMaskedRegions} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsCompositeMaskedRegionsCheckbox.displayName = 'CanvasSettingsCompositeMaskedRegionsCheckbox';

View File

@@ -1,25 +0,0 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectOutputOnlyMaskedRegions,
settingsOutputOnlyMaskedRegionsToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsOutputOnlyMaskedRegionsCheckbox = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const outputOnlyMaskedRegions = useAppSelector(selectOutputOnlyMaskedRegions);
const onChange = useCallback(() => {
dispatch(settingsOutputOnlyMaskedRegionsToggled());
}, [dispatch]);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.outputOnlyMaskedRegions')}</FormLabel>
<Checkbox isChecked={outputOnlyMaskedRegions} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsOutputOnlyMaskedRegionsCheckbox.displayName = 'CanvasSettingsOutputOnlyMaskedRegionsCheckbox';

View File

@@ -10,29 +10,27 @@ import {
useShiftModifier,
} from '@invoke-ai/ui-library';
import { CanvasSettingsAutoSaveCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsAutoSaveCheckbox';
import { CanvasSettingsBboxOverlaySwitch } from 'features/controlLayers/components/Settings/CanvasSettingsBboxOverlaySwitch';
import { CanvasSettingsClearCachesButton } from 'features/controlLayers/components/Settings/CanvasSettingsClearCachesButton';
import { CanvasSettingsClearHistoryButton } from 'features/controlLayers/components/Settings/CanvasSettingsClearHistoryButton';
import { CanvasSettingsClipToBboxCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsClipToBboxCheckbox';
import { CanvasSettingsCompositeMaskedRegionsCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsCompositeMaskedRegionsCheckbox';
import { CanvasSettingsDynamicGridSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsDynamicGridSwitch';
import { CanvasSettingsSnapToGridCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsGridSize';
import { CanvasSettingsInvertScrollCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsInvertScrollCheckbox';
import { CanvasSettingsLogDebugInfoButton } from 'features/controlLayers/components/Settings/CanvasSettingsLogDebugInfo';
import { CanvasSettingsOutputOnlyMaskedRegionsCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsOutputOnlyMaskedRegionsCheckbox';
import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPreserveMaskCheckbox';
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
import { CanvasSettingsResetButton } from 'features/controlLayers/components/Settings/CanvasSettingsResetButton';
import { CanvasSettingsShowHUDSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsShowHUDSwitch';
import { CanvasSettingsShowProgressOnCanvas } from 'features/controlLayers/components/Settings/CanvasSettingsShowProgressOnCanvasSwitch';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixFill } from 'react-icons/pi';
import { RiSettings4Fill } from 'react-icons/ri';
export const CanvasSettingsPopover = memo(() => {
const { t } = useTranslation();
return (
<Popover isLazy>
<PopoverTrigger>
<IconButton aria-label={t('common.settingsLabel')} icon={<PiGearSixFill />} variant="ghost" />
<IconButton aria-label={t('common.settingsLabel')} icon={<RiSettings4Fill />} variant="ghost" />
</PopoverTrigger>
<PopoverContent>
<PopoverArrow />
@@ -40,14 +38,12 @@ export const CanvasSettingsPopover = memo(() => {
<Flex direction="column" gap={2}>
<CanvasSettingsAutoSaveCheckbox />
<CanvasSettingsInvertScrollCheckbox />
<CanvasSettingsPreserveMaskCheckbox />
<CanvasSettingsClipToBboxCheckbox />
<CanvasSettingsOutputOnlyMaskedRegionsCheckbox />
<CanvasSettingsCompositeMaskedRegionsCheckbox />
<CanvasSettingsSnapToGridCheckbox />
<CanvasSettingsShowProgressOnCanvas />
<CanvasSettingsDynamicGridSwitch />
<CanvasSettingsBboxOverlaySwitch />
<CanvasSettingsShowHUDSwitch />
<CanvasSettingsResetButton />
<DebugSettings />
</Flex>
</PopoverBody>

View File

@@ -1,20 +0,0 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectPreserveMask, settingsPreserveMaskToggled } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsPreserveMaskCheckbox = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const preserveMask = useAppSelector(selectPreserveMask);
const onChange = useCallback(() => dispatch(settingsPreserveMaskToggled()), [dispatch]);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.settings.preserveMask.label')}</FormLabel>
<Checkbox isChecked={preserveMask} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsPreserveMaskCheckbox.displayName = 'CanvasSettingsPreserveMaskCheckbox';

View File

@@ -1,12 +1,11 @@
import { IconButton } from '@invoke-ai/ui-library';
import { Button } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { canvasReset } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashBold } from 'react-icons/pi';
export const CanvasToolbarResetCanvasButton = memo(() => {
export const CanvasSettingsResetButton = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const canvasManager = useCanvasManager();
@@ -15,15 +14,10 @@ export const CanvasToolbarResetCanvasButton = memo(() => {
canvasManager.stage.fitLayersToStage();
}, [canvasManager.stage, dispatch]);
return (
<IconButton
aria-label={t('controlLayers.resetCanvas')}
tooltip={t('controlLayers.resetCanvas')}
onClick={onClick}
colorScheme="error"
icon={<PiTrashBold />}
variant="ghost"
/>
<Button onClick={onClick} colorScheme="error" size="sm">
{t('controlLayers.resetCanvas')}
</Button>
);
});
CanvasToolbarResetCanvasButton.displayName = 'CanvasToolbarResetCanvasButton';
CanvasSettingsResetButton.displayName = 'CanvasSettingsResetButton';

View File

@@ -1,28 +0,0 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectShowProgressOnCanvas,
settingsShowProgressOnCanvasToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsShowProgressOnCanvas = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const showProgressOnCanvas = useAppSelector(selectShowProgressOnCanvas);
const onChange = useCallback(() => {
dispatch(settingsShowProgressOnCanvasToggled());
}, [dispatch]);
return (
<FormControl>
<FormLabel m={0} flexGrow={1}>
{t('controlLayers.showProgressOnCanvas')}
</FormLabel>
<Switch size="sm" isChecked={showProgressOnCanvas} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsShowProgressOnCanvas.displayName = 'CanvasSettingsShowProgressOnCanvas';

View File

@@ -26,14 +26,7 @@ export const ToolColorPicker = memo(() => {
<Flex role="button" aria-label={t('controlLayers.fill.fillColor')} tabIndex={-1} w={8} h={8}>
<Tooltip label={t('controlLayers.fill.fillColor')}>
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Box
borderRadius="full"
borderColor="base.300"
w={6}
h={6}
borderWidth={1}
bg={rgbaColorToString(fill)}
/>
<Box borderRadius="full" w={6} h={6} borderWidth={1} bg={rgbaColorToString(fill)} />
</Flex>
</Tooltip>
</Flex>

View File

@@ -5,10 +5,10 @@ import { ToolChooser } from 'features/controlLayers/components/Tool/ToolChooser'
import { ToolColorPicker } from 'features/controlLayers/components/Tool/ToolFillColorPicker';
import { ToolSettings } from 'features/controlLayers/components/Tool/ToolSettings';
import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton';
import { CanvasToolbarResetCanvasButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetCanvasButton';
import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton';
import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton';
import { CanvasToolbarScale } from 'features/controlLayers/components/Toolbar/CanvasToolbarScale';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCanvasDeleteLayerHotkey } from 'features/controlLayers/hooks/useCanvasDeleteLayerHotkey';
import { useCanvasEntityQuickSwitchHotkey } from 'features/controlLayers/hooks/useCanvasEntityQuickSwitchHotkey';
import { useCanvasResetLayerHotkey } from 'features/controlLayers/hooks/useCanvasResetLayerHotkey';
@@ -24,20 +24,21 @@ export const CanvasToolbar = memo(() => {
useNextPrevEntityHotkeys();
return (
<Flex w="full" gap={2} alignItems="center">
<ToolChooser />
<Spacer />
<ToolSettings />
<Spacer />
<CanvasToolbarScale />
<CanvasToolbarResetViewButton />
<Spacer />
<ToolColorPicker />
<CanvasToolbarFitBboxToLayersButton />
<CanvasToolbarSaveToGalleryButton />
<CanvasToolbarResetCanvasButton />
<CanvasSettingsPopover />
</Flex>
<CanvasManagerProviderGate>
<Flex w="full" gap={2} alignItems="center">
<ToolChooser />
<Spacer />
<ToolSettings />
<Spacer />
<CanvasToolbarScale />
<CanvasToolbarResetViewButton />
<Spacer />
<ToolColorPicker />
<CanvasToolbarFitBboxToLayersButton />
<CanvasToolbarSaveToGalleryButton />
<CanvasSettingsPopover />
</Flex>
</CanvasManagerProviderGate>
);
});

View File

@@ -1,6 +1,5 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsOut } from 'react-icons/pi';
@@ -8,7 +7,6 @@ import { PiArrowsOut } from 'react-icons/pi';
export const CanvasToolbarFitBboxToLayersButton = memo(() => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const isBusy = useCanvasIsBusy();
const onClick = useCallback(() => {
canvasManager.bbox.fitToLayers();
}, [canvasManager.bbox]);
@@ -20,7 +18,6 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
aria-label={t('controlLayers.fitBboxToLayers')}
tooltip={t('controlLayers.fitBboxToLayers')}
icon={<PiArrowsOut />}
isDisabled={isBusy}
/>
);
});

View File

@@ -1,6 +1,9 @@
import { IconButton, useShiftModifier } from '@invoke-ai/ui-library';
import { useSaveBboxToGallery, useSaveCanvasToGallery } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
useIsSavingCanvas,
useSaveBboxToGallery,
useSaveCanvasToGallery,
} from 'features/controlLayers/hooks/saveCanvasHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold } from 'react-icons/pi';
@@ -8,7 +11,7 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
export const CanvasToolbarSaveToGalleryButton = memo(() => {
const { t } = useTranslation();
const shift = useShiftModifier();
const isBusy = useCanvasIsBusy();
const isSaving = useIsSavingCanvas();
const saveCanvasToGallery = useSaveCanvasToGallery();
const saveBboxToGallery = useSaveBboxToGallery();
@@ -17,9 +20,9 @@ export const CanvasToolbarSaveToGalleryButton = memo(() => {
variant="ghost"
onClick={shift ? saveBboxToGallery : saveCanvasToGallery}
icon={<PiFloppyDiskBold />}
isLoading={isSaving.isTrue}
aria-label={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')}
tooltip={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')}
isDisabled={isBusy}
/>
);
});

View File

@@ -31,7 +31,7 @@ const TransformBox = memo(({ adapter }: { adapter: CanvasEntityAdapter }) => {
leftIcon={<PiArrowsOutBold />}
onClick={adapter.transformer.fitProxyRectToBbox}
isLoading={isProcessing}
loadingText={t('controlLayers.transform.fitToBbox')}
loadingText={t('controlLayers.transform.reset')}
variant="ghost"
>
{t('controlLayers.transform.fitToBbox')}

View File

@@ -18,9 +18,9 @@ export const BeginEndStepPct = memo(({ beginEndStepPct, onChange }: Props) => {
}, [onChange]);
return (
<FormControl orientation="horizontal" pe={2}>
<FormControl orientation="horizontal" pe={1}>
<InformationalPopover feature="controlNetBeginEnd">
<FormLabel m={0}>{t('controlLayers.beginEndStepPercentShort')}</FormLabel>
<FormLabel m={0}>{t('controlnet.beginEndStepPercentShort')}</FormLabel>
</InformationalPopover>
<CompositeRangeSlider
aria-label={ariaLabel}

View File

@@ -1,12 +1,11 @@
import { IconButton } from '@invoke-ai/ui-library';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddIPAdapter,
useAddRasterLayer,
useAddRegionalGuidance,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -18,12 +17,11 @@ type Props = {
export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const addInpaintMask = useAddInpaintMask();
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addIPAdapter = useAddIPAdapter();
const onClick = useCallback(() => {
switch (type) {
@@ -39,11 +37,11 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
case 'control_layer':
addControlLayer();
break;
case 'reference_image':
addGlobalReferenceImage();
case 'ip_adapter':
addIPAdapter();
break;
}
}, [addControlLayer, addGlobalReferenceImage, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
}, [addControlLayer, addIPAdapter, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
const label = useMemo(() => {
switch (type) {
@@ -55,8 +53,8 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
return t('controlLayers.addRasterLayer');
case 'control_layer':
return t('controlLayers.addControlLayer');
case 'reference_image':
return t('controlLayers.addGlobalReferenceImage');
case 'ip_adapter':
return t('controlLayers.addIPAdapter');
}
}, [type, t]);
@@ -69,7 +67,6 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
icon={<PiPlusBold />}
onClick={onClick}
alignSelf="stretch"
isDisabled={isBusy}
/>
);
});

View File

@@ -1,7 +1,6 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { entityDeleted } from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -11,7 +10,6 @@ export const CanvasEntityDeleteButton = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext();
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const onClick = useCallback(() => {
dispatch(entityDeleted({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
@@ -26,7 +24,6 @@ export const CanvasEntityDeleteButton = memo(() => {
icon={<PiTrashSimpleFill />}
onClick={onClick}
colorScheme="error"
isDisabled={isBusy}
/>
);
});

Some files were not shown because too many files have changed in this diff Show More