mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
28 Commits
ryan/flux-
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ed86c082a | ||
|
|
41ff0a5af6 | ||
|
|
c2ec65d582 | ||
|
|
a0bede02f4 | ||
|
|
8f3c09348d | ||
|
|
940269e60a | ||
|
|
4fe42e2e48 | ||
|
|
cbba28bdec | ||
|
|
2e8effe83f | ||
|
|
1b406e6d6a | ||
|
|
757251266d | ||
|
|
b91a9ec54c | ||
|
|
0745d7ecfa | ||
|
|
690bf4eb9d | ||
|
|
c238b60db9 | ||
|
|
de5e9f33fa | ||
|
|
91ada8fc4c | ||
|
|
26be5ea030 | ||
|
|
d038d635f1 | ||
|
|
5225c77908 | ||
|
|
33761066f1 | ||
|
|
931942754d | ||
|
|
04cdf00702 | ||
|
|
141de5755c | ||
|
|
9b52e9c116 | ||
|
|
d0e8ce9056 | ||
|
|
1bad546626 | ||
|
|
fa78ad5c5a |
@@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -81,9 +82,10 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@@ -176,9 +178,9 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
|
||||
@@ -28,10 +28,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(
|
||||
description="Mask for denoise model run. Values of 0.0 represent the regions to be fully denoised, and 1.0 "
|
||||
+ "represent the regions to be preserved."
|
||||
)
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
@@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
||||
@@ -181,7 +181,7 @@ class FieldDescriptions:
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
denoise_mask = "A mask of the region to apply the denoising process to. Values of 0.0 represent the regions to be fully denoised, and 1.0 represent the regions to be preserved."
|
||||
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
@@ -29,6 +29,8 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -39,7 +41,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="2.0.0",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -187,7 +189,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with transformer_info as transformer:
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
@@ -220,19 +231,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
device, and dtype for the inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
||||
represent the regions to be preserved.
|
||||
torch.Tensor | None: Inpaint mask.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
||||
# 1.0 represents the regions to be preserved.
|
||||
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
@@ -247,6 +252,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
|
||||
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""FLUX LoRA Loader Output"""
|
||||
|
||||
transformer: TransformerField = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
||||
@@ -129,18 +129,7 @@ class MergeMetadataInvocation(BaseInvocation):
|
||||
|
||||
|
||||
GENERATION_MODES = Literal[
|
||||
"txt2img",
|
||||
"img2img",
|
||||
"inpaint",
|
||||
"outpaint",
|
||||
"sdxl_txt2img",
|
||||
"sdxl_img2img",
|
||||
"sdxl_inpaint",
|
||||
"sdxl_outpaint",
|
||||
"flux_txt2img",
|
||||
"flux_img2img",
|
||||
"flux_inpaint",
|
||||
"flux_outpaint",
|
||||
"txt2img", "img2img", "inpaint", "outpaint", "sdxl_txt2img", "sdxl_img2img", "sdxl_inpaint", "sdxl_outpaint"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
@@ -202,7 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
|
||||
@@ -23,7 +23,7 @@ from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
@@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
|
||||
@@ -41,7 +41,6 @@ def denoise(
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
|
||||
@@ -28,19 +28,8 @@ class InpaintExtension:
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
timestep_cutoff = 0.5
|
||||
if timestep > timestep_cutoff:
|
||||
# Early in the denoising process, use the smaller mask.
|
||||
# I.e. treat gradient values as 0.0.
|
||||
mask = self._inpaint_mask.where(self._inpaint_mask >= (1.0 - 1e-3), 0.0)
|
||||
else:
|
||||
# After the cut-off, use the larger mask.
|
||||
# I.e. treat gradient values as 1.0.
|
||||
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + 1e-3), 1.0)
|
||||
# mask = (self._inpaint_mask > (0.0 + 1e-5)).float()
|
||||
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||
|
||||
@@ -31,24 +31,10 @@ def get_noise(
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
"""Shift the timestep schedule.
|
||||
|
||||
This is a simmilar idea to the beta schedule introduced in https://arxiv.org/abs/2305.08898. But, the function for
|
||||
remapping timesteps in [0, 1] is different.
|
||||
|
||||
Properties of this function:
|
||||
- Recommended sigma values: 1.0 <= sigma <= 3.0.
|
||||
- When sigma=1.0 and mu=0.0, the conversion is the identity function.
|
||||
- Increasing sigma results in an increasingly steep logistic function.
|
||||
- Adjusting mu shifts the midpoint of the logistic function.
|
||||
"""
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
"""Return a linear function that maps x to y given the coordsinates of two points on the line (x1, y1) and
|
||||
(x2, y2).
|
||||
"""
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
@@ -66,17 +52,9 @@ def get_schedule(
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# Select mu based on linear interpolation between two points.
|
||||
# Point 1: (image_seq_len=256, mu=0.5)
|
||||
# Point 2: (image_seq_len=4096, mu=1.15)
|
||||
# This has the effect of increasing mu as the image size increases. image_seq_len=4096 corresponds to an image
|
||||
# size of 1024x1024.
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
|
||||
# Shift the timesteps based on mu. Higher values of mu mean that there will be more timesteps early in the
|
||||
# denoising process (i.e. many small steps in the timestep range 1.0-0.9, and fewer large steps in the timestep
|
||||
# range 0.1-0.0).
|
||||
timesteps = time_shift(mu=mu, sigma=1.0, t=timesteps)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
@@ -116,10 +94,6 @@ def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoi
|
||||
|
||||
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||
|
||||
# clipped_timesteps = torch.tensor(timesteps)
|
||||
# clipped_timesteps = clipped_timesteps * (t_start_val - t_end_val) + t_end_val
|
||||
# clipped_timesteps = clipped_timesteps.tolist()
|
||||
|
||||
return clipped_timesteps
|
||||
|
||||
|
||||
|
||||
@@ -214,14 +214,8 @@ class LineartEdgeDetector:
|
||||
line = line.cpu().numpy()
|
||||
line = (line * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = 255 - line
|
||||
detected_map = line
|
||||
|
||||
# The lineart model often outputs a lot of almost-black noise. SD1.5 ControlNets seem to be OK with this, but
|
||||
# SDXL ControlNets are not - they need a cleaner map. 12 was experimentally determined to be a good threshold,
|
||||
# eliminating all the noise while keeping the actual edges. Other approaches to thresholding may be better,
|
||||
# for example stretching the contrast or removing noise.
|
||||
detected_map[detected_map < 12] = 0
|
||||
detected_map = 255 - detected_map
|
||||
|
||||
output = np_to_pil(detected_map)
|
||||
|
||||
return output
|
||||
return np_to_pil(detected_map)
|
||||
|
||||
@@ -260,14 +260,8 @@ class LineartAnimeEdgeDetector:
|
||||
line = cv2.resize(line, (width, height), interpolation=cv2.INTER_CUBIC)
|
||||
line = line.clip(0, 255).astype(np.uint8)
|
||||
|
||||
detected_map = 255 - line
|
||||
|
||||
# The lineart model often outputs a lot of almost-black noise. SD1.5 ControlNets seem to be OK with this, but
|
||||
# SDXL ControlNets are not - they need a cleaner map. 12 was experimentally determined to be a good threshold,
|
||||
# eliminating all the noise while keeping the actual edges. Other approaches to thresholding may be better,
|
||||
# for example stretching the contrast or removing noise.
|
||||
detected_map[detected_map < 12] = 0
|
||||
|
||||
detected_map = line
|
||||
detected_map = 255 - detected_map
|
||||
output = np_to_pil(detected_map)
|
||||
|
||||
return output
|
||||
|
||||
0
invokeai/backend/lora/conversions/__init__.py
Normal file
0
invokeai/backend/lora/conversions/__init__.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
|
||||
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
|
||||
|
||||
# Next, check that this is likely a FLUX model by spot-checking a few keys.
|
||||
expected_keys = [
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||
]
|
||||
all_expected_keys_present = all(k in state_dict for k in expected_keys)
|
||||
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
This function is based on:
|
||||
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
|
||||
"""
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||
|
||||
# Remove the "transformer." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||
|
||||
# Constants for FLUX.1
|
||||
num_double_layers = 19
|
||||
num_single_layers = 38
|
||||
# inner_dim = 3072
|
||||
# mlp_ratio = 4.0
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
layers[dst_key] = LoRALayer(
|
||||
dst_key,
|
||||
{
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
"alpha": torch.tensor(alpha),
|
||||
},
|
||||
)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
|
||||
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||
"""
|
||||
# We expect that either all src keys are present or none of them are. Verify this.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
assert all(keys_present) or not any(keys_present)
|
||||
|
||||
# If none of the keys are present, return early.
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
|
||||
sub_layers: list[LoRALayerBase] = []
|
||||
for src_layer_dict in src_layer_dicts:
|
||||
sub_layers.append(
|
||||
LoRALayer(
|
||||
layer_key="",
|
||||
values={
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
"alpha": torch.tensor(alpha),
|
||||
},
|
||||
)
|
||||
)
|
||||
assert len(src_layer_dict) == 0
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(layer_key=dst_qkv_key, lora_layers=sub_layers, concat_axis=0)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
|
||||
|
||||
# time_text_embed.text_embedder -> vector_in.
|
||||
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
|
||||
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
|
||||
|
||||
# time_text_embed.guidance_embedder -> guidance_in.
|
||||
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
|
||||
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
|
||||
|
||||
# context_embedder -> txt_in.
|
||||
add_lora_layer_if_present("context_embedder", "txt_in")
|
||||
|
||||
# x_embedder -> img_in.
|
||||
add_lora_layer_if_present("x_embedder", "img_in")
|
||||
|
||||
# Double transformer blocks.
|
||||
for i in range(num_double_layers):
|
||||
# norms.
|
||||
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
|
||||
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
|
||||
|
||||
# Q, K, V
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
f"double_blocks.{i}.img_attn.qkv",
|
||||
)
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
f"double_blocks.{i}.txt_attn.qkv",
|
||||
)
|
||||
|
||||
# ff img_mlp
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff.net.0.proj",
|
||||
f"double_blocks.{i}.img_mlp.0",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff.net.2",
|
||||
f"double_blocks.{i}.img_mlp.2",
|
||||
)
|
||||
|
||||
# ff txt_mlp
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff_context.net.0.proj",
|
||||
f"double_blocks.{i}.txt_mlp.0",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff_context.net.2",
|
||||
f"double_blocks.{i}.txt_mlp.2",
|
||||
)
|
||||
|
||||
# output projections.
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.attn.to_out.0",
|
||||
f"double_blocks.{i}.img_attn.proj",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.attn.to_add_out",
|
||||
f"double_blocks.{i}.txt_attn.proj",
|
||||
)
|
||||
|
||||
# Single transformer blocks.
|
||||
for i in range(num_single_layers):
|
||||
# norms
|
||||
add_lora_layer_if_present(
|
||||
f"single_transformer_blocks.{i}.norm.linear",
|
||||
f"single_blocks.{i}.modulation.lin",
|
||||
)
|
||||
|
||||
# Q, K, V, mlp
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"single_transformer_blocks.{i}.attn.to_q",
|
||||
f"single_transformer_blocks.{i}.attn.to_k",
|
||||
f"single_transformer_blocks.{i}.attn.to_v",
|
||||
f"single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
f"single_blocks.{i}.linear1",
|
||||
)
|
||||
|
||||
# Output projections.
|
||||
add_lora_layer_if_present(
|
||||
f"single_transformer_blocks.{i}.proj_out",
|
||||
f"single_blocks.{i}.linear2",
|
||||
)
|
||||
|
||||
# Final layer.
|
||||
add_lora_layer_if_present("proj_out", "final_layer.linear")
|
||||
|
||||
# Assert that all keys were processed.
|
||||
assert len(grouped_state_dict) == 0
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Groups the keys in the state dict by layer."""
|
||||
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key in state_dict:
|
||||
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
|
||||
parts = key.rsplit(".", maxsplit=2)
|
||||
layer_name = parts[0]
|
||||
key_name = ".".join(parts[1:])
|
||||
if layer_name not in layer_dict:
|
||||
layer_dict[layer_name] = {}
|
||||
layer_dict[layer_name][key_name] = state_dict[key]
|
||||
return layer_dict
|
||||
@@ -0,0 +1,81 @@
|
||||
import re
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Convert the state dict to the InvokeAI format.
|
||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
layer = any_lora_layer_from_state_dict(layer_key, layer_state_dict)
|
||||
layers[layer_key] = layer
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||
if match.group(4):
|
||||
s += f".{match.group(4)}"
|
||||
return s
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layer = any_lora_layer_from_state_dict(layer_key, values)
|
||||
layers[layer_key] = layer
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
154
invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py
Normal file
154
invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import bisect
|
||||
from typing import Dict, List, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
@@ -7,4 +8,4 @@ from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||
|
||||
46
invokeai/backend/lora/layers/concatenated_lora_layer.py
Normal file
46
invokeai/backend/lora/layers/concatenated_lora_layer.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
|
||||
|
||||
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
|
||||
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
|
||||
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_key: str, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
|
||||
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
|
||||
super().__init__(layer_key, values={})
|
||||
|
||||
self._lora_layers = lora_layers
|
||||
self._concat_axis = concat_axis
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original weight tensor here.
|
||||
layer_weights = [lora_layer.get_weight(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||
return torch.cat(layer_weights, dim=self._concat_axis)
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original bias tensor here.
|
||||
layer_biases = [lora_layer.get_bias(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||
layer_bias_is_none = [layer_bias is None for layer_bias in layer_biases]
|
||||
if any(layer_bias_is_none):
|
||||
assert all(layer_bias_is_none)
|
||||
return None
|
||||
|
||||
# Ignore the type error, because we have just verified that all layer biases are non-None.
|
||||
return torch.cat(layer_biases, dim=self._concat_axis)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return sum(lora_layer.calc_size() for lora_layer in self._lora_layers)
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
33
invokeai/backend/lora/layers/utils.py
Normal file
33
invokeai/backend/lora/layers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(layer_key: str, state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
if "lora_up.weight" in state_dict:
|
||||
# LoRA a.k.a LoCon
|
||||
return LoRALayer(layer_key, state_dict)
|
||||
elif "hada_w1_a" in state_dict:
|
||||
return LoHALayer(layer_key, state_dict)
|
||||
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
|
||||
return LoKRLayer(layer_key, state_dict)
|
||||
elif "diff" in state_dict:
|
||||
# Full a.k.a Diff
|
||||
return FullLayer(layer_key, state_dict)
|
||||
elif "on_input" in state_dict:
|
||||
return IA3Layer(layer_key, state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer(layer_key, state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
@@ -1,43 +1,17 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
def __init__(self, layers: Dict[str, AnyLoRALayer]):
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -46,234 +20,3 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
|
||||
148
invokeai/backend/lora/lora_patcher.py
Normal file
148
invokeai/backend/lora/lora_patcher.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model within a context manager.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
|
||||
that the LoRA patches do not need to be loaded into memory all at once.
|
||||
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del patch
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply a single LoRA patch to a model.
|
||||
:param model: The model to patch.
|
||||
:param patch: LoRA model to patch in.
|
||||
:param patch_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _get_submodule(
|
||||
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||
) -> tuple[str, torch.nn.Module]:
|
||||
"""Get the submodule corresponding to the given layer key.
|
||||
:param model: The model to search.
|
||||
:param layer_key: The layer key to search for.
|
||||
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
|
||||
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
|
||||
searching, but some legacy code still uses flattened keys.
|
||||
:return: A tuple containing the module key and the submodule.
|
||||
"""
|
||||
if not layer_key_is_flattened:
|
||||
return layer_key, model.get_submodule(layer_key)
|
||||
|
||||
# Handle flattened keys.
|
||||
assert "." not in layer_key
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = layer_key.split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return module_key, module
|
||||
@@ -5,8 +5,18 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -45,14 +55,39 @@ class LoRALoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
|
||||
# Load the state dict from the model file.
|
||||
if model_path.suffix == ".safetensors":
|
||||
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
# HACK(ryand): We assume alpha=8 for diffusers PEFT format models. These models are typically
|
||||
# distributed as a single file without the associated metadata containing the alpha value. We chose
|
||||
# alpha=8, because this is the default value in the PEFT library:
|
||||
# https://github.com/huggingface/peft/blob/7868d0372b86a6b9ac5f365b8f0eef2f2f5dedce/src/peft/tuners/lora/config.py#L169
|
||||
# Other reasonable defaults for alpha could be 1.0 or the rank of the LoRA. If our assumption is wrong,
|
||||
# the user will need to adjust the weight accordingly to account for the difference.
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=8)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
|
||||
|
||||
model.to(dtype=self._torch_dtype)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@@ -244,7 +248,9 @@ class ModelProbe(object):
|
||||
return ModelType.VAE
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
return ModelType.LoRA
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
return ModelType.LoRA
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
return ModelType.ControlNet
|
||||
@@ -554,12 +560,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for LoRA checkpoints."""
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("lycoris")
|
||||
if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
|
||||
# TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
|
||||
# ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
|
||||
# file, but the weight keys are in the diffusers format.
|
||||
return ModelFormat.Diffusers
|
||||
return ModelFormat.LyCORIS
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||
self.checkpoint
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
|
||||
# If we've gotten here, we assume that the model is a Stable Diffusion model.
|
||||
token_vector_length = lora_token_vector_length(self.checkpoint)
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
|
||||
@@ -5,32 +5,18 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from diffusers import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
(lora_model2, 0.4),
|
||||
]
|
||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
# unet with applied loras
|
||||
# unmodified unet
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
@@ -54,95 +40,6 @@ class ModelPatcher:
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(
|
||||
unet,
|
||||
loras=loras,
|
||||
prefix="lora_unet_",
|
||||
cached_weights=cached_weights,
|
||||
):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for lora_model, lora_weight in loras:
|
||||
LoRAExt.patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
lora=lora_model,
|
||||
lora_weight=lora_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
@@ -282,26 +179,6 @@ class ModelPatcher:
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
@@ -31,107 +30,14 @@ class LoRAExt(ExtensionBase):
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
self.patch_model(
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
lora=lora_model,
|
||||
lora_weight=self._weight,
|
||||
patch=lora_model,
|
||||
patch_weight=self._weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def patch_model(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if lora_weight == 0:
|
||||
return
|
||||
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= lora_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@@ -173,11 +173,102 @@
|
||||
"comparing": "Comparing",
|
||||
"comparingDesc": "Comparing two images",
|
||||
"enabled": "Enabled",
|
||||
"disabled": "Disabled",
|
||||
"placeholderSelectAModel": "Select a model",
|
||||
"reset": "Reset",
|
||||
"disabled": "Disabled"
|
||||
},
|
||||
"controlnet": {
|
||||
"controlAdapter_one": "Control Adapter",
|
||||
"controlAdapter_other": "Control Adapters",
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||
"addControlNet": "Add $t(common.controlNet)",
|
||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||
"addT2IAdapter": "Add $t(common.t2iAdapter)",
|
||||
"amult": "a_mult",
|
||||
"autoConfigure": "Auto configure processor",
|
||||
"balanced": "Balanced",
|
||||
"base": "Base",
|
||||
"beginEndStepPercent": "Begin / End Step Percentage",
|
||||
"beginEndStepPercentShort": "Begin/End %",
|
||||
"bgth": "bg_th",
|
||||
"canny": "Canny",
|
||||
"cannyDescription": "Canny edge detection",
|
||||
"colorMap": "Color",
|
||||
"colorMapDescription": "Generates a color map from the image",
|
||||
"coarse": "Coarse",
|
||||
"contentShuffle": "Content Shuffle",
|
||||
"contentShuffleDescription": "Shuffles the content in an image",
|
||||
"control": "Control",
|
||||
"controlMode": "Control Mode",
|
||||
"crop": "Crop",
|
||||
"delete": "Delete",
|
||||
"depthAnything": "Depth Anything",
|
||||
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
|
||||
"depthAnythingSmallV2": "Small V2",
|
||||
"depthMidas": "Depth (Midas)",
|
||||
"depthMidasDescription": "Depth map generation using Midas",
|
||||
"depthZoe": "Depth (Zoe)",
|
||||
"depthZoeDescription": "Depth map generation using Zoe",
|
||||
"detectResolution": "Detect Resolution",
|
||||
"duplicate": "Duplicate",
|
||||
"f": "F",
|
||||
"fill": "Fill",
|
||||
"h": "H",
|
||||
"face": "Face",
|
||||
"body": "Body",
|
||||
"hands": "Hands",
|
||||
"hed": "HED",
|
||||
"hedDescription": "Holistically-Nested Edge Detection",
|
||||
"hideAdvanced": "Hide Advanced",
|
||||
"highThreshold": "High Threshold",
|
||||
"imageResolution": "Image Resolution",
|
||||
"colorMapTileSize": "Tile Size",
|
||||
"importImageFromCanvas": "Import Image From Canvas",
|
||||
"importMaskFromCanvas": "Import Mask From Canvas",
|
||||
"large": "Large",
|
||||
"lineart": "Lineart",
|
||||
"lineartAnime": "Lineart Anime",
|
||||
"lineartAnimeDescription": "Anime-style lineart processing",
|
||||
"lineartDescription": "Converts image to lineart",
|
||||
"lowThreshold": "Low Threshold",
|
||||
"maxFaces": "Max Faces",
|
||||
"mediapipeFace": "Mediapipe Face",
|
||||
"mediapipeFaceDescription": "Face detection using Mediapipe",
|
||||
"megaControl": "Mega Control",
|
||||
"minConfidence": "Min Confidence",
|
||||
"mlsd": "M-LSD",
|
||||
"mlsdDescription": "Minimalist Line Segment Detector",
|
||||
"modelSize": "Model Size",
|
||||
"none": "None",
|
||||
"new": "New"
|
||||
"noneDescription": "No processing applied",
|
||||
"normalBae": "Normal BAE",
|
||||
"normalBaeDescription": "Normal BAE processing",
|
||||
"dwOpenpose": "DW Openpose",
|
||||
"dwOpenposeDescription": "Human pose estimation using DW Openpose",
|
||||
"pidi": "PIDI",
|
||||
"pidiDescription": "PIDI image processing",
|
||||
"processor": "Processor",
|
||||
"prompt": "Prompt",
|
||||
"resetControlImage": "Reset Control Image",
|
||||
"resize": "Resize",
|
||||
"resizeSimple": "Resize (Simple)",
|
||||
"resizeMode": "Resize Mode",
|
||||
"ipAdapterMethod": "Method",
|
||||
"full": "Full",
|
||||
"style": "Style Only",
|
||||
"composition": "Composition Only",
|
||||
"safe": "Safe",
|
||||
"saveControlImage": "Save Control Image",
|
||||
"scribble": "Scribble",
|
||||
"selectModel": "Select a model",
|
||||
"selectCLIPVisionModel": "Select a CLIP Vision model",
|
||||
"setControlImageDimensions": "Copy size to W/H (optimize for model)",
|
||||
"setControlImageDimensionsForce": "Copy size to W/H (ignore model)",
|
||||
"showAdvanced": "Show Advanced",
|
||||
"small": "Small",
|
||||
"toggleControlNet": "Toggle this ControlNet",
|
||||
"w": "W",
|
||||
"weight": "Weight"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -907,7 +998,6 @@
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"globalSettings": "Global Settings",
|
||||
"guidance": "Guidance",
|
||||
"height": "Height",
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
"images": "Images",
|
||||
@@ -930,14 +1020,8 @@
|
||||
"noModelForControlAdapter": "Control Adapter #{{number}} has no model selected.",
|
||||
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
|
||||
"noModelSelected": "No model selected",
|
||||
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
|
||||
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
|
||||
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
|
||||
"canvasManagerNotLoaded": "Canvas Manager not loaded",
|
||||
"canvasIsFiltering": "Canvas is filtering",
|
||||
"canvasIsTransforming": "Canvas is transforming",
|
||||
"canvasIsRasterizing": "Canvas is rasterizing",
|
||||
"canvasIsCompositing": "Canvas is compositing",
|
||||
"canvasBusy": "Canvas is busy",
|
||||
"noPrompts": "No prompts generated",
|
||||
"noNodesInGraph": "No nodes in graph",
|
||||
"systemDisconnected": "System disconnected",
|
||||
@@ -1300,13 +1384,6 @@
|
||||
"High CFG Scale values can result in over-saturation and distorted generation results. "
|
||||
]
|
||||
},
|
||||
"paramGuidance": {
|
||||
"heading": "Guidance",
|
||||
"paragraphs": [
|
||||
"Controls how much the prompt influences the generation process.",
|
||||
"High guidance values can result in over-saturation and high or low guidance may result in distorted generation results. Guidance only applies to FLUX DEV models."
|
||||
]
|
||||
},
|
||||
"paramCFGRescaleMultiplier": {
|
||||
"heading": "CFG Rescale Multiplier",
|
||||
"paragraphs": [
|
||||
@@ -1584,36 +1661,21 @@
|
||||
"storeNotInitialized": "Store is not initialized"
|
||||
},
|
||||
"controlLayers": {
|
||||
"regional": "Regional",
|
||||
"global": "Global",
|
||||
"canvas": "Canvas",
|
||||
"bookmark": "Bookmark for Quick Switch",
|
||||
"fitBboxToLayers": "Fit Bbox To Layers",
|
||||
"removeBookmark": "Remove Bookmark",
|
||||
"saveCanvasToGallery": "Save Canvas to Gallery",
|
||||
"saveBboxToGallery": "Save Bbox to Gallery",
|
||||
"newControlLayerFromBbox": "New Control Layer from Bbox",
|
||||
"newRasterLayerFromBbox": "New Raster Layer from Bbox",
|
||||
"sendBboxToRegionalIPAdapter": "Send Bbox to Regional IP Adapter",
|
||||
"sendBboxToGlobalIPAdapter": "Send Bbox to Global IP Adapter",
|
||||
"sendBboxToControlLayer": "Send Bbox to Control Layer",
|
||||
"sendBboxToRasterLayer": "Send Bbox to Raster Layer",
|
||||
"savedToGalleryOk": "Saved to Gallery",
|
||||
"savedToGalleryError": "Error saving to gallery",
|
||||
"newGlobalReferenceImageOk": "Created Global Reference Image",
|
||||
"newGlobalReferenceImageError": "Problem Creating Global Reference Image",
|
||||
"newRegionalReferenceImageOk": "Created Regional Reference Image",
|
||||
"newRegionalReferenceImageError": "Problem Creating Regional Reference Image",
|
||||
"newControlLayerOk": "Created Control Layer",
|
||||
"newControlLayerError": "Problem Creating Control Layer",
|
||||
"newRasterLayerOk": "Created Raster Layer",
|
||||
"newRasterLayerError": "Problem Creating Raster Layer",
|
||||
"pullBboxIntoLayerOk": "Bbox Pulled Into Layer",
|
||||
"pullBboxIntoLayerError": "Problem Pulling BBox Into Layer",
|
||||
"pullBboxIntoReferenceImageOk": "Bbox Pulled Into ReferenceImage",
|
||||
"pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage",
|
||||
"regionIsEmpty": "Selected region is empty",
|
||||
"mergeVisible": "Merge Visible",
|
||||
"mergeVisibleOk": "Merged visible layers",
|
||||
"mergeVisibleError": "Error merging visible layers",
|
||||
"clearHistory": "Clear History",
|
||||
"bboxOverlay": "Show Bbox Overlay",
|
||||
"generateMode": "Generate",
|
||||
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
|
||||
"composeMode": "Compose",
|
||||
@@ -1624,7 +1686,7 @@
|
||||
"clearCaches": "Clear Caches",
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"outputOnlyMaskedRegions": "Output Only Masked Regions",
|
||||
"compositeMaskedRegions": "Composite Masked Regions",
|
||||
"addLayer": "Add Layer",
|
||||
"duplicate": "Duplicate",
|
||||
"moveToFront": "Move to Front",
|
||||
@@ -1641,29 +1703,25 @@
|
||||
"enableAutoNegative": "Enable Auto Negative",
|
||||
"disableAutoNegative": "Disable Auto Negative",
|
||||
"deletePrompt": "Delete Prompt",
|
||||
"deleteReferenceImage": "Delete Reference Image",
|
||||
"resetRegion": "Reset Region",
|
||||
"debugLayers": "Debug Layers",
|
||||
"showHUD": "Show HUD",
|
||||
"rectangle": "Rectangle",
|
||||
"maskFill": "Mask Fill",
|
||||
"addPositivePrompt": "Add $t(controlLayers.prompt)",
|
||||
"addNegativePrompt": "Add $t(controlLayers.negativePrompt)",
|
||||
"addReferenceImage": "Add $t(controlLayers.referenceImage)",
|
||||
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||
"addRasterLayer": "Add $t(controlLayers.rasterLayer)",
|
||||
"addControlLayer": "Add $t(controlLayers.controlLayer)",
|
||||
"addInpaintMask": "Add $t(controlLayers.inpaintMask)",
|
||||
"addRegionalGuidance": "Add $t(controlLayers.regionalGuidance)",
|
||||
"addGlobalReferenceImage": "Add $t(controlLayers.globalReferenceImage)",
|
||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||
"raster": "Raster",
|
||||
"rasterLayer": "Raster Layer",
|
||||
"controlLayer": "Control Layer",
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"referenceImage": "Reference Image",
|
||||
"regionalReferenceImage": "Regional Reference Image",
|
||||
"globalReferenceImage": "Global Reference Image",
|
||||
"ipAdapter": "IP Adapter",
|
||||
"sendingToCanvas": "Sending to Canvas",
|
||||
"sendingToGallery": "Sending to Gallery",
|
||||
"sendToGallery": "Send To Gallery",
|
||||
@@ -1676,23 +1734,29 @@
|
||||
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
|
||||
"inpaintMask_withCount_one": "$t(controlLayers.inpaintMask)",
|
||||
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
|
||||
"globalReferenceImage_withCount_one": "$t(controlLayers.globalReferenceImage)",
|
||||
"ipAdapter_withCount_one": "$t(controlLayers.ipAdapter)",
|
||||
"rasterLayer_withCount_other": "Raster Layers",
|
||||
"controlLayer_withCount_other": "Control Layers",
|
||||
"inpaintMask_withCount_other": "Inpaint Masks",
|
||||
"regionalGuidance_withCount_other": "Regional Guidance",
|
||||
"globalReferenceImage_withCount_other": "Global Reference Images",
|
||||
"ipAdapter_withCount_other": "IP Adapters",
|
||||
"opacity": "Opacity",
|
||||
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
|
||||
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
|
||||
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
|
||||
"globalReferenceImages_withCount_hidden": "Global Reference Images ({{count}} hidden)",
|
||||
"globalIPAdapters_withCount_hidden": "Global IP Adapters ({{count}} hidden)",
|
||||
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
|
||||
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
|
||||
"controlLayers_withCount_visible": "Control Layers ({{count}})",
|
||||
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
|
||||
"globalReferenceImages_withCount_visible": "Global Reference Images ({{count}})",
|
||||
"globalIPAdapters_withCount_visible": "Global IP Adapters ({{count}})",
|
||||
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
|
||||
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
||||
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||
"globalIPAdapter": "Global $t(common.ipAdapter)",
|
||||
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
||||
"globalInitialImage": "Global Initial Image",
|
||||
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
|
||||
"layer": "Layer",
|
||||
"opacityFilter": "Opacity Filter",
|
||||
"clearProcessor": "Clear Processor",
|
||||
@@ -1723,27 +1787,7 @@
|
||||
"stagingOnCanvas": "Staging images on",
|
||||
"replaceLayer": "Replace Layer",
|
||||
"pullBboxIntoLayer": "Pull Bbox into Layer",
|
||||
"pullBboxIntoReferenceImage": "Pull Bbox into Reference Image",
|
||||
"showProgressOnCanvas": "Show Progress on Canvas",
|
||||
"prompt": "Prompt",
|
||||
"negativePrompt": "Negative Prompt",
|
||||
"beginEndStepPercentShort": "Begin/End %",
|
||||
"weight": "Weight",
|
||||
"controlMode": {
|
||||
"controlMode": "Control Mode",
|
||||
"balanced": "Balanced",
|
||||
"prompt": "Prompt",
|
||||
"control": "Control",
|
||||
"megaControl": "Mega Control"
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"ipAdapterMethod": "IP Adapter Method",
|
||||
"full": "Full",
|
||||
"style": "Style Only",
|
||||
"composition": "Composition Only"
|
||||
},
|
||||
"useSizeOptimizeForModel": "Copy size to W/H (optimize for model)",
|
||||
"useSizeIgnoreModel": "Copy size to W/H (ignore model)",
|
||||
"pullBboxIntoIPAdapter": "Pull Bbox into IP Adapter",
|
||||
"fill": {
|
||||
"fillColor": "Fill Color",
|
||||
"fillStyle": "Fill Style",
|
||||
@@ -1861,10 +1905,6 @@
|
||||
"label": "Snap to Grid",
|
||||
"on": "On",
|
||||
"off": "Off"
|
||||
},
|
||||
"preserveMask": {
|
||||
"label": "Preserve Masked Region",
|
||||
"alert": "Preserving Masked Region"
|
||||
}
|
||||
},
|
||||
"HUD": {
|
||||
@@ -1872,23 +1912,15 @@
|
||||
"scaledBbox": "Scaled Bbox",
|
||||
"autoSave": "Auto Save",
|
||||
"entityStatus": {
|
||||
"isFiltering": "{{title}} is filtering",
|
||||
"isTransforming": "{{title}} is transforming",
|
||||
"isLocked": "{{title}} is locked",
|
||||
"isHidden": "{{title}} is hidden",
|
||||
"isDisabled": "{{title}} is disabled",
|
||||
"isEmpty": "{{title}} is empty"
|
||||
"selectedEntity": "Selected Entity",
|
||||
"selectedEntityIs": "Selected Entity is",
|
||||
"isFiltering": "is filtering",
|
||||
"isTransforming": "is transforming",
|
||||
"isLocked": "is locked",
|
||||
"isHidden": "is hidden",
|
||||
"isDisabled": "is disabled",
|
||||
"enabled": "Enabled"
|
||||
}
|
||||
},
|
||||
"canvasContextMenu": {
|
||||
"saveToGalleryGroup": "Save To Gallery",
|
||||
"saveCanvasToGallery": "Save Canvas To Gallery",
|
||||
"saveBboxToGallery": "Save Bbox To Gallery",
|
||||
"bboxGroup": "Create From Bbox",
|
||||
"newGlobalReferenceImage": "New Global Reference Image",
|
||||
"newRegionalReferenceImage": "New Regional Reference Image",
|
||||
"newControlLayer": "New Control Layer",
|
||||
"newRasterLayer": "New Raster Layer"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
|
||||
@@ -19,6 +19,7 @@ import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/Cl
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { selectLanguage } from 'features/system/store/systemSelectors';
|
||||
import { AppContent } from 'features/ui/components/AppContent';
|
||||
@@ -137,6 +138,7 @@ const App = ({
|
||||
<StylePresetModal />
|
||||
<ClearQueueConfirmationsAlertDialog />
|
||||
<PreselectedImage selectedImage={selectedImage} />
|
||||
<SettingsModal />
|
||||
<RefreshAfterResetModal />
|
||||
<DeleteBoardModal />
|
||||
</ErrorBoundary>
|
||||
|
||||
@@ -5,7 +5,7 @@ import { canvasReset, rasterLayerAdded } from 'features/controlLayers/store/canv
|
||||
import { stagingAreaImageAccepted, stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
@@ -3,7 +3,7 @@ import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import type { Result } from 'common/util/result';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { isErr, withResult, withResultAsync } from 'common/util/result';
|
||||
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
selectIsStaging,
|
||||
@@ -11,7 +11,6 @@ import {
|
||||
stagingAreaStartedStaging,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
@@ -25,7 +24,7 @@ const log = logger('generation');
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'generation',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const model = state.params.model;
|
||||
@@ -48,11 +47,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
};
|
||||
|
||||
let buildGraphResult: Result<
|
||||
{
|
||||
g: Graph;
|
||||
noise: Invocation<'noise' | 'flux_denoise'>;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
},
|
||||
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
|
||||
Error
|
||||
>;
|
||||
|
||||
@@ -67,14 +62,11 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
case `sd-2`:
|
||||
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
|
||||
break;
|
||||
case `flux`:
|
||||
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
|
||||
break;
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
if (isErr(buildGraphResult)) {
|
||||
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
|
||||
abortStaging();
|
||||
return;
|
||||
@@ -88,7 +80,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
if (isErr(prepareBatchResult)) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
abortStaging();
|
||||
return;
|
||||
@@ -103,7 +95,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
const enqueueResult = await withResultAsync(() => req.unwrap());
|
||||
|
||||
if (enqueueResult.isErr()) {
|
||||
if (isErr(enqueueResult)) {
|
||||
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
|
||||
abortStaging();
|
||||
return;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { entityDeleted, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
@@ -53,9 +53,9 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
|
||||
// };
|
||||
|
||||
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
|
||||
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
|
||||
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
|
||||
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
|
||||
dispatch(ipaImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,27 +1,17 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
entityRasterized,
|
||||
entitySelected,
|
||||
ipaImageChanged,
|
||||
rasterLayerAdded,
|
||||
referenceImageAdded,
|
||||
referenceImageIPAdapterImageChanged,
|
||||
rgAdded,
|
||||
rgIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasRasterLayerState,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
@@ -65,10 +55,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
) {
|
||||
const { id } = overData.context;
|
||||
dispatch(
|
||||
referenceImageIPAdapterImageChanged({
|
||||
entityIdentifier: { id, type: 'reference_image' },
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO: activeData.payload.imageDTO })
|
||||
);
|
||||
return;
|
||||
}
|
||||
@@ -81,11 +68,11 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { id, referenceImageId } = overData.context;
|
||||
const { id, ipAdapterId } = overData.context;
|
||||
dispatch(
|
||||
rgIPAdapterImageChanged({
|
||||
entityIdentifier: { id, type: 'regional_guidance' },
|
||||
referenceImageId,
|
||||
ipAdapterId,
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
@@ -131,36 +118,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
overData.actionType === 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const state = getState();
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
|
||||
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
|
||||
const overrides: Partial<CanvasRegionalGuidanceState> = {
|
||||
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }],
|
||||
};
|
||||
dispatch(rgAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
overData.actionType === 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const state = getState();
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
|
||||
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
|
||||
const overrides: Partial<CanvasReferenceImageState> = {
|
||||
ipAdapter,
|
||||
};
|
||||
dispatch(referenceImageAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on Raster layer
|
||||
*/
|
||||
@@ -170,7 +127,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
|
||||
dispatch(entitySelected({ entityIdentifier }));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import {
|
||||
entityRasterized,
|
||||
entitySelected,
|
||||
referenceImageIPAdapterImageChanged,
|
||||
rgIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
@@ -101,15 +94,15 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
|
||||
if (postUploadAction?.type === 'SET_IPA_IMAGE') {
|
||||
const { id } = postUploadAction;
|
||||
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: { id, type: 'reference_image' }, imageDTO }));
|
||||
dispatch(ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO }));
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_RG_IP_ADAPTER_IMAGE') {
|
||||
const { id, referenceImageId } = postUploadAction;
|
||||
const { id, ipAdapterId } = postUploadAction;
|
||||
dispatch(
|
||||
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, referenceImageId, imageDTO })
|
||||
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, ipAdapterId, imageDTO })
|
||||
);
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||
return;
|
||||
@@ -121,17 +114,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'REPLACE_LAYER_WITH_IMAGE') {
|
||||
const { entityIdentifier } = postUploadAction;
|
||||
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
|
||||
dispatch(entitySelected({ entityIdentifier }));
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -6,18 +6,11 @@ import {
|
||||
bboxHeightChanged,
|
||||
bboxWidthChanged,
|
||||
controlLayerModelChanged,
|
||||
referenceImageIPAdapterModelChanged,
|
||||
ipaModelChanged,
|
||||
rgIPAdapterModelChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import {
|
||||
clipEmbedModelSelected,
|
||||
fluxVAESelected,
|
||||
modelChanged,
|
||||
refinerModelChanged,
|
||||
t5EncoderModelSelected,
|
||||
vaeSelected,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
|
||||
@@ -28,16 +21,13 @@ import type { Logger } from 'roarr';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonFluxVAEModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT5EncoderModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const log = logger('models');
|
||||
@@ -60,9 +50,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
handleSpandrelImageToImageModels(models, state, dispatch, log);
|
||||
handleIPAdapterModels(models, state, dispatch, log);
|
||||
handleT5EncoderModels(models, state, dispatch, log);
|
||||
handleCLIPEmbedModels(models, state, dispatch, log);
|
||||
handleFLUXVAEModels(models, state, dispatch, log);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -144,7 +131,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
// null is a valid VAE! it means "use the default with the main model"
|
||||
return;
|
||||
}
|
||||
const vaeModels = models.filter(isNonFluxVAEModelConfig);
|
||||
const vaeModels = models.filter(isVAEModelConfig);
|
||||
|
||||
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
|
||||
|
||||
@@ -194,22 +181,22 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
|
||||
|
||||
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const ipaModels = models.filter(isIPAdapterModelConfig);
|
||||
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
|
||||
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
|
||||
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||
dispatch(ipaModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||
});
|
||||
|
||||
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
|
||||
const isModelAvailable = ipaModels.some((m) => m.key === ipAdapter.model?.key);
|
||||
selectCanvasSlice(state).regions.entities.forEach((entity) => {
|
||||
entity.ipAdapters.forEach(({ id: ipAdapterId, model }) => {
|
||||
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
|
||||
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), ipAdapterId, modelConfig: null })
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -236,45 +223,3 @@ const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch,
|
||||
dispatch(postProcessingModelChanged(firstModel));
|
||||
}
|
||||
};
|
||||
|
||||
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const { t5EncoderModel: currentT5EncoderModel } = state.params;
|
||||
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
|
||||
const firstModel = t5EncoderModels[0] || null;
|
||||
|
||||
const isCurrentT5EncoderModelAvailable = currentT5EncoderModel
|
||||
? t5EncoderModels.some((m) => m.key === currentT5EncoderModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentT5EncoderModelAvailable) {
|
||||
dispatch(t5EncoderModelSelected(firstModel));
|
||||
}
|
||||
};
|
||||
|
||||
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const { clipEmbedModel: currentCLIPEmbedModel } = state.params;
|
||||
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
|
||||
const firstModel = CLIPEmbedModels[0] || null;
|
||||
|
||||
const isCurrentCLIPEmbedModelAvailable = currentCLIPEmbedModel
|
||||
? CLIPEmbedModels.some((m) => m.key === currentCLIPEmbedModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentCLIPEmbedModelAvailable) {
|
||||
dispatch(clipEmbedModelSelected(firstModel));
|
||||
}
|
||||
};
|
||||
|
||||
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const { fluxVAE: currentFLUXVAEModel } = state.params;
|
||||
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
|
||||
const firstModel = fluxVAEModels[0] || null;
|
||||
|
||||
const isCurrentFLUXVAEModelAvailable = currentFLUXVAEModel
|
||||
? fluxVAEModels.some((m) => m.key === currentFLUXVAEModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentFLUXVAEModelAvailable) {
|
||||
dispatch(fluxVAESelected(firstModel));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import type { ReadableAtom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
|
||||
/**
|
||||
* A fallback non-writable atom that always returns `false`, used when a nanostores atom is only conditionally available
|
||||
* in a hook or component.
|
||||
*/
|
||||
export const $false: ReadableAtom<boolean> = atom(false);
|
||||
/**
|
||||
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
|
||||
* in a hook or component.
|
||||
*/
|
||||
export const $true: ReadableAtom<boolean> = atom(true);
|
||||
@@ -114,9 +114,6 @@ export type AppConfig = {
|
||||
weight: NumericalParameterConfig;
|
||||
};
|
||||
};
|
||||
flux: {
|
||||
guidance: NumericalParameterConfig;
|
||||
};
|
||||
};
|
||||
|
||||
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;
|
||||
|
||||
@@ -155,19 +155,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
return styles;
|
||||
}, [isUploadDisabled, minSize]);
|
||||
|
||||
const openInNewTab = useCallback(
|
||||
(e: MouseEvent) => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (e.button !== 1) {
|
||||
return;
|
||||
}
|
||||
window.open(imageDTO.image_url, '_blank');
|
||||
},
|
||||
[imageDTO]
|
||||
);
|
||||
|
||||
return (
|
||||
<ImageContextMenu imageDTO={imageDTO}>
|
||||
{(ref) => (
|
||||
@@ -225,12 +212,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
onClick={onClick}
|
||||
onAuxClick={openInNewTab}
|
||||
/>
|
||||
<IAIDraggable data={draggableData} disabled={isDragDisabled || !imageDTO} onClick={onClick} />
|
||||
)}
|
||||
{children}
|
||||
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
|
||||
|
||||
@@ -10,9 +10,7 @@ const sx: SystemStyleObject = {
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: `drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-900))
|
||||
drop-shadow(0px 0px 0.3rem var(--invoke-colors-base-900))
|
||||
drop-shadow(0px 0px 0.3rem var(--invoke-colors-base-900))`,
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
};
|
||||
|
||||
@@ -29,6 +27,7 @@ const IAIDndImageIcon = (props: Props) => {
|
||||
onClick={onClick}
|
||||
aria-label={tooltip}
|
||||
icon={icon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={sx}
|
||||
data-testid={tooltip}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import type { AnimationProps } from 'framer-motion';
|
||||
import { motion } from 'framer-motion';
|
||||
import type { ReactNode } from 'react';
|
||||
@@ -28,13 +28,11 @@ const IAIDropOverlay = (props: Props) => {
|
||||
const motionId = useRef(uuidv4());
|
||||
return (
|
||||
<motion.div key={motionId.current} initial={initial} animate={animate} exit={exit}>
|
||||
<Flex position="absolute" top={0} right={0} bottom={0} left={0}>
|
||||
<Flex position="absolute" top={0} insetInlineStart={0} w="full" h="full">
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
insetInlineStart={0}
|
||||
w="full"
|
||||
h="full"
|
||||
bg="base.900"
|
||||
@@ -49,30 +47,29 @@ const IAIDropOverlay = (props: Props) => {
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0.5}
|
||||
right={0.5}
|
||||
insetInlineStart={0.5}
|
||||
insetInlineEnd={0.5}
|
||||
bottom={0.5}
|
||||
left={0.5}
|
||||
opacity={1}
|
||||
borderWidth={1.5}
|
||||
borderColor={isOver ? 'invokeYellow.300' : 'base.500'}
|
||||
borderWidth={2}
|
||||
borderColor={isOver ? 'base.300' : 'base.500'}
|
||||
borderRadius="base"
|
||||
borderStyle="dashed"
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
p={4}
|
||||
>
|
||||
<Text
|
||||
fontSize="xl"
|
||||
<Box
|
||||
fontSize="2xl"
|
||||
fontWeight="semibold"
|
||||
color={isOver ? 'invokeYellow.300' : 'base.500'}
|
||||
transform={isOver ? 'scale(1.1)' : 'scale(1)'}
|
||||
color={isOver ? 'base.50' : 'base.300'}
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
textAlign="center"
|
||||
>
|
||||
{label}
|
||||
</Text>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</motion.div>
|
||||
|
||||
@@ -30,9 +30,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
||||
ref={setNodeRef}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
insetInlineStart={0}
|
||||
w="full"
|
||||
h="full"
|
||||
pointerEvents={active ? 'auto' : 'none'}
|
||||
|
||||
@@ -30,7 +30,6 @@ export type Feature =
|
||||
| 'noiseUseCPU'
|
||||
| 'paramAspect'
|
||||
| 'paramCFGScale'
|
||||
| 'paramGuidance'
|
||||
| 'paramCFGRescaleMultiplier'
|
||||
| 'paramDenoisingStrength'
|
||||
| 'paramHeight'
|
||||
|
||||
@@ -68,7 +68,7 @@ export const useGlobalHotkeys = () => {
|
||||
useHotkeys(
|
||||
'1',
|
||||
() => {
|
||||
dispatch(setActiveTab('canvas'));
|
||||
dispatch(setActiveTab('generation'));
|
||||
addScope('canvas');
|
||||
removeScope('workflows');
|
||||
},
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { $true } from 'app/store/nanostores/util';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
@@ -19,25 +18,19 @@ import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import i18n from 'i18next';
|
||||
import { forEach, upperFirst } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import { useMemo } from 'react';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
|
||||
const LAYER_TYPE_TO_TKEY = {
|
||||
reference_image: 'controlLayers.referenceImage',
|
||||
ip_adapter: 'controlLayers.ipAdapter',
|
||||
inpaint_mask: 'controlLayers.inpaintMask',
|
||||
regional_guidance: 'controlLayers.regionalGuidance',
|
||||
raster_layer: 'controlLayers.rasterLayer',
|
||||
control_layer: 'controlLayers.controlLayer',
|
||||
raster_layer: 'controlLayers.raster',
|
||||
control_layer: 'controlLayers.globalControlAdapter',
|
||||
} as const;
|
||||
|
||||
const createSelector = (
|
||||
templates: Templates,
|
||||
isConnected: boolean,
|
||||
canvasIsFiltering: boolean,
|
||||
canvasIsTransforming: boolean,
|
||||
canvasIsRasterizing: boolean,
|
||||
canvasIsCompositing: boolean
|
||||
) =>
|
||||
const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy: boolean) =>
|
||||
createMemoizedSelector(
|
||||
[
|
||||
selectSystemSlice,
|
||||
@@ -126,17 +119,8 @@ const createSelector = (
|
||||
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
|
||||
}
|
||||
} else {
|
||||
if (canvasIsFiltering) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.canvasIsFiltering') });
|
||||
}
|
||||
if (canvasIsTransforming) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.canvasIsTransforming') });
|
||||
}
|
||||
if (canvasIsRasterizing) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.canvasIsRasterizing') });
|
||||
}
|
||||
if (canvasIsCompositing) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.canvasIsCompositing') });
|
||||
if (canvasIsBusy) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.canvasBusy') });
|
||||
}
|
||||
|
||||
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
|
||||
@@ -147,18 +131,6 @@ const createSelector = (
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
|
||||
if (model?.base === 'flux') {
|
||||
if (!params.t5EncoderModel) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') });
|
||||
}
|
||||
if (!params.clipEmbedModel) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noCLIPEmbedModelSelected') });
|
||||
}
|
||||
if (!params.fluxVAE) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noFLUXVAEModelSelected') });
|
||||
}
|
||||
}
|
||||
|
||||
canvas.controlLayers.entities
|
||||
.filter((controlLayer) => controlLayer.isEnabled)
|
||||
.forEach((controlLayer, i) => {
|
||||
@@ -189,7 +161,7 @@ const createSelector = (
|
||||
}
|
||||
});
|
||||
|
||||
canvas.referenceImages.entities
|
||||
canvas.ipAdapters.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
@@ -217,7 +189,7 @@ const createSelector = (
|
||||
}
|
||||
});
|
||||
|
||||
canvas.regionalGuidance.entities
|
||||
canvas.regions.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
@@ -230,14 +202,10 @@ const createSelector = (
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
|
||||
}
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
if (
|
||||
entity.positivePrompt === null &&
|
||||
entity.negativePrompt === null &&
|
||||
entity.referenceImages.length === 0
|
||||
) {
|
||||
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.ipAdapters.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
|
||||
}
|
||||
entity.referenceImages.forEach(({ ipAdapter }) => {
|
||||
entity.ipAdapters.forEach((ipAdapter) => {
|
||||
// Must have model
|
||||
if (!ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
@@ -278,25 +246,16 @@ const createSelector = (
|
||||
}
|
||||
);
|
||||
|
||||
const dummyAtom = atom(true);
|
||||
|
||||
export const useIsReadyToEnqueue = () => {
|
||||
const templates = useStore($templates);
|
||||
const isConnected = useStore($isConnected);
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
const canvasIsFiltering = useStore(canvasManager?.stateApi.$isFiltering ?? $true);
|
||||
const canvasIsTransforming = useStore(canvasManager?.stateApi.$isTransforming ?? $true);
|
||||
const canvasIsRasterizing = useStore(canvasManager?.stateApi.$isRasterizing ?? $true);
|
||||
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
|
||||
const canvasIsBusy = useStore(canvasManager?.$isBusy ?? dummyAtom);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
templates,
|
||||
isConnected,
|
||||
canvasIsFiltering,
|
||||
canvasIsTransforming,
|
||||
canvasIsRasterizing,
|
||||
canvasIsCompositing
|
||||
),
|
||||
[templates, isConnected, canvasIsFiltering, canvasIsTransforming, canvasIsRasterizing, canvasIsCompositing]
|
||||
() => createSelector(templates, isConnected, canvasIsBusy),
|
||||
[templates, isConnected, canvasIsBusy]
|
||||
);
|
||||
const value = useAppSelector(selector);
|
||||
return value;
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
import { getPrefixedId, nanoid } from 'features/controlLayers/konva/util';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNanoid = (prefix?: string) => {
|
||||
const id = useMemo(() => {
|
||||
if (prefix) {
|
||||
return getPrefixedId(prefix);
|
||||
} else {
|
||||
return nanoid();
|
||||
}
|
||||
}, [prefix]);
|
||||
|
||||
return id;
|
||||
};
|
||||
@@ -2,7 +2,8 @@ import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { Err, ErrResult, Ok, OkResult, withResult, withResultAsync } from './result';
|
||||
import type { ErrResult, OkResult } from './result';
|
||||
import { Err, isErr, isOk, Ok, withResult, withResultAsync } from './result'; // Adjust import as needed
|
||||
|
||||
const promiseify = <T>(fn: () => T): (() => Promise<T>) => {
|
||||
return () =>
|
||||
@@ -12,30 +13,28 @@ const promiseify = <T>(fn: () => T): (() => Promise<T>) => {
|
||||
};
|
||||
|
||||
describe('Result Utility Functions', () => {
|
||||
it('OkResult() should create an Ok result', () => {
|
||||
const result = OkResult(42);
|
||||
expect(result).toBeInstanceOf(Ok);
|
||||
expect(result.isOk()).toBe(true);
|
||||
expect(result.isErr()).toBe(false);
|
||||
expect(result.value).toBe(42);
|
||||
assert<Equals<Ok<number>, typeof result>>(result);
|
||||
it('Ok() should create an OkResult', () => {
|
||||
const result = Ok(42);
|
||||
expect(result).toEqual({ type: 'Ok', value: 42 });
|
||||
expect(isOk(result)).toBe(true);
|
||||
expect(isErr(result)).toBe(false);
|
||||
assert<Equals<OkResult<number>, typeof result>>(result);
|
||||
});
|
||||
|
||||
it('ErrResult() should create an Err result', () => {
|
||||
it('Err() should create an ErrResult', () => {
|
||||
const error = new Error('Something went wrong');
|
||||
const result = ErrResult(error);
|
||||
expect(result).toBeInstanceOf(Err);
|
||||
expect(result.isOk()).toBe(false);
|
||||
expect(result.isErr()).toBe(true);
|
||||
expect(result.error).toBe(error);
|
||||
assert<Equals<Err<Error>, typeof result>>(result);
|
||||
const result = Err(error);
|
||||
expect(result).toEqual({ type: 'Err', error });
|
||||
expect(isOk(result)).toBe(false);
|
||||
expect(isErr(result)).toBe(true);
|
||||
assert<Equals<ErrResult<Error>, typeof result>>(result);
|
||||
});
|
||||
|
||||
it('withResult() should return Ok on success', () => {
|
||||
const fn = () => 42;
|
||||
const result = withResult(fn);
|
||||
expect(result.isOk()).toBe(true);
|
||||
if (result.isOk()) {
|
||||
expect(isOk(result)).toBe(true);
|
||||
if (isOk(result)) {
|
||||
expect(result.value).toBe(42);
|
||||
}
|
||||
});
|
||||
@@ -45,8 +44,8 @@ describe('Result Utility Functions', () => {
|
||||
throw new Error('Failure');
|
||||
};
|
||||
const result = withResult(fn);
|
||||
expect(result.isErr()).toBe(true);
|
||||
if (result.isErr()) {
|
||||
expect(isErr(result)).toBe(true);
|
||||
if (isErr(result)) {
|
||||
expect(result.error.message).toBe('Failure');
|
||||
}
|
||||
});
|
||||
@@ -54,8 +53,8 @@ describe('Result Utility Functions', () => {
|
||||
it('withResultAsync() should return Ok on success', async () => {
|
||||
const fn = promiseify(() => 42);
|
||||
const result = await withResultAsync(fn);
|
||||
expect(result.isOk()).toBe(true);
|
||||
if (result.isOk()) {
|
||||
expect(isOk(result)).toBe(true);
|
||||
if (isOk(result)) {
|
||||
expect(result.value).toBe(42);
|
||||
}
|
||||
});
|
||||
@@ -65,8 +64,8 @@ describe('Result Utility Functions', () => {
|
||||
throw new Error('Async failure');
|
||||
});
|
||||
const result = await withResultAsync(fn);
|
||||
expect(result.isErr()).toBe(true);
|
||||
if (result.isErr()) {
|
||||
expect(isErr(result)).toBe(true);
|
||||
if (isErr(result)) {
|
||||
expect(result.error.message).toBe('Async failure');
|
||||
}
|
||||
});
|
||||
|
||||
@@ -2,81 +2,39 @@
|
||||
* Represents a successful result.
|
||||
* @template T The type of the value.
|
||||
*/
|
||||
export class Ok<T> {
|
||||
readonly value: T;
|
||||
constructor(value: T) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if this result is an `Ok` result.
|
||||
* @returns {this is Ok<T>} `true` if the result is an `Ok` result, otherwise `false`.
|
||||
*/
|
||||
isOk(): this is Ok<T> {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if this result is an `Err` result.
|
||||
* @returns {this is Err<never>} `true` if the result is an `Err` result, otherwise `false`.
|
||||
*/
|
||||
isErr(): this is Err<never> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
export type OkResult<T> = { type: 'Ok'; value: T };
|
||||
|
||||
/**
|
||||
* Represents a failed result.
|
||||
* @template E The type of the error.
|
||||
*/
|
||||
export class Err<E> {
|
||||
readonly error: E;
|
||||
constructor(error: E) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if this result is an `Ok` result.
|
||||
* @returns {this is Ok<never>} `true` if the result is an `Ok` result, otherwise `false`.
|
||||
*/
|
||||
isOk(): this is Ok<never> {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if this result is an `Err` result.
|
||||
* @returns {this is Err<E>} `true` if the result is an `Err` result, otherwise `false`.
|
||||
*/
|
||||
isErr(): this is Err<E> {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
export type ErrResult<E> = { type: 'Err'; error: E };
|
||||
|
||||
/**
|
||||
* A union type that represents either a successful result (`Ok`) or a failed result (`Err`).
|
||||
* @template T The type of the value in the `Ok` case.
|
||||
* @template E The type of the error in the `Err` case.
|
||||
*/
|
||||
export type Result<T, E = Error> = Ok<T> | Err<E>;
|
||||
export type Result<T, E = Error> = OkResult<T> | ErrResult<E>;
|
||||
|
||||
/**
|
||||
* Creates a successful result.
|
||||
* @template T The type of the value.
|
||||
* @param {T} value The value to wrap in an `Ok` result.
|
||||
* @returns {Ok<T>} The `Ok` result containing the value.
|
||||
* @returns {OkResult<T>} The `Ok` result containing the value.
|
||||
*/
|
||||
export function OkResult<T>(value: T): Ok<T> {
|
||||
return new Ok(value);
|
||||
export function Ok<T>(value: T): OkResult<T> {
|
||||
return { type: 'Ok', value };
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a failed result.
|
||||
* @template E The type of the error.
|
||||
* @param {E} error The error to wrap in an `Err` result.
|
||||
* @returns {Err<E>} The `Err` result containing the error.
|
||||
* @returns {ErrResult<E>} The `Err` result containing the error.
|
||||
*/
|
||||
export function ErrResult<E>(error: E): Err<E> {
|
||||
return new Err(error);
|
||||
export function Err<E>(error: E): ErrResult<E> {
|
||||
return { type: 'Err', error };
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -87,9 +45,9 @@ export function ErrResult<E>(error: E): Err<E> {
|
||||
*/
|
||||
export function withResult<T>(fn: () => T): Result<T> {
|
||||
try {
|
||||
return new Ok(fn());
|
||||
return Ok(fn());
|
||||
} catch (error) {
|
||||
return new Err(error instanceof Error ? error : new Error(String(error)));
|
||||
return Err(error instanceof Error ? error : new Error(String(error)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,8 +60,30 @@ export function withResult<T>(fn: () => T): Result<T> {
|
||||
export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T>> {
|
||||
try {
|
||||
const result = await fn();
|
||||
return new Ok(result);
|
||||
return Ok(result);
|
||||
} catch (error) {
|
||||
return new Err(error instanceof Error ? error : new Error(String(error)));
|
||||
return Err(error instanceof Error ? error : new Error(String(error)));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if a `Result` is an `Ok` result.
|
||||
* @template T The type of the value in the `Ok` result.
|
||||
* @template E The type of the error in the `Err` result.
|
||||
* @param {Result<T, E>} result The result to check.
|
||||
* @returns {result is OkResult<T>} `true` if the result is an `Ok` result, otherwise `false`.
|
||||
*/
|
||||
export function isOk<T, E>(result: Result<T, E>): result is OkResult<T> {
|
||||
return result.type === 'Ok';
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if a `Result` is an `Err` result.
|
||||
* @template T The type of the value in the `Ok` result.
|
||||
* @template E The type of the error in the `Err` result.
|
||||
* @param {Result<T, E>} result The result to check.
|
||||
* @returns {result is ErrResult<E>} `true` if the result is an `Err` result, otherwise `false`.
|
||||
*/
|
||||
export function isErr<T, E>(result: Result<T, E>): result is ErrResult<E> {
|
||||
return result.type === 'Err';
|
||||
}
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddGlobalReferenceImage,
|
||||
useAddInpaintMask,
|
||||
useAddIPAdapter,
|
||||
useAddRasterLayer,
|
||||
useAddRegionalGuidance,
|
||||
useAddRegionalReferenceImage,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
@@ -19,82 +16,27 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
const addRegionalGuidance = useAddRegionalGuidance();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const addGlobalReferenceImage = useAddGlobalReferenceImage();
|
||||
const addRegionalReferenceImage = useAddRegionalReferenceImage();
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const addIPAdapter = useAddIPAdapter();
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" justifyContent="center" gap={4}>
|
||||
<Flex position="relative" flexDir="column" gap={4} top="20%">
|
||||
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
|
||||
<Heading size="xs">{t('controlLayers.global')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addGlobalReferenceImage}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Heading size="xs">{t('controlLayers.regional')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addInpaintMask}
|
||||
>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidance}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalReferenceImage}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
|
||||
<Heading size="xs">{t('controlLayers.layer_other')}</Heading>
|
||||
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addControlLayer}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRasterLayer}
|
||||
>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flexDir="column" w="full" h="full" alignItems="center">
|
||||
<ButtonGroup position="relative" orientation="vertical" isAttached={false} top="20%">
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.globalIPAdapter')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectPreserveMask } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasAlertsPreserveMask = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const preserveMask = useAppSelector(selectPreserveMask);
|
||||
|
||||
if (!preserveMask) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content" alignSelf="flex-end">
|
||||
<AlertIcon />
|
||||
<AlertTitle>{t('controlLayers.settings.preserveMask.alert')}</AlertTitle>
|
||||
</Alert>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasAlertsPreserveMask.displayName = 'CanvasAlertsPreserveMask';
|
||||
@@ -1,53 +0,0 @@
|
||||
import { MenuGroup, MenuItem } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
useNewControlLayerFromBbox,
|
||||
useNewGlobalReferenceImageFromBbox,
|
||||
useNewRasterLayerFromBbox,
|
||||
useNewRegionalReferenceImageFromBbox,
|
||||
useSaveBboxToGallery,
|
||||
useSaveCanvasToGallery,
|
||||
} from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFloppyDiskBold, PiStackPlusFill } from 'react-icons/pi';
|
||||
|
||||
export const CanvasContextMenuGlobalMenuItems = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const saveCanvasToGallery = useSaveCanvasToGallery();
|
||||
const saveBboxToGallery = useSaveBboxToGallery();
|
||||
const newRegionalReferenceImageFromBbox = useNewRegionalReferenceImageFromBbox();
|
||||
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
|
||||
const newRasterLayerFromBbox = useNewRasterLayerFromBbox();
|
||||
const newControlLayerFromBbox = useNewControlLayerFromBbox();
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuGroup title={t('controlLayers.canvasContextMenu.saveToGalleryGroup')}>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveCanvasToGallery}>
|
||||
{t('controlLayers.canvasContextMenu.saveCanvasToGallery')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveBboxToGallery}>
|
||||
{t('controlLayers.canvasContextMenu.saveBboxToGallery')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.canvasContextMenu.bboxGroup')}>
|
||||
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newGlobalReferenceImageFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newRegionalReferenceImageFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newControlLayerFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newControlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiStackPlusFill />} isDisabled={isBusy} onClick={newRasterLayerFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newRasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasContextMenuGlobalMenuItems.displayName = 'CanvasContextMenuGlobalMenuItems';
|
||||
@@ -0,0 +1,49 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
useIsSavingCanvas,
|
||||
useSaveBboxAsControlLayer,
|
||||
useSaveBboxAsGlobalIPAdapter,
|
||||
useSaveBboxAsRasterLayer,
|
||||
useSaveBboxAsRegionalGuidanceIPAdapter,
|
||||
useSaveBboxToGallery,
|
||||
useSaveCanvasToGallery,
|
||||
} from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFloppyDiskBold, PiShareFatBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasContextMenuItems = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isSaving = useIsSavingCanvas();
|
||||
const saveCanvasToGallery = useSaveCanvasToGallery();
|
||||
const saveBboxToGallery = useSaveBboxToGallery();
|
||||
const saveBboxAsRegionalGuidanceIPAdapter = useSaveBboxAsRegionalGuidanceIPAdapter();
|
||||
const saveBboxAsIPAdapter = useSaveBboxAsGlobalIPAdapter();
|
||||
const saveBboxAsRasterLayer = useSaveBboxAsRasterLayer();
|
||||
const saveBboxAsControlLayer = useSaveBboxAsControlLayer();
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isLoading={isSaving.isTrue} onClick={saveCanvasToGallery}>
|
||||
{t('controlLayers.saveCanvasToGallery')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isLoading={isSaving.isTrue} onClick={saveBboxToGallery}>
|
||||
{t('controlLayers.saveBboxToGallery')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsIPAdapter}>
|
||||
{t('controlLayers.sendBboxToGlobalIPAdapter')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsRegionalGuidanceIPAdapter}>
|
||||
{t('controlLayers.sendBboxToRegionalIPAdapter')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsControlLayer}>
|
||||
{t('controlLayers.sendBboxToControlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiShareFatBold />} isLoading={isSaving.isTrue} onClick={saveBboxAsRasterLayer}>
|
||||
{t('controlLayers.sendBboxToRasterLayer')}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasContextMenuItems.displayName = 'CanvasContextMenuItems';
|
||||
@@ -1,43 +0,0 @@
|
||||
import { MenuGroup } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import {
|
||||
EntityIdentifierContext,
|
||||
useEntityIdentifierContext,
|
||||
} from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityTitle } from 'features/controlLayers/hooks/useEntityTitle';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { isFilterableEntityIdentifier, isTransformableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const CanvasContextMenuSelectedEntityMenuItemsContent = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const title = useEntityTitle(entityIdentifier);
|
||||
|
||||
return (
|
||||
<MenuGroup title={title}>
|
||||
{isFilterableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsFilter />}
|
||||
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsTransform />}
|
||||
<CanvasEntityMenuItemsDelete />
|
||||
</MenuGroup>
|
||||
);
|
||||
});
|
||||
CanvasContextMenuSelectedEntityMenuItemsContent.displayName = 'CanvasContextMenuSelectedEntityMenuItemsContent';
|
||||
|
||||
export const CanvasContextMenuSelectedEntityMenuItems = memo(() => {
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
|
||||
if (!selectedEntityIdentifier) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<EntityIdentifierContext.Provider value={selectedEntityIdentifier}>
|
||||
<CanvasContextMenuSelectedEntityMenuItemsContent />
|
||||
</EntityIdentifierContext.Provider>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasContextMenuSelectedEntityMenuItems.displayName = 'CanvasContextMenuSelectedEntityMenuItems';
|
||||
@@ -1,11 +1,6 @@
|
||||
import { Grid, GridItem } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type {
|
||||
AddControlLayerFromImageDropData,
|
||||
AddGlobalReferenceImageFromImageDropData,
|
||||
AddRasterLayerFromImageDropData,
|
||||
AddRegionalReferenceImageFromImageDropData,
|
||||
} from 'features/dnd/types';
|
||||
import type { AddControlLayerFromImageDropData, AddRasterLayerFromImageDropData } from 'features/dnd/types';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo } from 'react';
|
||||
|
||||
@@ -19,16 +14,6 @@ const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
|
||||
actionType: 'ADD_CONTROL_LAYER_FROM_IMAGE',
|
||||
};
|
||||
|
||||
const addRegionalReferenceImageFromImageDropData: AddRegionalReferenceImageFromImageDropData = {
|
||||
id: 'add-control-layer-from-image-drop-data',
|
||||
actionType: 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE',
|
||||
};
|
||||
|
||||
const addGlobalReferenceImageFromImageDropData: AddGlobalReferenceImageFromImageDropData = {
|
||||
id: 'add-control-layer-from-image-drop-data',
|
||||
actionType: 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE',
|
||||
};
|
||||
|
||||
export const CanvasDropArea = memo(() => {
|
||||
const imageViewer = useImageViewer();
|
||||
|
||||
@@ -38,29 +23,12 @@ export const CanvasDropArea = memo(() => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Grid
|
||||
gridTemplateRows="1fr 1fr"
|
||||
gridTemplateColumns="1fr 1fr"
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
pointerEvents="none"
|
||||
>
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable dropLabel="New Raster Layer" data={addRasterLayerFromImageDropData} />
|
||||
</GridItem>
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable dropLabel="New Control Layer" data={addControlLayerFromImageDropData} />
|
||||
</GridItem>
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable dropLabel="New Regional Reference Image" data={addRegionalReferenceImageFromImageDropData} />
|
||||
</GridItem>
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable dropLabel="New Global Reference Image" data={addGlobalReferenceImageFromImageDropData} />
|
||||
</GridItem>
|
||||
</Grid>
|
||||
<Flex position="absolute" top={0} right={0} bottom="50%" left={0} gap={2} pointerEvents="none">
|
||||
<IAIDroppable dropLabel="Create Raster Layer" data={addRasterLayerFromImageDropData} />
|
||||
</Flex>
|
||||
<Flex position="absolute" top="50%" right={0} bottom={0} left={0} gap={2} pointerEvents="none">
|
||||
<IAIDroppable dropLabel="Create Control Layer" data={addControlLayerFromImageDropData} />
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -11,9 +11,9 @@ export const CanvasEntityList = memo(() => {
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={2} data-testid="control-layers-layer-list" w="full" h="full">
|
||||
<IPAdapterList />
|
||||
<InpaintMaskList />
|
||||
<RegionalGuidanceEntityList />
|
||||
<IPAdapterList />
|
||||
<ControlLayerEntityList />
|
||||
<RasterLayerEntityList />
|
||||
</Flex>
|
||||
|
||||
@@ -1,27 +1,22 @@
|
||||
import { IconButton, Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddGlobalReferenceImage,
|
||||
useAddInpaintMask,
|
||||
useAddIPAdapter,
|
||||
useAddRasterLayer,
|
||||
useAddRegionalGuidance,
|
||||
useAddRegionalReferenceImage,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const addGlobalReferenceImage = useAddGlobalReferenceImage();
|
||||
const addInpaintMask = useAddInpaintMask();
|
||||
const addRegionalGuidance = useAddRegionalGuidance();
|
||||
const addRegionalReferenceImage = useAddRegionalReferenceImage();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const addIPAdapter = useAddIPAdapter();
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
@@ -36,30 +31,21 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
data-testid="control-layers-add-layer-menu-button"
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuGroup title={t('controlLayers.global')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isFLUX}>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.regional')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX}>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.layer_other')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isFLUX}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.globalIPAdapter')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { entityDuplicated } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
@@ -10,7 +9,6 @@ import { PiCopyFill } from 'react-icons/pi';
|
||||
export const EntityListSelectedEntityActionBarDuplicateButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const onClick = useCallback(() => {
|
||||
if (!selectedEntityIdentifier) {
|
||||
@@ -22,7 +20,7 @@ export const EntityListSelectedEntityActionBarDuplicateButton = memo(() => {
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
isDisabled={!selectedEntityIdentifier || isBusy}
|
||||
isDisabled={!selectedEntityIdentifier}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
|
||||
@@ -50,14 +50,7 @@ export const EntityListSelectedEntityActionBarFill = memo(() => {
|
||||
<Flex role="button" aria-label={t('controlLayers.maskFill')} tabIndex={-1} w={8} h={8}>
|
||||
<Tooltip label={t('controlLayers.maskFill')}>
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Box
|
||||
borderRadius="full"
|
||||
borderColor="base.300"
|
||||
w={6}
|
||||
h={6}
|
||||
borderWidth={1}
|
||||
bg={rgbColorToString(fill.color)}
|
||||
/>
|
||||
<Box borderRadius="full" w={6} h={6} borderWidth={1} bg={rgbColorToString(fill.color)} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
|
||||
@@ -137,7 +137,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
|
||||
<FormControl
|
||||
w="min-content"
|
||||
gap={2}
|
||||
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
|
||||
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'ip_adapter'}
|
||||
>
|
||||
<FormLabel m={0}>{t('controlLayers.opacity')}</FormLabel>
|
||||
<PopoverAnchor>
|
||||
@@ -167,7 +167,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
|
||||
position="absolute"
|
||||
insetInlineEnd={0}
|
||||
h="full"
|
||||
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
|
||||
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'ip_adapter'}
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
</NumberInput>
|
||||
@@ -185,7 +185,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
|
||||
marks={marks}
|
||||
formatValue={formatSliderValue}
|
||||
alwaysShowMarks
|
||||
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
|
||||
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'ip_adapter'}
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
import { ContextMenu, Flex, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useScopeOnFocus } from 'common/hooks/interactionScopes';
|
||||
import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask';
|
||||
import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus';
|
||||
import { CanvasAlertsSendingToGallery } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo';
|
||||
import { CanvasContextMenuGlobalMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems';
|
||||
import { CanvasContextMenuSelectedEntityMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuSelectedEntityMenuItems';
|
||||
import { CanvasContextMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItems';
|
||||
import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea';
|
||||
import { Filter } from 'features/controlLayers/components/Filters/Filter';
|
||||
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
|
||||
import { CanvasSelectedEntityStatusAlert } from 'features/controlLayers/components/HUD/CanvasSelectedEntityStatusAlert';
|
||||
import { SendingToGalleryAlert } from 'features/controlLayers/components/HUD/CanvasSendingToGalleryAlert';
|
||||
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
|
||||
import { StagingAreaIsStagingGate } from 'features/controlLayers/components/StagingArea/StagingAreaIsStagingGate';
|
||||
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
|
||||
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
|
||||
import { Transform } from 'features/controlLayers/components/Transform/Transform';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { TRANSPARENCY_CHECKERBOARD_PATTERN_DATAURL } from 'features/controlLayers/konva/patterns/transparency-checkerboard-pattern';
|
||||
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -28,8 +27,7 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
<MenuList>
|
||||
<CanvasContextMenuGlobalMenuItems />
|
||||
<CanvasContextMenuSelectedEntityMenuItems />
|
||||
<CanvasContextMenuItems />
|
||||
</MenuList>
|
||||
</CanvasManagerProviderGate>
|
||||
);
|
||||
@@ -50,9 +48,7 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<CanvasManagerProviderGate>
|
||||
<CanvasToolbar />
|
||||
</CanvasManagerProviderGate>
|
||||
<CanvasToolbar />
|
||||
<ContextMenu<HTMLDivElement> renderMenu={renderMenu}>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
@@ -63,6 +59,18 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
bg={dynamicGrid ? 'base.850' : 'base.900'}
|
||||
borderRadius="base"
|
||||
>
|
||||
{!dynamicGrid && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
borderRadius="base"
|
||||
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DATAURL}
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
opacity={0.1}
|
||||
/>
|
||||
)}
|
||||
<InvokeCanvasComponent />
|
||||
<CanvasManagerProviderGate>
|
||||
{showHUD && (
|
||||
@@ -71,9 +79,8 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
</Flex>
|
||||
)}
|
||||
<Flex flexDir="column" position="absolute" top={1} insetInlineEnd={1} pointerEvents="none" gap={2}>
|
||||
<CanvasAlertsSelectedEntityStatus />
|
||||
<CanvasAlertsPreserveMask />
|
||||
<CanvasAlertsSendingToGallery />
|
||||
<CanvasSelectedEntityStatusAlert />
|
||||
<SendingToGalleryAlert />
|
||||
</Flex>
|
||||
</CanvasManagerProviderGate>
|
||||
</Flex>
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import { Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
|
||||
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useIsSavingCanvas, usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useEntityFilter } from 'features/controlLayers/hooks/useEntityFilter';
|
||||
import {
|
||||
controlLayerBeginEndStepPctChanged,
|
||||
@@ -20,8 +18,8 @@ import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/s
|
||||
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiShootingStarBold, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { PiBoundingBoxBold, PiShootingStarBold } from 'react-icons/pi';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
|
||||
const selectControlAdapter = useMemo(
|
||||
@@ -72,12 +70,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
);
|
||||
|
||||
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const postUploadAction = useMemo<PostUploadAction>(
|
||||
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ postUploadAction });
|
||||
const isSaving = useIsSavingCanvas();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} position="relative" w="full">
|
||||
@@ -86,34 +79,19 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
<IconButton
|
||||
onClick={filter.start}
|
||||
isDisabled={filter.isDisabled}
|
||||
size="sm"
|
||||
alignSelf="stretch"
|
||||
variant="link"
|
||||
variant="ghost"
|
||||
aria-label={t('controlLayers.filter.filter')}
|
||||
tooltip={t('controlLayers.filter.filter')}
|
||||
icon={<PiShootingStarBold />}
|
||||
/>
|
||||
<IconButton
|
||||
onClick={pullBboxIntoLayer}
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
alignSelf="stretch"
|
||||
variant="link"
|
||||
isLoading={isSaving.isTrue}
|
||||
variant="ghost"
|
||||
aria-label={t('controlLayers.pullBboxIntoLayer')}
|
||||
tooltip={t('controlLayers.pullBboxIntoLayer')}
|
||||
icon={<PiBoundingBoxBold />}
|
||||
/>
|
||||
<IconButton
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
alignSelf="stretch"
|
||||
variant="link"
|
||||
aria-label={t('accessibility.uploadImage')}
|
||||
tooltip={t('accessibility.uploadImage')}
|
||||
icon={<PiUploadBold />}
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
|
||||
@@ -16,10 +16,10 @@ export const ControlLayerControlAdapterControlMode = memo(({ controlMode, onChan
|
||||
const { t } = useTranslation();
|
||||
const CONTROL_MODE_DATA = useMemo(
|
||||
() => [
|
||||
{ label: t('controlLayers.controlMode.balanced'), value: 'balanced' },
|
||||
{ label: t('controlLayers.controlMode.prompt'), value: 'more_prompt' },
|
||||
{ label: t('controlLayers.controlMode.control'), value: 'more_control' },
|
||||
{ label: t('controlLayers.controlMode.megaControl'), value: 'unbalanced' },
|
||||
{ label: t('controlnet.balanced'), value: 'balanced' },
|
||||
{ label: t('controlnet.prompt'), value: 'more_prompt' },
|
||||
{ label: t('controlnet.control'), value: 'more_control' },
|
||||
{ label: t('controlnet.megaControl'), value: 'unbalanced' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
@@ -44,7 +44,7 @@ export const ControlLayerControlAdapterControlMode = memo(({ controlMode, onChan
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="controlNetControlMode">
|
||||
<FormLabel m={0}>{t('controlLayers.controlMode.controlMode')}</FormLabel>
|
||||
<FormLabel m={0}>{t('controlnet.control')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox
|
||||
value={value}
|
||||
|
||||
@@ -51,7 +51,7 @@ export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onCha
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { ControlLayerMenuItemsConvertControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertControlToRaster';
|
||||
import { ControlLayerMenuItemsControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsControlToRaster';
|
||||
import { ControlLayerMenuItemsTransparencyEffect } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsTransparencyEffect';
|
||||
import { memo } from 'react';
|
||||
|
||||
@@ -13,7 +13,7 @@ export const ControlLayerMenuItems = memo(() => {
|
||||
<>
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<ControlLayerMenuItemsConvertControlToRaster />
|
||||
<ControlLayerMenuItemsControlToRaster />
|
||||
<ControlLayerMenuItemsTransparencyEffect />
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { controlLayerConvertedToRasterLayer } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLightningBold } from 'react-icons/pi';
|
||||
|
||||
export const ControlLayerMenuItemsConvertControlToRaster = memo(() => {
|
||||
export const ControlLayerMenuItemsControlToRaster = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const convertControlLayerToRasterLayer = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiLightningBold />} isDisabled={!isInteractable}>
|
||||
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiLightningBold />} isDisabled={isBusy}>
|
||||
{t('controlLayers.convertToRasterLayer')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
ControlLayerMenuItemsConvertControlToRaster.displayName = 'ControlLayerMenuItemsConvertControlToRaster';
|
||||
ControlLayerMenuItemsControlToRaster.displayName = 'ControlLayerMenuItemsControlToRaster';
|
||||
@@ -2,7 +2,6 @@ import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { controlLayerWithTransparencyEffectToggled } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -13,7 +12,6 @@ export const ControlLayerMenuItemsTransparencyEffect = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
const selectWithTransparencyEffect = useMemo(
|
||||
() =>
|
||||
createSelector(selectCanvasSlice, (canvas) => {
|
||||
@@ -28,7 +26,7 @@ export const ControlLayerMenuItemsTransparencyEffect = memo(() => {
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={onToggle} icon={<PiDropHalfBold />} isDisabled={!isInteractable}>
|
||||
<MenuItem onClick={onToggle} icon={<PiDropHalfBold />}>
|
||||
{withTransparencyEffect
|
||||
? t('controlLayers.disableTransparencyEffect')
|
||||
: t('controlLayers.enableTransparencyEffect')}
|
||||
|
||||
@@ -75,7 +75,7 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
|
||||
<Button
|
||||
variant="ghost"
|
||||
leftIcon={<PiShootingStarBold />}
|
||||
onClick={adapter.filterer.processImmediate}
|
||||
onClick={adapter.filterer.process}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.filter.process')}
|
||||
isDisabled={!isValid || autoProcessFilter}
|
||||
@@ -108,6 +108,7 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
|
||||
onClick={adapter.filterer.cancel}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.filter.cancel')}
|
||||
isDisabled={!isValid}
|
||||
>
|
||||
{t('controlLayers.filter.cancel')}
|
||||
</Button>
|
||||
|
||||
@@ -70,9 +70,8 @@ export const FilterSpandrel = ({ onChange, config }: Props) => {
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const firstModel = options[0];
|
||||
if (!config.model && firstModel) {
|
||||
onChangeModel(firstModel);
|
||||
if (!config.model) {
|
||||
onChangeModel(options[0] ?? null);
|
||||
}
|
||||
}, [config.model, onChangeModel, options]);
|
||||
|
||||
@@ -81,14 +80,14 @@ export const FilterSpandrel = ({ onChange, config }: Props) => {
|
||||
<FormControl w="full" orientation="vertical">
|
||||
<Flex w="full" alignItems="center">
|
||||
<FormLabel m={0} flexGrow={1}>
|
||||
{t('controlLayers.filter.spandrel_filter.autoScale')}
|
||||
{t('controlLayers.filter.spandrel.paramAutoScale')}
|
||||
</FormLabel>
|
||||
<Switch size="sm" isChecked={config.autoScale} onChange={onAutoscaleChanged} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('controlLayers.filter.spandrel_filter.autoScaleDesc')}</FormHelperText>
|
||||
<FormHelperText>{t('controlLayers.filter.spandrel.paramAutoScaleDesc')}</FormHelperText>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={!config.autoScale}>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.spandrel_filter.scale')}</FormLabel>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.spandrel.paramScale')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.scale}
|
||||
onChange={onScaleChanged}
|
||||
@@ -105,7 +104,7 @@ export const FilterSpandrel = ({ onChange, config }: Props) => {
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.spandrel_filter.model')}</FormLabel>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.spandrel.paramModel')}</FormLabel>
|
||||
<Tooltip label={tooltipLabel}>
|
||||
<Box w="full">
|
||||
<Combobox
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
|
||||
import { Alert, AlertDescription, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -27,9 +27,10 @@ const $isFilteringFallback = atom(false);
|
||||
type AlertData = {
|
||||
status: AlertStatus;
|
||||
title: string;
|
||||
description: string;
|
||||
};
|
||||
|
||||
const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapter }: ContentProps) => {
|
||||
const CanvasSelectedEntityStatusAlertContent = memo(({ entityIdentifier, adapter }: ContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
const title = useEntityTitle(entityIdentifier);
|
||||
const selectIsEnabled = useMemo(
|
||||
@@ -45,69 +46,67 @@ const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapte
|
||||
const isHidden = useEntityTypeIsHidden(entityIdentifier.type);
|
||||
const isFiltering = useStore(adapter.filterer?.$isFiltering ?? $isFilteringFallback);
|
||||
const isTransforming = useStore(adapter.transformer.$isTransforming);
|
||||
const isEmpty = useStore(adapter.$isEmpty);
|
||||
|
||||
const alert = useMemo<AlertData | null>(() => {
|
||||
if (isFiltering) {
|
||||
return {
|
||||
status: 'info',
|
||||
title: t('controlLayers.HUD.entityStatus.isFiltering', { title }),
|
||||
title,
|
||||
description: t('controlLayers.HUD.entityStatus.isFiltering'),
|
||||
};
|
||||
}
|
||||
|
||||
if (isTransforming) {
|
||||
return {
|
||||
status: 'info',
|
||||
title: t('controlLayers.HUD.entityStatus.isTransforming', { title }),
|
||||
};
|
||||
}
|
||||
|
||||
if (isEmpty) {
|
||||
return {
|
||||
status: 'info',
|
||||
title: t('controlLayers.HUD.entityStatus.isEmpty', { title }),
|
||||
title,
|
||||
description: t('controlLayers.HUD.entityStatus.isTransforming'),
|
||||
};
|
||||
}
|
||||
|
||||
if (isHidden) {
|
||||
return {
|
||||
status: 'warning',
|
||||
title: t('controlLayers.HUD.entityStatus.isHidden', { title }),
|
||||
title,
|
||||
description: t('controlLayers.HUD.entityStatus.isHidden'),
|
||||
};
|
||||
}
|
||||
|
||||
if (isLocked) {
|
||||
return {
|
||||
status: 'warning',
|
||||
title: t('controlLayers.HUD.entityStatus.isLocked', { title }),
|
||||
title,
|
||||
description: t('controlLayers.HUD.entityStatus.isLocked'),
|
||||
};
|
||||
}
|
||||
|
||||
if (!isEnabled) {
|
||||
return {
|
||||
status: 'warning',
|
||||
title: t('controlLayers.HUD.entityStatus.isDisabled', { title }),
|
||||
title,
|
||||
description: t('controlLayers.HUD.entityStatus.isDisabled'),
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}, [isFiltering, isTransforming, isEmpty, isHidden, isLocked, isEnabled, title, t]);
|
||||
}, [isFiltering, isTransforming, isHidden, isLocked, isEnabled, title, t]);
|
||||
|
||||
if (!alert) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Alert status={alert.status} borderRadius="base" fontSize="sm" shadow="md" w="fit-content" alignSelf="flex-end">
|
||||
<Alert status={alert.status} borderRadius="base" fontSize="sm" shadow="md">
|
||||
<AlertIcon />
|
||||
<AlertTitle>{alert.title}</AlertTitle>
|
||||
<AlertDescription>{alert.description}.</AlertDescription>
|
||||
</Alert>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasAlertsSelectedEntityStatusContent.displayName = 'CanvasAlertsSelectedEntityStatusContent';
|
||||
CanvasSelectedEntityStatusAlertContent.displayName = 'CanvasSelectedEntityStatusAlertContent';
|
||||
|
||||
export const CanvasAlertsSelectedEntityStatus = memo(() => {
|
||||
export const CanvasSelectedEntityStatusAlert = memo(() => {
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const adapter = useEntityAdapterSafe(selectedEntityIdentifier);
|
||||
|
||||
@@ -115,7 +114,7 @@ export const CanvasAlertsSelectedEntityStatus = memo(() => {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <CanvasAlertsSelectedEntityStatusContent entityIdentifier={selectedEntityIdentifier} adapter={adapter} />;
|
||||
return <CanvasSelectedEntityStatusAlertContent entityIdentifier={selectedEntityIdentifier} adapter={adapter} />;
|
||||
});
|
||||
|
||||
CanvasAlertsSelectedEntityStatus.displayName = 'CanvasAlertsSelectedEntityStatus';
|
||||
CanvasSelectedEntityStatusAlert.displayName = 'CanvasSelectedEntityStatusAlert';
|
||||
@@ -38,7 +38,7 @@ const ActivateImageViewerButton = (props: PropsWithChildren) => {
|
||||
);
|
||||
};
|
||||
|
||||
export const CanvasAlertsSendingToGallery = () => {
|
||||
export const SendingToGalleryAlert = () => {
|
||||
const { t } = useTranslation();
|
||||
const destination = useCurrentDestination();
|
||||
const isVisible = useMemo(() => {
|
||||
@@ -68,7 +68,7 @@ const ActivateCanvasButton = (props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const imageViewer = useImageViewer();
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setActiveTab('canvas'));
|
||||
dispatch(setActiveTab('generation'));
|
||||
setRightPanelTabToLayers();
|
||||
imageViewer.close();
|
||||
}, [dispatch, imageViewer]);
|
||||
@@ -79,7 +79,7 @@ const ActivateCanvasButton = (props: PropsWithChildren) => {
|
||||
);
|
||||
};
|
||||
|
||||
export const CanvasAlertsSendingToCanvas = () => {
|
||||
export const SendingToCanvasAlert = () => {
|
||||
const { t } = useTranslation();
|
||||
const destination = useCurrentDestination();
|
||||
const isVisible = useMemo(() => {
|
||||
@@ -136,16 +136,7 @@ const AlertWrapper = ({
|
||||
onMouseEnter={isHovered.setTrue}
|
||||
onMouseLeave={isHovered.setFalse}
|
||||
>
|
||||
<Alert
|
||||
status="warning"
|
||||
flexDir="column"
|
||||
pointerEvents="auto"
|
||||
borderRadius="base"
|
||||
fontSize="sm"
|
||||
shadow="md"
|
||||
w="fit-content"
|
||||
alignSelf="flex-end"
|
||||
>
|
||||
<Alert status="warning" flexDir="column" pointerEvents="auto" borderRadius="base" fontSize="sm" shadow="md">
|
||||
<Flex w="full" alignItems="center">
|
||||
<AlertIcon />
|
||||
<AlertTitle>{title}</AlertTitle>
|
||||
@@ -13,7 +13,7 @@ type Props = {
|
||||
};
|
||||
|
||||
export const IPAdapter = memo(({ id }: Props) => {
|
||||
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'reference_image' }), [id]);
|
||||
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'ip_adapter' }), [id]);
|
||||
|
||||
return (
|
||||
<EntityIdentifierContext.Provider value={entityIdentifier}>
|
||||
|
||||
@@ -5,7 +5,6 @@ import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
import { useNanoid } from 'common/hooks/useNanoid';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectOptimalDimension } from 'features/controlLayers/store/selectors';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
@@ -20,86 +19,90 @@ import type { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
type Props = {
|
||||
image: ImageWithDims | null;
|
||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||
ipAdapterId: string; // required for the dnd/upload interactions
|
||||
droppableData: TypesafeDroppableData;
|
||||
postUploadAction: PostUploadAction;
|
||||
};
|
||||
|
||||
export const IPAdapterImagePreview = memo(({ image, onChangeImage, droppableData, postUploadAction }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isConnected = useStore($isConnected);
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const shift = useShiftModifier();
|
||||
const dndId = useNanoid('ip_adapter_image_preview');
|
||||
export const IPAdapterImagePreview = memo(
|
||||
({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isConnected = useStore($isConnected);
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const shift = useShiftModifier();
|
||||
|
||||
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
|
||||
image?.image_name ?? skipToken
|
||||
);
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
|
||||
image?.image_name ?? skipToken
|
||||
);
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
|
||||
const handleSetControlImageToDimensions = useCallback(() => {
|
||||
if (!controlImage) {
|
||||
return;
|
||||
}
|
||||
const handleSetControlImageToDimensions = useCallback(() => {
|
||||
if (!controlImage) {
|
||||
return;
|
||||
}
|
||||
|
||||
const options = { updateAspectRatio: true, clamp: true };
|
||||
if (shift) {
|
||||
const { width, height } = controlImage;
|
||||
dispatch(bboxWidthChanged({ width, ...options }));
|
||||
dispatch(bboxHeightChanged({ height, ...options }));
|
||||
} else {
|
||||
const { width, height } = calculateNewSize(
|
||||
controlImage.width / controlImage.height,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
dispatch(bboxWidthChanged({ width, ...options }));
|
||||
dispatch(bboxHeightChanged({ height, ...options }));
|
||||
}
|
||||
}, [controlImage, dispatch, optimalDimension, shift]);
|
||||
const options = { updateAspectRatio: true, clamp: true };
|
||||
if (shift) {
|
||||
const { width, height } = controlImage;
|
||||
dispatch(bboxWidthChanged({ width, ...options }));
|
||||
dispatch(bboxHeightChanged({ height, ...options }));
|
||||
} else {
|
||||
const { width, height } = calculateNewSize(
|
||||
controlImage.width / controlImage.height,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
dispatch(bboxWidthChanged({ width, ...options }));
|
||||
dispatch(bboxHeightChanged({ height, ...options }));
|
||||
}
|
||||
}, [controlImage, dispatch, optimalDimension, shift]);
|
||||
|
||||
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
|
||||
if (controlImage) {
|
||||
return {
|
||||
id: dndId,
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO: controlImage },
|
||||
};
|
||||
}
|
||||
}, [controlImage, dndId]);
|
||||
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
|
||||
if (controlImage) {
|
||||
return {
|
||||
id: ipAdapterId,
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO: controlImage },
|
||||
};
|
||||
}
|
||||
}, [controlImage, ipAdapterId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected && isErrorControlImage) {
|
||||
handleResetControlImage();
|
||||
}
|
||||
}, [handleResetControlImage, isConnected, isErrorControlImage]);
|
||||
useEffect(() => {
|
||||
if (isConnected && isErrorControlImage) {
|
||||
handleResetControlImage();
|
||||
}
|
||||
}, [handleResetControlImage, isConnected, isErrorControlImage]);
|
||||
|
||||
return (
|
||||
<Flex position="relative" w="full" h="full" alignItems="center">
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
return (
|
||||
<Flex position="relative" w={36} h={36} alignItems="center">
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
{controlImage && (
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('common.reset')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={shift ? t('controlLayers.useSizeIgnoreModel') : t('controlLayers.useSizeOptimizeForModel')}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
{controlImage && (
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={<PiRulerBold size={16} />}
|
||||
tooltip={
|
||||
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
||||
@@ -9,10 +9,10 @@ import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/cont
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.referenceImages.entities.map(mapId).reverse();
|
||||
return canvas.ipAdapters.entities.map(mapId).reverse();
|
||||
});
|
||||
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
|
||||
return selectedEntityIdentifier?.type === 'reference_image';
|
||||
return selectedEntityIdentifier?.type === 'ip_adapter';
|
||||
});
|
||||
|
||||
export const IPAdapterList = memo(() => {
|
||||
@@ -25,7 +25,7 @@ export const IPAdapterList = memo(() => {
|
||||
|
||||
if (ipaIds.length > 0) {
|
||||
return (
|
||||
<CanvasEntityGroupList type="reference_image" isSelected={isSelected}>
|
||||
<CanvasEntityGroupList type="ip_adapter" isSelected={isSelected}>
|
||||
{ipaIds.map((id) => (
|
||||
<IPAdapter key={id} id={id} />
|
||||
))}
|
||||
|
||||
@@ -16,9 +16,9 @@ export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const options: { label: string; value: IPMethodV2 }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlLayers.ipAdapterMethod.full'), value: 'full' },
|
||||
{ label: `${t('controlLayers.ipAdapterMethod.style')} (${t('common.beta')})`, value: 'style' },
|
||||
{ label: `${t('controlLayers.ipAdapterMethod.composition')} (${t('common.beta')})`, value: 'composition' },
|
||||
{ label: t('controlnet.full'), value: 'full' },
|
||||
{ label: `${t('controlnet.style')} (${t('common.beta')})`, value: 'style' },
|
||||
{ label: `${t('controlnet.composition')} (${t('common.beta')})`, value: 'composition' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
@@ -34,7 +34,7 @@ export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="ipAdapterMethod">
|
||||
<FormLabel>{t('controlLayers.ipAdapterMethod.ipAdapterMethod')}</FormLabel>
|
||||
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox value={value} options={options} onChange={_onChange} />
|
||||
</FormControl>
|
||||
|
||||
@@ -70,12 +70,12 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
<Flex gap={4}>
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
@@ -86,7 +86,7 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
|
||||
<Combobox
|
||||
options={CLIP_VISION_OPTIONS}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
placeholder={t('controlnet.selectCLIPVisionModel')}
|
||||
value={clipVisionModelValue}
|
||||
onChange={_onChangeCLIPVisionModel}
|
||||
/>
|
||||
|
||||
@@ -6,15 +6,14 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useIsSavingCanvas, usePullBboxIntoIPAdapter } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import {
|
||||
referenceImageIPAdapterBeginEndStepPctChanged,
|
||||
referenceImageIPAdapterCLIPVisionModelChanged,
|
||||
referenceImageIPAdapterImageChanged,
|
||||
referenceImageIPAdapterMethodChanged,
|
||||
referenceImageIPAdapterModelChanged,
|
||||
referenceImageIPAdapterWeightChanged,
|
||||
ipaBeginEndStepPctChanged,
|
||||
ipaCLIPVisionModelChanged,
|
||||
ipaImageChanged,
|
||||
ipaMethodChanged,
|
||||
ipaModelChanged,
|
||||
ipaWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
|
||||
@@ -30,7 +29,7 @@ import { IPAdapterModel } from './IPAdapterModel';
|
||||
export const IPAdapterSettings = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('reference_image');
|
||||
const entityIdentifier = useEntityIdentifierContext('ip_adapter');
|
||||
const selectIPAdapter = useMemo(
|
||||
() => createSelector(selectCanvasSlice, (s) => selectEntityOrThrow(s, entityIdentifier).ipAdapter),
|
||||
[entityIdentifier]
|
||||
@@ -39,42 +38,42 @@ export const IPAdapterSettings = memo(() => {
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
dispatch(referenceImageIPAdapterBeginEndStepPctChanged({ entityIdentifier, beginEndStepPct }));
|
||||
dispatch(ipaBeginEndStepPctChanged({ entityIdentifier, beginEndStepPct }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const onChangeWeight = useCallback(
|
||||
(weight: number) => {
|
||||
dispatch(referenceImageIPAdapterWeightChanged({ entityIdentifier, weight }));
|
||||
dispatch(ipaWeightChanged({ entityIdentifier, weight }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const onChangeIPMethod = useCallback(
|
||||
(method: IPMethodV2) => {
|
||||
dispatch(referenceImageIPAdapterMethodChanged({ entityIdentifier, method }));
|
||||
dispatch(ipaMethodChanged({ entityIdentifier, method }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig) => {
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
|
||||
dispatch(ipaModelChanged({ entityIdentifier, modelConfig }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const onChangeCLIPVisionModel = useCallback(
|
||||
(clipVisionModel: CLIPVisionModelV2) => {
|
||||
dispatch(referenceImageIPAdapterCLIPVisionModelChanged({ entityIdentifier, clipVisionModel }));
|
||||
dispatch(ipaCLIPVisionModelChanged({ entityIdentifier, clipVisionModel }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const onChangeImage = useCallback(
|
||||
(imageDTO: ImageDTO | null) => {
|
||||
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier, imageDTO }));
|
||||
dispatch(ipaImageChanged({ entityIdentifier, imageDTO }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
);
|
||||
@@ -87,13 +86,13 @@ export const IPAdapterSettings = memo(() => {
|
||||
() => ({ type: 'SET_IPA_IMAGE', id: entityIdentifier.id }),
|
||||
[entityIdentifier.id]
|
||||
);
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoIPAdapter(entityIdentifier);
|
||||
const isSaving = useIsSavingCanvas();
|
||||
|
||||
return (
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<Flex flexDir="column" gap={4} position="relative" w="full">
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
|
||||
<IPAdapterModel
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
@@ -104,23 +103,24 @@ export const IPAdapterSettings = memo(() => {
|
||||
</Box>
|
||||
<IconButton
|
||||
onClick={pullBboxIntoIPAdapter}
|
||||
isDisabled={isBusy}
|
||||
isLoading={isSaving.isTrue}
|
||||
variant="ghost"
|
||||
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
aria-label={t('controlLayers.pullBboxIntoIPAdapter')}
|
||||
tooltip={t('controlLayers.pullBboxIntoIPAdapter')}
|
||||
icon={<PiBoundingBoxBold />}
|
||||
/>
|
||||
</Flex>
|
||||
<Flex gap={2} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<Flex gap={4} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={3} w="full">
|
||||
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
|
||||
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
|
||||
<IPAdapterImagePreview
|
||||
image={ipAdapter.image ?? null}
|
||||
onChangeImage={onChangeImage}
|
||||
ipAdapterId={entityIdentifier.id}
|
||||
droppableData={droppableData}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
@@ -4,7 +4,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { RasterLayerMenuItemsConvertRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertRasterToControl';
|
||||
import { RasterLayerMenuItemsRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsRasterToControl';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const RasterLayerMenuItems = memo(() => {
|
||||
@@ -12,7 +12,7 @@ export const RasterLayerMenuItems = memo(() => {
|
||||
<>
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<RasterLayerMenuItemsConvertRasterToControl />
|
||||
<RasterLayerMenuItemsRasterToControl />
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<MenuDivider />
|
||||
|
||||
@@ -2,20 +2,20 @@ import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { rasterLayerConvertedToControlLayer } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLightningBold } from 'react-icons/pi';
|
||||
|
||||
export const RasterLayerMenuItemsConvertRasterToControl = memo(() => {
|
||||
export const RasterLayerMenuItemsRasterToControl = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
const convertRasterLayerToControlLayer = useCallback(() => {
|
||||
dispatch(
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
@@ -27,10 +27,10 @@ export const RasterLayerMenuItemsConvertRasterToControl = memo(() => {
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={onClick} icon={<PiLightningBold />} isDisabled={!isInteractable}>
|
||||
<MenuItem onClick={convertRasterLayerToControlLayer} icon={<PiLightningBold />} isDisabled={isBusy}>
|
||||
{t('controlLayers.convertToControlLayer')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
RasterLayerMenuItemsConvertRasterToControl.displayName = 'RasterLayerMenuItemsConvertRasterToControl';
|
||||
RasterLayerMenuItemsRasterToControl.displayName = 'RasterLayerMenuItemsRasterToControl';
|
||||
@@ -33,7 +33,7 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
|
||||
onClick={addRegionalGuidancePositivePrompt}
|
||||
isDisabled={!validActions.canAddPositivePrompt}
|
||||
>
|
||||
{t('controlLayers.prompt')}
|
||||
{t('common.positivePrompt')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
@@ -42,10 +42,10 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
|
||||
onClick={addRegionalGuidanceNegativePrompt}
|
||||
isDisabled={!validActions.canAddNegativePrompt}
|
||||
>
|
||||
{t('controlLayers.negativePrompt')}
|
||||
{t('common.negativePrompt')}
|
||||
</Button>
|
||||
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addRegionalGuidanceIPAdapter}>
|
||||
{t('controlLayers.referenceImage')}
|
||||
{t('common.ipAdapter')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
onDelete: () => void;
|
||||
@@ -14,12 +14,11 @@ export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) =>
|
||||
<IconButton
|
||||
variant="link"
|
||||
aria-label={t('controlLayers.deletePrompt')}
|
||||
icon={<PiTrashSimpleFill />}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={onDelete}
|
||||
flexGrow={0}
|
||||
size="sm"
|
||||
p={0}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
|
||||
@@ -8,7 +8,7 @@ import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/cont
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.regionalGuidance.entities.map(mapId).reverse();
|
||||
return canvas.regions.entities.map(mapId).reverse();
|
||||
});
|
||||
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
|
||||
return selectedEntityIdentifier?.type === 'regional_guidance';
|
||||
|
||||
@@ -7,8 +7,10 @@ import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapt
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import {
|
||||
useIsSavingCanvas,
|
||||
usePullBboxIntoRegionalGuidanceIPAdapter,
|
||||
} from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import {
|
||||
rgIPAdapterBeginEndStepPctChanged,
|
||||
rgIPAdapterCLIPVisionModelChanged,
|
||||
@@ -18,114 +20,111 @@ import {
|
||||
rgIPAdapterModelChanged,
|
||||
rgIPAdapterWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice, selectRegionalGuidanceReferenceImage } from 'features/controlLayers/store/selectors';
|
||||
import { selectCanvasSlice, selectRegionalGuidanceIPAdapter } from 'features/controlLayers/store/selectors';
|
||||
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
|
||||
import type { RGIPAdapterImageDropData } from 'features/dnd/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import { PiBoundingBoxBold, PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, IPAdapterModelConfig, RGIPAdapterImagePostUploadAction } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
referenceImageId: string;
|
||||
ipAdapterId: string;
|
||||
ipAdapterNumber: number;
|
||||
};
|
||||
|
||||
export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Props) => {
|
||||
export const RegionalGuidanceIPAdapterSettings = memo(({ ipAdapterId, ipAdapterNumber }: Props) => {
|
||||
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const onDeleteIPAdapter = useCallback(() => {
|
||||
dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
|
||||
}, [dispatch, entityIdentifier, referenceImageId]);
|
||||
dispatch(rgIPAdapterDeleted({ entityIdentifier, ipAdapterId }));
|
||||
}, [dispatch, entityIdentifier, ipAdapterId]);
|
||||
const selectIPAdapter = useMemo(
|
||||
() =>
|
||||
createSelector(selectCanvasSlice, (canvas) => {
|
||||
const referenceImage = selectRegionalGuidanceReferenceImage(canvas, entityIdentifier, referenceImageId);
|
||||
assert(referenceImage, `Regional Guidance IP Adapter with id ${referenceImageId} not found`);
|
||||
return referenceImage.ipAdapter;
|
||||
const ipAdapter = selectRegionalGuidanceIPAdapter(canvas, entityIdentifier, ipAdapterId);
|
||||
assert(ipAdapter, `Regional GuidanceIP Adapter with id ${ipAdapterId} not found`);
|
||||
return ipAdapter;
|
||||
}),
|
||||
[entityIdentifier, referenceImageId]
|
||||
[entityIdentifier, ipAdapterId]
|
||||
);
|
||||
const ipAdapter = useAppSelector(selectIPAdapter);
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
dispatch(rgIPAdapterBeginEndStepPctChanged({ entityIdentifier, referenceImageId, beginEndStepPct }));
|
||||
dispatch(rgIPAdapterBeginEndStepPctChanged({ entityIdentifier, ipAdapterId, beginEndStepPct }));
|
||||
},
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
[dispatch, entityIdentifier, ipAdapterId]
|
||||
);
|
||||
|
||||
const onChangeWeight = useCallback(
|
||||
(weight: number) => {
|
||||
dispatch(rgIPAdapterWeightChanged({ entityIdentifier, referenceImageId, weight }));
|
||||
dispatch(rgIPAdapterWeightChanged({ entityIdentifier, ipAdapterId, weight }));
|
||||
},
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
[dispatch, entityIdentifier, ipAdapterId]
|
||||
);
|
||||
|
||||
const onChangeIPMethod = useCallback(
|
||||
(method: IPMethodV2) => {
|
||||
dispatch(rgIPAdapterMethodChanged({ entityIdentifier, referenceImageId, method }));
|
||||
dispatch(rgIPAdapterMethodChanged({ entityIdentifier, ipAdapterId, method }));
|
||||
},
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
[dispatch, entityIdentifier, ipAdapterId]
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig) => {
|
||||
dispatch(rgIPAdapterModelChanged({ entityIdentifier, referenceImageId, modelConfig }));
|
||||
dispatch(rgIPAdapterModelChanged({ entityIdentifier, ipAdapterId, modelConfig }));
|
||||
},
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
[dispatch, entityIdentifier, ipAdapterId]
|
||||
);
|
||||
|
||||
const onChangeCLIPVisionModel = useCallback(
|
||||
(clipVisionModel: CLIPVisionModelV2) => {
|
||||
dispatch(rgIPAdapterCLIPVisionModelChanged({ entityIdentifier, referenceImageId, clipVisionModel }));
|
||||
dispatch(rgIPAdapterCLIPVisionModelChanged({ entityIdentifier, ipAdapterId, clipVisionModel }));
|
||||
},
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
[dispatch, entityIdentifier, ipAdapterId]
|
||||
);
|
||||
|
||||
const onChangeImage = useCallback(
|
||||
(imageDTO: ImageDTO | null) => {
|
||||
dispatch(rgIPAdapterImageChanged({ entityIdentifier, referenceImageId, imageDTO }));
|
||||
dispatch(rgIPAdapterImageChanged({ entityIdentifier, ipAdapterId, imageDTO }));
|
||||
},
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
[dispatch, entityIdentifier, ipAdapterId]
|
||||
);
|
||||
|
||||
const droppableData = useMemo<RGIPAdapterImageDropData>(
|
||||
() => ({
|
||||
actionType: 'SET_RG_IP_ADAPTER_IMAGE',
|
||||
context: { id: entityIdentifier.id, referenceImageId: referenceImageId },
|
||||
context: { id: entityIdentifier.id, ipAdapterId },
|
||||
id: entityIdentifier.id,
|
||||
}),
|
||||
[entityIdentifier.id, referenceImageId]
|
||||
[entityIdentifier.id, ipAdapterId]
|
||||
);
|
||||
const postUploadAction = useMemo<RGIPAdapterImagePostUploadAction>(
|
||||
() => ({ type: 'SET_RG_IP_ADAPTER_IMAGE', id: entityIdentifier.id, referenceImageId: referenceImageId }),
|
||||
[entityIdentifier.id, referenceImageId]
|
||||
() => ({ type: 'SET_RG_IP_ADAPTER_IMAGE', id: entityIdentifier.id, ipAdapterId }),
|
||||
[entityIdentifier.id, ipAdapterId]
|
||||
);
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceIPAdapter(entityIdentifier, ipAdapterId);
|
||||
const isSaving = useIsSavingCanvas();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Text fontWeight="semibold" color="base.400">
|
||||
{t('controlLayers.referenceImage')}
|
||||
</Text>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex alignItems="center" gap={3}>
|
||||
<Text fontWeight="semibold" color="base.400">{`IP Adapter ${ipAdapterNumber}`}</Text>
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={<PiTrashSimpleFill />}
|
||||
tooltip={t('controlLayers.deleteReferenceImage')}
|
||||
aria-label={t('controlLayers.deleteReferenceImage')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
aria-label="Delete IP Adapter"
|
||||
onClick={onDeleteIPAdapter}
|
||||
variant="ghost"
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<Flex flexDir="column" gap={4} position="relative" w="full">
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
|
||||
<IPAdapterModel
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
@@ -136,23 +135,24 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
|
||||
</Box>
|
||||
<IconButton
|
||||
onClick={pullBboxIntoIPAdapter}
|
||||
isDisabled={isBusy}
|
||||
isLoading={isSaving.isTrue}
|
||||
variant="ghost"
|
||||
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
aria-label={t('controlLayers.pullBboxIntoIPAdapter')}
|
||||
tooltip={t('controlLayers.pullBboxIntoIPAdapter')}
|
||||
icon={<PiBoundingBoxBold />}
|
||||
/>
|
||||
</Flex>
|
||||
<Flex gap={2} w="full">
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<Flex gap={4} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={3} w="full">
|
||||
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
|
||||
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
|
||||
<IPAdapterImagePreview
|
||||
image={ipAdapter.image ?? null}
|
||||
onChangeImage={onChangeImage}
|
||||
ipAdapterId={ipAdapter.id}
|
||||
droppableData={droppableData}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
@@ -13,7 +13,7 @@ export const RegionalGuidanceIPAdapters = memo(() => {
|
||||
const selectIPAdapterIds = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
const ipAdapterIds = selectEntityOrThrow(canvas, entityIdentifier).referenceImages.map(({ id }) => id);
|
||||
const ipAdapterIds = selectEntityOrThrow(canvas, entityIdentifier).ipAdapters.map(({ id }) => id);
|
||||
if (ipAdapterIds.length === 0) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
@@ -33,7 +33,7 @@ export const RegionalGuidanceIPAdapters = memo(() => {
|
||||
{ipAdapterIds.map((ipAdapterId, index) => (
|
||||
<Fragment key={ipAdapterId}>
|
||||
{index > 0 && <Divider />}
|
||||
<RegionalGuidanceIPAdapterSettings referenceImageId={ipAdapterId} />
|
||||
<RegionalGuidanceIPAdapterSettings ipAdapterId={ipAdapterId} ipAdapterNumber={index + 1} />
|
||||
</Fragment>
|
||||
))}
|
||||
</>
|
||||
|
||||
@@ -33,7 +33,7 @@ export const RegionalGuidanceMenuItemsAddPromptsAndIPAdapter = memo(() => {
|
||||
{t('controlLayers.addNegativePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={addRegionalGuidanceIPAdapter} isDisabled={isBusy}>
|
||||
{t('controlLayers.addReferenceImage')}
|
||||
{t('controlLayers.addIPAdapter')}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -20,7 +20,7 @@ export const RegionalGuidanceSettings = memo(() => {
|
||||
return {
|
||||
hasPositivePrompt: entity.positivePrompt !== null,
|
||||
hasNegativePrompt: entity.negativePrompt !== null,
|
||||
hasIPAdapters: entity.referenceImages.length > 0,
|
||||
hasIPAdapters: entity.ipAdapters.length > 0,
|
||||
};
|
||||
}),
|
||||
[entityIdentifier]
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectBboxOverlay, settingsBboxOverlayToggled } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsBboxOverlaySwitch = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxOverlay = useAppSelector(selectBboxOverlay);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsBboxOverlayToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel m={0} flexGrow={1}>
|
||||
{t('controlLayers.bboxOverlay')}
|
||||
</FormLabel>
|
||||
<Switch size="sm" isChecked={bboxOverlay} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsBboxOverlaySwitch.displayName = 'CanvasSettingsBboxOverlaySwitch';
|
||||
@@ -0,0 +1,33 @@
|
||||
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectCanvasSettingsSlice,
|
||||
settingsCompositeMaskedRegionsChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectCompositeMaskedRegions = createSelector(
|
||||
selectCanvasSettingsSlice,
|
||||
(canvasSettings) => canvasSettings.compositeMaskedRegions
|
||||
);
|
||||
|
||||
export const CanvasSettingsCompositeMaskedRegionsCheckbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const compositeMaskedRegions = useAppSelector(selectCompositeMaskedRegions);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(settingsCompositeMaskedRegionsChanged(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.compositeMaskedRegions')}</FormLabel>
|
||||
<Checkbox isChecked={compositeMaskedRegions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsCompositeMaskedRegionsCheckbox.displayName = 'CanvasSettingsCompositeMaskedRegionsCheckbox';
|
||||
@@ -1,25 +0,0 @@
|
||||
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectOutputOnlyMaskedRegions,
|
||||
settingsOutputOnlyMaskedRegionsToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsOutputOnlyMaskedRegionsCheckbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const outputOnlyMaskedRegions = useAppSelector(selectOutputOnlyMaskedRegions);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsOutputOnlyMaskedRegionsToggled());
|
||||
}, [dispatch]);
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.outputOnlyMaskedRegions')}</FormLabel>
|
||||
<Checkbox isChecked={outputOnlyMaskedRegions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsOutputOnlyMaskedRegionsCheckbox.displayName = 'CanvasSettingsOutputOnlyMaskedRegionsCheckbox';
|
||||
@@ -10,29 +10,27 @@ import {
|
||||
useShiftModifier,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { CanvasSettingsAutoSaveCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsAutoSaveCheckbox';
|
||||
import { CanvasSettingsBboxOverlaySwitch } from 'features/controlLayers/components/Settings/CanvasSettingsBboxOverlaySwitch';
|
||||
import { CanvasSettingsClearCachesButton } from 'features/controlLayers/components/Settings/CanvasSettingsClearCachesButton';
|
||||
import { CanvasSettingsClearHistoryButton } from 'features/controlLayers/components/Settings/CanvasSettingsClearHistoryButton';
|
||||
import { CanvasSettingsClipToBboxCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsClipToBboxCheckbox';
|
||||
import { CanvasSettingsCompositeMaskedRegionsCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsCompositeMaskedRegionsCheckbox';
|
||||
import { CanvasSettingsDynamicGridSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsDynamicGridSwitch';
|
||||
import { CanvasSettingsSnapToGridCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsGridSize';
|
||||
import { CanvasSettingsInvertScrollCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsInvertScrollCheckbox';
|
||||
import { CanvasSettingsLogDebugInfoButton } from 'features/controlLayers/components/Settings/CanvasSettingsLogDebugInfo';
|
||||
import { CanvasSettingsOutputOnlyMaskedRegionsCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsOutputOnlyMaskedRegionsCheckbox';
|
||||
import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPreserveMaskCheckbox';
|
||||
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
|
||||
import { CanvasSettingsResetButton } from 'features/controlLayers/components/Settings/CanvasSettingsResetButton';
|
||||
import { CanvasSettingsShowHUDSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsShowHUDSwitch';
|
||||
import { CanvasSettingsShowProgressOnCanvas } from 'features/controlLayers/components/Settings/CanvasSettingsShowProgressOnCanvasSwitch';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
import { RiSettings4Fill } from 'react-icons/ri';
|
||||
|
||||
export const CanvasSettingsPopover = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<IconButton aria-label={t('common.settingsLabel')} icon={<PiGearSixFill />} variant="ghost" />
|
||||
<IconButton aria-label={t('common.settingsLabel')} icon={<RiSettings4Fill />} variant="ghost" />
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
@@ -40,14 +38,12 @@ export const CanvasSettingsPopover = memo(() => {
|
||||
<Flex direction="column" gap={2}>
|
||||
<CanvasSettingsAutoSaveCheckbox />
|
||||
<CanvasSettingsInvertScrollCheckbox />
|
||||
<CanvasSettingsPreserveMaskCheckbox />
|
||||
<CanvasSettingsClipToBboxCheckbox />
|
||||
<CanvasSettingsOutputOnlyMaskedRegionsCheckbox />
|
||||
<CanvasSettingsCompositeMaskedRegionsCheckbox />
|
||||
<CanvasSettingsSnapToGridCheckbox />
|
||||
<CanvasSettingsShowProgressOnCanvas />
|
||||
<CanvasSettingsDynamicGridSwitch />
|
||||
<CanvasSettingsBboxOverlaySwitch />
|
||||
<CanvasSettingsShowHUDSwitch />
|
||||
<CanvasSettingsResetButton />
|
||||
<DebugSettings />
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectPreserveMask, settingsPreserveMaskToggled } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsPreserveMaskCheckbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const preserveMask = useAppSelector(selectPreserveMask);
|
||||
const onChange = useCallback(() => dispatch(settingsPreserveMaskToggled()), [dispatch]);
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.settings.preserveMask.label')}</FormLabel>
|
||||
<Checkbox isChecked={preserveMask} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsPreserveMaskCheckbox.displayName = 'CanvasSettingsPreserveMaskCheckbox';
|
||||
@@ -1,12 +1,11 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { canvasReset } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasToolbarResetCanvasButton = memo(() => {
|
||||
export const CanvasSettingsResetButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const canvasManager = useCanvasManager();
|
||||
@@ -15,15 +14,10 @@ export const CanvasToolbarResetCanvasButton = memo(() => {
|
||||
canvasManager.stage.fitLayersToStage();
|
||||
}, [canvasManager.stage, dispatch]);
|
||||
return (
|
||||
<IconButton
|
||||
aria-label={t('controlLayers.resetCanvas')}
|
||||
tooltip={t('controlLayers.resetCanvas')}
|
||||
onClick={onClick}
|
||||
colorScheme="error"
|
||||
icon={<PiTrashBold />}
|
||||
variant="ghost"
|
||||
/>
|
||||
<Button onClick={onClick} colorScheme="error" size="sm">
|
||||
{t('controlLayers.resetCanvas')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasToolbarResetCanvasButton.displayName = 'CanvasToolbarResetCanvasButton';
|
||||
CanvasSettingsResetButton.displayName = 'CanvasSettingsResetButton';
|
||||
@@ -1,28 +0,0 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectShowProgressOnCanvas,
|
||||
settingsShowProgressOnCanvasToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsShowProgressOnCanvas = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const showProgressOnCanvas = useAppSelector(selectShowProgressOnCanvas);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsShowProgressOnCanvasToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel m={0} flexGrow={1}>
|
||||
{t('controlLayers.showProgressOnCanvas')}
|
||||
</FormLabel>
|
||||
<Switch size="sm" isChecked={showProgressOnCanvas} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsShowProgressOnCanvas.displayName = 'CanvasSettingsShowProgressOnCanvas';
|
||||
@@ -26,14 +26,7 @@ export const ToolColorPicker = memo(() => {
|
||||
<Flex role="button" aria-label={t('controlLayers.fill.fillColor')} tabIndex={-1} w={8} h={8}>
|
||||
<Tooltip label={t('controlLayers.fill.fillColor')}>
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Box
|
||||
borderRadius="full"
|
||||
borderColor="base.300"
|
||||
w={6}
|
||||
h={6}
|
||||
borderWidth={1}
|
||||
bg={rgbaColorToString(fill)}
|
||||
/>
|
||||
<Box borderRadius="full" w={6} h={6} borderWidth={1} bg={rgbaColorToString(fill)} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
|
||||
@@ -5,10 +5,10 @@ import { ToolChooser } from 'features/controlLayers/components/Tool/ToolChooser'
|
||||
import { ToolColorPicker } from 'features/controlLayers/components/Tool/ToolFillColorPicker';
|
||||
import { ToolSettings } from 'features/controlLayers/components/Tool/ToolSettings';
|
||||
import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton';
|
||||
import { CanvasToolbarResetCanvasButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetCanvasButton';
|
||||
import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton';
|
||||
import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton';
|
||||
import { CanvasToolbarScale } from 'features/controlLayers/components/Toolbar/CanvasToolbarScale';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasDeleteLayerHotkey } from 'features/controlLayers/hooks/useCanvasDeleteLayerHotkey';
|
||||
import { useCanvasEntityQuickSwitchHotkey } from 'features/controlLayers/hooks/useCanvasEntityQuickSwitchHotkey';
|
||||
import { useCanvasResetLayerHotkey } from 'features/controlLayers/hooks/useCanvasResetLayerHotkey';
|
||||
@@ -24,20 +24,21 @@ export const CanvasToolbar = memo(() => {
|
||||
useNextPrevEntityHotkeys();
|
||||
|
||||
return (
|
||||
<Flex w="full" gap={2} alignItems="center">
|
||||
<ToolChooser />
|
||||
<Spacer />
|
||||
<ToolSettings />
|
||||
<Spacer />
|
||||
<CanvasToolbarScale />
|
||||
<CanvasToolbarResetViewButton />
|
||||
<Spacer />
|
||||
<ToolColorPicker />
|
||||
<CanvasToolbarFitBboxToLayersButton />
|
||||
<CanvasToolbarSaveToGalleryButton />
|
||||
<CanvasToolbarResetCanvasButton />
|
||||
<CanvasSettingsPopover />
|
||||
</Flex>
|
||||
<CanvasManagerProviderGate>
|
||||
<Flex w="full" gap={2} alignItems="center">
|
||||
<ToolChooser />
|
||||
<Spacer />
|
||||
<ToolSettings />
|
||||
<Spacer />
|
||||
<CanvasToolbarScale />
|
||||
<CanvasToolbarResetViewButton />
|
||||
<Spacer />
|
||||
<ToolColorPicker />
|
||||
<CanvasToolbarFitBboxToLayersButton />
|
||||
<CanvasToolbarSaveToGalleryButton />
|
||||
<CanvasSettingsPopover />
|
||||
</Flex>
|
||||
</CanvasManagerProviderGate>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsOut } from 'react-icons/pi';
|
||||
@@ -8,7 +7,6 @@ import { PiArrowsOut } from 'react-icons/pi';
|
||||
export const CanvasToolbarFitBboxToLayersButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const canvasManager = useCanvasManager();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const onClick = useCallback(() => {
|
||||
canvasManager.bbox.fitToLayers();
|
||||
}, [canvasManager.bbox]);
|
||||
@@ -20,7 +18,6 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
|
||||
aria-label={t('controlLayers.fitBboxToLayers')}
|
||||
tooltip={t('controlLayers.fitBboxToLayers')}
|
||||
icon={<PiArrowsOut />}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { IconButton, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { useSaveBboxToGallery, useSaveCanvasToGallery } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import {
|
||||
useIsSavingCanvas,
|
||||
useSaveBboxToGallery,
|
||||
useSaveCanvasToGallery,
|
||||
} from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFloppyDiskBold } from 'react-icons/pi';
|
||||
@@ -8,7 +11,7 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
|
||||
export const CanvasToolbarSaveToGalleryButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const shift = useShiftModifier();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isSaving = useIsSavingCanvas();
|
||||
const saveCanvasToGallery = useSaveCanvasToGallery();
|
||||
const saveBboxToGallery = useSaveBboxToGallery();
|
||||
|
||||
@@ -17,9 +20,9 @@ export const CanvasToolbarSaveToGalleryButton = memo(() => {
|
||||
variant="ghost"
|
||||
onClick={shift ? saveBboxToGallery : saveCanvasToGallery}
|
||||
icon={<PiFloppyDiskBold />}
|
||||
isLoading={isSaving.isTrue}
|
||||
aria-label={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')}
|
||||
tooltip={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -31,7 +31,7 @@ const TransformBox = memo(({ adapter }: { adapter: CanvasEntityAdapter }) => {
|
||||
leftIcon={<PiArrowsOutBold />}
|
||||
onClick={adapter.transformer.fitProxyRectToBbox}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.transform.fitToBbox')}
|
||||
loadingText={t('controlLayers.transform.reset')}
|
||||
variant="ghost"
|
||||
>
|
||||
{t('controlLayers.transform.fitToBbox')}
|
||||
|
||||
@@ -18,9 +18,9 @@ export const BeginEndStepPct = memo(({ beginEndStepPct, onChange }: Props) => {
|
||||
}, [onChange]);
|
||||
|
||||
return (
|
||||
<FormControl orientation="horizontal" pe={2}>
|
||||
<FormControl orientation="horizontal" pe={1}>
|
||||
<InformationalPopover feature="controlNetBeginEnd">
|
||||
<FormLabel m={0}>{t('controlLayers.beginEndStepPercentShort')}</FormLabel>
|
||||
<FormLabel m={0}>{t('controlnet.beginEndStepPercentShort')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<CompositeRangeSlider
|
||||
aria-label={ariaLabel}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddGlobalReferenceImage,
|
||||
useAddInpaintMask,
|
||||
useAddIPAdapter,
|
||||
useAddRasterLayer,
|
||||
useAddRegionalGuidance,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -18,12 +17,11 @@ type Props = {
|
||||
|
||||
export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const addInpaintMask = useAddInpaintMask();
|
||||
const addRegionalGuidance = useAddRegionalGuidance();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const addGlobalReferenceImage = useAddGlobalReferenceImage();
|
||||
const addIPAdapter = useAddIPAdapter();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
switch (type) {
|
||||
@@ -39,11 +37,11 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
|
||||
case 'control_layer':
|
||||
addControlLayer();
|
||||
break;
|
||||
case 'reference_image':
|
||||
addGlobalReferenceImage();
|
||||
case 'ip_adapter':
|
||||
addIPAdapter();
|
||||
break;
|
||||
}
|
||||
}, [addControlLayer, addGlobalReferenceImage, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
|
||||
}, [addControlLayer, addIPAdapter, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
|
||||
|
||||
const label = useMemo(() => {
|
||||
switch (type) {
|
||||
@@ -55,8 +53,8 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
|
||||
return t('controlLayers.addRasterLayer');
|
||||
case 'control_layer':
|
||||
return t('controlLayers.addControlLayer');
|
||||
case 'reference_image':
|
||||
return t('controlLayers.addGlobalReferenceImage');
|
||||
case 'ip_adapter':
|
||||
return t('controlLayers.addIPAdapter');
|
||||
}
|
||||
}, [type, t]);
|
||||
|
||||
@@ -69,7 +67,6 @@ export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
|
||||
icon={<PiPlusBold />}
|
||||
onClick={onClick}
|
||||
alignSelf="stretch"
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { entityDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -11,7 +10,6 @@ export const CanvasEntityDeleteButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(entityDeleted({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
@@ -26,7 +24,6 @@ export const CanvasEntityDeleteButton = memo(() => {
|
||||
icon={<PiTrashSimpleFill />}
|
||||
onClick={onClick}
|
||||
colorScheme="error"
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user