mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 17:48:13 -05:00
Compare commits
129 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5edee6997e | ||
|
|
9aaecf5b5c | ||
|
|
b4a2244943 | ||
|
|
155bf13d2b | ||
|
|
9f7b5f7a85 | ||
|
|
b3d16b4979 | ||
|
|
10b2567fcb | ||
|
|
04feb74f81 | ||
|
|
a7d8db8c15 | ||
|
|
b3b930a6f5 | ||
|
|
43f108fe9f | ||
|
|
f1f2525ed0 | ||
|
|
afd7b50343 | ||
|
|
3583d03b70 | ||
|
|
bc954b9996 | ||
|
|
c08075946a | ||
|
|
df8df914e8 | ||
|
|
33924e8491 | ||
|
|
7e5ce1d69d | ||
|
|
6a24594140 | ||
|
|
61d26cffe6 | ||
|
|
fdbc244dbe | ||
|
|
0eea84c90d | ||
|
|
e079a91800 | ||
|
|
eb20173487 | ||
|
|
20dd0779b5 | ||
|
|
b384a92f5c | ||
|
|
116d32fbbe | ||
|
|
b044f31a61 | ||
|
|
6c3c24403b | ||
|
|
591f48bb95 | ||
|
|
dc6e45485c | ||
|
|
829820479d | ||
|
|
48a471bfb8 | ||
|
|
ff72315db2 | ||
|
|
790846297a | ||
|
|
230b455a13 | ||
|
|
71f0fff55b | ||
|
|
7f2c83b9e6 | ||
|
|
bc85bd4bd4 | ||
|
|
38b09d73e4 | ||
|
|
606c4ae88c | ||
|
|
f666bac77f | ||
|
|
c9bf7da23a | ||
|
|
dfc65b93e9 | ||
|
|
9ca40b4cf5 | ||
|
|
d571e71d5e | ||
|
|
ad1e6c3fe6 | ||
|
|
21d02911dd | ||
|
|
43afe0bd9a | ||
|
|
e7a68c446d | ||
|
|
b9c68a2e7e | ||
|
|
371a1b1af3 | ||
|
|
dae4591de6 | ||
|
|
8ccb2e30ce | ||
|
|
b8106a4613 | ||
|
|
ce51e9582a | ||
|
|
00848eb631 | ||
|
|
b48430a892 | ||
|
|
f94a218561 | ||
|
|
9b6ed40875 | ||
|
|
26553dbb0e | ||
|
|
9eb695d0b4 | ||
|
|
babab17e1d | ||
|
|
d0a80f3347 | ||
|
|
9b30363177 | ||
|
|
89bde36b0c | ||
|
|
86a8476d97 | ||
|
|
afa0661e55 | ||
|
|
ba09c1277f | ||
|
|
80bf9ddb71 | ||
|
|
1dbc98d747 | ||
|
|
0698188ea2 | ||
|
|
59d0ad4505 | ||
|
|
074a5692dd | ||
|
|
bb0741146a | ||
|
|
1845d9a87a | ||
|
|
748c393e71 | ||
|
|
9bd17ea02f | ||
|
|
24f9b46fbc | ||
|
|
54b3aa1d01 | ||
|
|
d85733f22b | ||
|
|
aff6ad0316 | ||
|
|
61496fdcbc | ||
|
|
ee8975401a | ||
|
|
bf3260446d | ||
|
|
f53823b45e | ||
|
|
5cbe89afdd | ||
|
|
c466d50c3d | ||
|
|
d20b894a61 | ||
|
|
20362448b9 | ||
|
|
5df10cc494 | ||
|
|
da171114ea | ||
|
|
62919a443c | ||
|
|
ffcec91d87 | ||
|
|
0a96466b60 | ||
|
|
e48cab0276 | ||
|
|
740f6eb19f | ||
|
|
d1bb4c2c70 | ||
|
|
e545f18a45 | ||
|
|
e8cd1bb3d8 | ||
|
|
90a906e203 | ||
|
|
5546110127 | ||
|
|
73bbb12f7a | ||
|
|
dde54740c5 | ||
|
|
f70a8e2c1a | ||
|
|
fdccdd52d5 | ||
|
|
31ffd73423 | ||
|
|
3fa1012879 | ||
|
|
c2a8fbd8d6 | ||
|
|
d6643d7263 | ||
|
|
412e79d8e6 | ||
|
|
f939dbdc33 | ||
|
|
24a0ca86f5 | ||
|
|
95c30f6a8b | ||
|
|
ac7441e606 | ||
|
|
9c9af312fe | ||
|
|
7bf5927c43 | ||
|
|
32c7cdd856 | ||
|
|
bbd89d54b4 | ||
|
|
ee61006a49 | ||
|
|
0b43f5fd64 | ||
|
|
6c61266990 | ||
|
|
2d5afe8094 | ||
|
|
2430137d19 | ||
|
|
6df4ee5fc8 | ||
|
|
371742d8f9 | ||
|
|
5440c03767 | ||
|
|
73d4c4d56d |
@@ -38,7 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.6"; \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
|
||||
else \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
|
||||
fi &&\
|
||||
|
||||
@@ -12,7 +12,7 @@ MINIMUM_PYTHON_VERSION=3.10.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.100
|
||||
PYTHON=""
|
||||
for candidate in python3.11 python3.10 python3 python ; do
|
||||
if ppath=`which $candidate`; then
|
||||
if ppath=`which $candidate 2>/dev/null`; then
|
||||
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
||||
# we check that this found executable can actually run
|
||||
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
||||
@@ -30,10 +30,11 @@ done
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "A suitable Python interpreter could not be found"
|
||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||
read -p "Press any key to exit"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||
|
||||
exec $PYTHON ./lib/main.py ${@}
|
||||
read -p "Press any key to exit"
|
||||
|
||||
@@ -245,6 +245,9 @@ class InvokeAiInstance:
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
# Uninstall xformers if it is present; the correct version of it will be reinstalled if needed
|
||||
_ = pip["uninstall", "-yqq", "xformers"] & FG
|
||||
|
||||
pipeline = pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
@@ -407,7 +410,7 @@ def get_torch_source() -> Tuple[str | None, str | None]:
|
||||
optional_modules: str | None = None
|
||||
if OS == "Linux":
|
||||
if device == GpuType.ROCM:
|
||||
url = "https://download.pytorch.org/whl/rocm5.6"
|
||||
url = "https://download.pytorch.org/whl/rocm6.1"
|
||||
elif device == GpuType.CPU:
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
elif device == GpuType.CUDA:
|
||||
|
||||
@@ -547,7 +547,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||
single_ipa_images = [
|
||||
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
|
||||
]
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
|
||||
@@ -133,6 +133,7 @@ class FieldDescriptions:
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
mmditx = "MMDiTX"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
@@ -140,6 +141,7 @@ class FieldDescriptions:
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
@@ -17,6 +21,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -26,6 +31,8 @@ from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule_fractional,
|
||||
@@ -49,7 +56,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.1.0",
|
||||
version="3.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -82,6 +89,24 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_text_conditioning: FluxConditioningField | None = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
cfg_scale_start_step: int = InputField(
|
||||
default=0,
|
||||
title="CFG Scale Start Step",
|
||||
description="Index of the first step to apply cfg_scale. Negative indices count backwards from the "
|
||||
+ "the last step (e.g. a value of -1 refers to the final step).",
|
||||
)
|
||||
cfg_scale_end_step: int = InputField(
|
||||
default=-1,
|
||||
title="CFG Scale End Step",
|
||||
description="Index of the last step to apply cfg_scale. Negative indices count backwards from the "
|
||||
+ "last step (e.g. a value of -1 refers to the final step).",
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
@@ -96,10 +121,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
default=None, input=Input.Connection, description="ControlNet models."
|
||||
)
|
||||
controlnet_vae: VAEField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -108,6 +138,19 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
return t5_embeddings, clip_embeddings
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -115,13 +158,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
neg_t5_embeddings: torch.Tensor | None = None
|
||||
neg_clip_embeddings: torch.Tensor | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
@@ -182,8 +227,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
|
||||
pos_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
neg_txt_ids: torch.Tensor | None = None
|
||||
if neg_t5_embeddings is not None:
|
||||
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
|
||||
neg_txt_ids = torch.zeros(
|
||||
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
@@ -204,6 +257,21 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
# Compute the IP-Adapter image prompt clip embeddings.
|
||||
# We do this before loading other models to minimize peak memory.
|
||||
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
|
||||
ip_adapter_fields = self._normalize_ip_adapter_fields()
|
||||
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
|
||||
ip_adapter_fields, context
|
||||
)
|
||||
|
||||
cfg_scale = self.prep_cfg_scale(
|
||||
cfg_scale=self.cfg_scale,
|
||||
timesteps=timesteps,
|
||||
cfg_scale_start_step=self.cfg_scale_start_step,
|
||||
cfg_scale_end_step=self.cfg_scale_end_step,
|
||||
)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
# Prepare ControlNet extensions.
|
||||
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
|
||||
@@ -252,23 +320,88 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {config.format}")
|
||||
|
||||
# Prepare IP-Adapter extensions.
|
||||
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
|
||||
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
|
||||
neg_image_prompt_clip_embeds=neg_image_prompt_clip_embeds,
|
||||
ip_adapter_fields=ip_adapter_fields,
|
||||
context=context,
|
||||
exit_stack=exit_stack,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
txt=pos_t5_embeddings,
|
||||
txt_ids=pos_txt_ids,
|
||||
vec=pos_clip_embeddings,
|
||||
neg_txt=neg_t5_embeddings,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
neg_vec=neg_clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
cfg_scale=cfg_scale,
|
||||
inpaint_extension=inpaint_extension,
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
) -> list[float]:
|
||||
"""Prepare the cfg_scale schedule.
|
||||
|
||||
- Clips the cfg_scale schedule based on cfg_scale_start_step and cfg_scale_end_step.
|
||||
- If cfg_scale is a list, then it is assumed to be a schedule and is returned as-is.
|
||||
- If cfg_scale is a scalar, then a linear schedule is created from cfg_scale_start_step to cfg_scale_end_step.
|
||||
"""
|
||||
# num_steps is the number of denoising steps, which is one less than the number of timesteps.
|
||||
num_steps = len(timesteps) - 1
|
||||
|
||||
# Normalize cfg_scale to a list if it is a scalar.
|
||||
cfg_scale_list: list[float]
|
||||
if isinstance(cfg_scale, float):
|
||||
cfg_scale_list = [cfg_scale] * num_steps
|
||||
elif isinstance(cfg_scale, list):
|
||||
cfg_scale_list = cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Unsupported cfg_scale type: {type(cfg_scale)}")
|
||||
assert len(cfg_scale_list) == num_steps
|
||||
|
||||
# Handle negative indices for cfg_scale_start_step and cfg_scale_end_step.
|
||||
start_step_index = cfg_scale_start_step
|
||||
if start_step_index < 0:
|
||||
start_step_index = num_steps + start_step_index
|
||||
end_step_index = cfg_scale_end_step
|
||||
if end_step_index < 0:
|
||||
end_step_index = num_steps + end_step_index
|
||||
|
||||
# Validate the start and end step indices.
|
||||
if not (0 <= start_step_index < num_steps):
|
||||
raise ValueError(f"Invalid cfg_scale_start_step. Out of range: {cfg_scale_start_step}.")
|
||||
if not (0 <= end_step_index < num_steps):
|
||||
raise ValueError(f"Invalid cfg_scale_end_step. Out of range: {cfg_scale_end_step}.")
|
||||
if start_step_index > end_step_index:
|
||||
raise ValueError(
|
||||
f"cfg_scale_start_step ({cfg_scale_start_step}) must be before cfg_scale_end_step "
|
||||
+ f"({cfg_scale_end_step})."
|
||||
)
|
||||
|
||||
# Set values outside the start and end step indices to 1.0. This is equivalent to disabling cfg_scale for those
|
||||
# steps.
|
||||
clipped_cfg_scale = [1.0] * num_steps
|
||||
clipped_cfg_scale[start_step_index : end_step_index + 1] = cfg_scale_list[start_step_index : end_step_index + 1]
|
||||
|
||||
return clipped_cfg_scale
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
@@ -408,6 +541,112 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return controlnet_extensions
|
||||
|
||||
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
||||
if self.ip_adapter is None:
|
||||
return []
|
||||
elif isinstance(self.ip_adapter, IPAdapterField):
|
||||
return [self.ip_adapter]
|
||||
elif isinstance(self.ip_adapter, list):
|
||||
return self.ip_adapter
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter type: {type(self.ip_adapter)}")
|
||||
|
||||
def _prep_ip_adapter_image_prompt_clip_embeds(
|
||||
self,
|
||||
ip_adapter_fields: list[IPAdapterField],
|
||||
context: InvocationContext,
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||
clip_image_processor = CLIPImageProcessor()
|
||||
|
||||
pos_image_prompt_clip_embeds: list[torch.Tensor] = []
|
||||
neg_image_prompt_clip_embeds: list[torch.Tensor] = []
|
||||
for ip_adapter_field in ip_adapter_fields:
|
||||
# `ip_adapter_field.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
ipa_image_fields: list[ImageField]
|
||||
if isinstance(ip_adapter_field.image, ImageField):
|
||||
ipa_image_fields = [ip_adapter_field.image]
|
||||
elif isinstance(ip_adapter_field.image, list):
|
||||
ipa_image_fields = ip_adapter_field.image
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter image type: {type(ip_adapter_field.image)}")
|
||||
|
||||
if len(ipa_image_fields) != 1:
|
||||
raise ValueError(
|
||||
f"FLUX IP-Adapter only supports a single image prompt (received {len(ipa_image_fields)})."
|
||||
)
|
||||
|
||||
ipa_images = [context.images.get_pil(image.image_name, mode="RGB") for image in ipa_image_fields]
|
||||
|
||||
pos_images: list[npt.NDArray[np.uint8]] = []
|
||||
neg_images: list[npt.NDArray[np.uint8]] = []
|
||||
for ipa_image in ipa_images:
|
||||
assert ipa_image.mode == "RGB"
|
||||
pos_image = np.array(ipa_image)
|
||||
# We use a black image as the negative image prompt for parity with
|
||||
# https://github.com/XLabs-AI/x-flux-comfyui/blob/45c834727dd2141aebc505ae4b01f193a8414e38/nodes.py#L592-L593
|
||||
# An alternative scheme would be to apply zeros_like() after calling the clip_image_processor.
|
||||
neg_image = np.zeros_like(pos_image)
|
||||
pos_images.append(pos_image)
|
||||
neg_images.append(neg_image)
|
||||
|
||||
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
|
||||
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
|
||||
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
|
||||
|
||||
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
|
||||
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
|
||||
|
||||
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
|
||||
neg_image_prompt_clip_embeds.append(neg_clip_image_embeds)
|
||||
|
||||
return pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds
|
||||
|
||||
def _prep_ip_adapter_extensions(
|
||||
self,
|
||||
ip_adapter_fields: list[IPAdapterField],
|
||||
pos_image_prompt_clip_embeds: list[torch.Tensor],
|
||||
neg_image_prompt_clip_embeds: list[torch.Tensor],
|
||||
context: InvocationContext,
|
||||
exit_stack: ExitStack,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[list[XLabsIPAdapterExtension], list[XLabsIPAdapterExtension]]:
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
|
||||
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
|
||||
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
|
||||
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
|
||||
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
|
||||
if ip_adapter_field.mask is not None:
|
||||
raise ValueError("IP-Adapter masks are not yet supported in Flux.")
|
||||
ip_adapter_extension = XLabsIPAdapterExtension(
|
||||
model=ip_adapter_model,
|
||||
image_prompt_clip_embed=pos_image_prompt_clip_embed,
|
||||
weight=ip_adapter_field.weight,
|
||||
begin_step_percent=ip_adapter_field.begin_step_percent,
|
||||
end_step_percent=ip_adapter_field.end_step_percent,
|
||||
)
|
||||
ip_adapter_extension.run_image_proj(dtype=dtype)
|
||||
pos_ip_adapter_extensions.append(ip_adapter_extension)
|
||||
|
||||
ip_adapter_extension = XLabsIPAdapterExtension(
|
||||
model=ip_adapter_model,
|
||||
image_prompt_clip_embed=neg_image_prompt_clip_embed,
|
||||
weight=ip_adapter_field.weight,
|
||||
begin_step_percent=ip_adapter_field.begin_step_percent,
|
||||
end_step_percent=ip_adapter_field.end_step_percent,
|
||||
)
|
||||
ip_adapter_extension.run_image_proj(dtype=dtype)
|
||||
neg_ip_adapter_extensions.append(ip_adapter_extension)
|
||||
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
|
||||
89
invokeai/app/invocations/flux_ip_adapter.py
Normal file
89
invokeai/app/invocations/flux_ip_adapter.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from builtins import float
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from pydantic import field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import InputField, UIType
|
||||
from invokeai.app.invocations.ip_adapter import (
|
||||
CLIP_VISION_MODEL_MAP,
|
||||
IPAdapterField,
|
||||
IPAdapterInvocation,
|
||||
IPAdapterOutput,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_ip_adapter",
|
||||
title="FLUX IP-Adapter",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxIPAdapterInvocation(BaseInvocation):
|
||||
"""Collects FLUX IP-Adapter info to pass to other nodes."""
|
||||
|
||||
# FLUXIPAdapterInvocation is based closely on IPAdapterInvocation, but with some unsupported features removed.
|
||||
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", ui_type=UIType.IPAdapterModel
|
||||
)
|
||||
# Currently, the only known ViT model used by FLUX IP-Adapters is ViT-L.
|
||||
clip_vision_model: Literal["ViT-L"] = InputField(description="CLIP Vision model to use.", default="ViT-L")
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
def validate_ip_adapter_weight(cls, v: float) -> float:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> Self:
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||
|
||||
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
|
||||
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
image_encoder_model_id = image_encoder_starter_model.source
|
||||
image_encoder_model_name = image_encoder_starter_model.name
|
||||
image_encoder_model = IPAdapterInvocation.get_clip_image_encoder(
|
||||
context, image_encoder_model_id, image_encoder_model_name
|
||||
)
|
||||
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||
weight=self.weight,
|
||||
target_blocks=[], # target_blocks is currently unused for FLUX IP-Adapters.
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
mask=None, # mask is currently unused for FLUX IP-Adapters.
|
||||
),
|
||||
)
|
||||
86
invokeai/app/invocations/flux_model_loader.py
Normal file
86
invokeai/app/invocations/flux_model_loader.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, SubModelType
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
@@ -9,6 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Outpu
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@@ -17,6 +18,12 @@ from invokeai.backend.model_manager.config import (
|
||||
IPAdapterInvokeAIConfig,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.starter_models import (
|
||||
StarterModel,
|
||||
clip_vit_l_image_encoder,
|
||||
ip_adapter_sd_image_encoder,
|
||||
ip_adapter_sdxl_image_encoder,
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
@@ -55,10 +62,14 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] = {
|
||||
"ViT-L": clip_vit_l_image_encoder,
|
||||
"ViT-H": ip_adapter_sd_image_encoder,
|
||||
"ViT-G": ip_adapter_sdxl_image_encoder,
|
||||
}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.5.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -70,7 +81,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G", "ViT-L"] = InputField(
|
||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||
default="ViT-H",
|
||||
ui_order=2,
|
||||
@@ -111,9 +122,11 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
else:
|
||||
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
image_encoder_model_id = image_encoder_starter_model.source
|
||||
image_encoder_model_name = image_encoder_starter_model.name
|
||||
|
||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||
image_encoder_model = self.get_clip_image_encoder(context, image_encoder_model_id, image_encoder_model_name)
|
||||
|
||||
if self.method == "style":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
@@ -147,7 +160,10 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||
@classmethod
|
||||
def get_clip_image_encoder(
|
||||
cls, context: InvocationContext, image_encoder_model_id: str, image_encoder_model_name: str
|
||||
) -> AnyModelConfig:
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
@@ -159,7 +175,11 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
installer = context._services.model_manager.install
|
||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||
# Note: We hard-code the type to CLIPVision here because if the model contains both a CLIPVision and a
|
||||
# CLIPText model, the probe may treat it as a CLIPText model.
|
||||
job = installer.heuristic_import(
|
||||
image_encoder_model_id, ModelRecordChanges(name=image_encoder_model_name, type=ModelType.CLIPVision)
|
||||
)
|
||||
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
|
||||
@@ -5,6 +5,7 @@ from PIL import Image
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
||||
from invokeai.backend.image_util.util import pil_to_np
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -148,3 +149,51 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
mask_pil = Image.fromarray(mask_np, mode="L")
|
||||
image_dto = context.images.save(image=mask_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"apply_tensor_mask_to_image",
|
||||
title="Apply Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Applies a tensor mask to an image.
|
||||
|
||||
The image is converted to RGBA and the mask is applied to the alpha channel."""
|
||||
|
||||
mask: TensorField = InputField(description="The mask tensor to apply.")
|
||||
image: ImageField = InputField(description="The image to apply the mask to.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
|
||||
# Squeeze the channel dimension if it exists.
|
||||
if mask.dim() == 3:
|
||||
mask = mask.squeeze(0)
|
||||
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
mask_np = (mask.float() * 255).byte().cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Apply the mask only to the alpha channel where the original alpha is non-zero. This preserves the original
|
||||
# image's transparency - else the transparent regions would end up as opaque black.
|
||||
|
||||
# Separate the image into R, G, B, and A channels
|
||||
image_np = pil_to_np(image)
|
||||
r, g, b, a = np.split(image_np, 4, axis=-1)
|
||||
|
||||
# Apply the mask to the alpha channel
|
||||
new_alpha = np.where(a.squeeze() > 0, mask_np, a.squeeze())
|
||||
|
||||
# Stack the RGB channels with the modified alpha
|
||||
masked_image_np = np.dstack([r.squeeze(), g.squeeze(), b.squeeze(), new_alpha])
|
||||
|
||||
# Convert back to an image (RGBA)
|
||||
masked_image = Image.fromarray(masked_image_np.astype(np.uint8), "RGBA")
|
||||
image_dto = context.images.save(image=masked_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -40,7 +40,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||
clip_vision_model: Literal["ViT-L", "ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
102
invokeai/app/invocations/sd3_model_loader.py
Normal file
102
invokeai/app/invocations/sd3_model_loader.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SD3 base model loader output."""
|
||||
|
||||
mmditx: TransformerField = OutputField(description=FieldDescriptions.mmditx, title="MMDiTX")
|
||||
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
|
||||
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_model_loader",
|
||||
title="SD3 Main Model",
|
||||
tags=["model", "sd3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
# TODO(ryand): Create a UIType.Sd3MainModelField to use here.
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.MainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
# TODO(ryand): Make the text encoders optional.
|
||||
# Note: The text encoders are optional for SD3. The model was trained with dropout, so any can be left out at
|
||||
# inference time. Typically, only the T5 encoder is omitted, since it is the largest by far.
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_l_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP L Embed",
|
||||
)
|
||||
|
||||
clip_g_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP G Embed",
|
||||
)
|
||||
|
||||
# TODO(ryand): Create a UIType.Sd3VaModelField to use here.
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
for key in [
|
||||
self.model.key,
|
||||
self.t5_encoder_model.key,
|
||||
self.clip_l_embed_model.key,
|
||||
self.clip_g_embed_model.key,
|
||||
self.vae_model.key,
|
||||
]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
# TODO(ryand): Figure out the sub-model types for SD3.
|
||||
mmditx = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer_l = self.clip_l_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_l = self.clip_l_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer_g = self.clip_g_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_g = self.clip_g_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer_t5 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(mmditx)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
mmditx=TransformerField(transformer=mmditx, loras=[]),
|
||||
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
|
||||
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
0
invokeai/app/invocations/sd3_text_encoder.py
Normal file
0
invokeai/app/invocations/sd3_text_encoder.py
Normal file
@@ -1,9 +1,11 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
@@ -23,12 +25,31 @@ SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
||||
}
|
||||
|
||||
|
||||
class SAMPointLabel(Enum):
|
||||
negative = -1
|
||||
neutral = 0
|
||||
positive = 1
|
||||
|
||||
|
||||
class SAMPoint(BaseModel):
|
||||
x: int = Field(..., description="The x-coordinate of the point")
|
||||
y: int = Field(..., description="The y-coordinate of the point")
|
||||
label: SAMPointLabel = Field(..., description="The label of the point")
|
||||
|
||||
|
||||
class SAMPointsField(BaseModel):
|
||||
points: list[SAMPoint] = Field(..., description="The points of the object")
|
||||
|
||||
def to_list(self) -> list[list[int]]:
|
||||
return [[point.x, point.y, point.label.value] for point in self.points]
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything",
|
||||
title="Segment Anything",
|
||||
tags=["prompt", "segmentation"],
|
||||
category="segmentation",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class SegmentAnythingInvocation(BaseInvocation):
|
||||
"""Runs a Segment Anything Model."""
|
||||
@@ -40,7 +61,13 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
|
||||
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||
bounding_boxes: list[BoundingBoxField] | None = InputField(
|
||||
default=None, description="The bounding boxes to prompt the SAM model with."
|
||||
)
|
||||
point_lists: list[SAMPointsField] | None = InputField(
|
||||
default=None,
|
||||
description="The list of point lists to prompt the SAM model with. Each list of points represents a single object.",
|
||||
)
|
||||
apply_polygon_refinement: bool = InputField(
|
||||
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||
default=True,
|
||||
@@ -50,12 +77,22 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
default="all",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_point_lists_or_bounding_box(self):
|
||||
if self.point_lists is None and self.bounding_boxes is None:
|
||||
raise ValueError("Either point_lists or bounding_box must be provided.")
|
||||
elif self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if len(self.bounding_boxes) == 0:
|
||||
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
|
||||
not self.point_lists or len(self.point_lists) == 0
|
||||
):
|
||||
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||
else:
|
||||
masks = self._segment(context=context, image=image_pil)
|
||||
@@ -83,14 +120,13 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
def _segment(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
) -> list[torch.Tensor]:
|
||||
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
# Convert the bounding boxes to the SAM input format.
|
||||
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||
sam_bounding_boxes = (
|
||||
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
|
||||
)
|
||||
sam_points = [p.to_list() for p in self.point_lists] if self.point_lists else None
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
@@ -98,7 +134,7 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes, point_lists=sam_points)
|
||||
|
||||
masks = self._process_masks(masks)
|
||||
if self.apply_polygon_refinement:
|
||||
@@ -141,9 +177,10 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
|
||||
return masks
|
||||
|
||||
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||
def _filter_masks(
|
||||
self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField] | None
|
||||
) -> list[torch.Tensor]:
|
||||
"""Filter the detected masks based on the specified mask filter."""
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
|
||||
if self.mask_filter == "all":
|
||||
return masks
|
||||
@@ -151,6 +188,10 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
# Find the largest mask.
|
||||
return [max(masks, key=lambda x: float(x.sum()))]
|
||||
elif self.mask_filter == "highest_box_score":
|
||||
assert (
|
||||
bounding_boxes is not None
|
||||
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
# Find the index of the bounding box with the highest score.
|
||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||
|
||||
@@ -110,15 +110,26 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
except Exception as e:
|
||||
raise ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||
path = self.__output_folder / image_name
|
||||
base_folder = self.__thumbnails_folder if thumbnail else self.__output_folder
|
||||
filename = get_thumbnail_name(image_name) if thumbnail else image_name
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
path = self.__thumbnails_folder / thumbnail_name
|
||||
# Strip any path information from the filename
|
||||
basename = Path(filename).name
|
||||
|
||||
return path
|
||||
if basename != filename:
|
||||
raise ValueError("Invalid image name, potential directory traversal detected")
|
||||
|
||||
image_path = base_folder / basename
|
||||
|
||||
# Ensure the image path is within the base folder to prevent directory traversal
|
||||
resolved_base = base_folder.resolve()
|
||||
resolved_image_path = image_path.resolve()
|
||||
|
||||
if not resolved_image_path.is_relative_to(resolved_base):
|
||||
raise ValueError("Image path outside outputs folder, potential directory traversal detected")
|
||||
|
||||
return resolved_image_path
|
||||
|
||||
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
|
||||
83
invokeai/backend/flux/custom_block_processor.py
Normal file
83
invokeai/backend/flux/custom_block_processor.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.math import attention
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
|
||||
|
||||
class CustomDoubleStreamBlockProcessor:
|
||||
"""A class containing a custom implementation of DoubleStreamBlock.forward() with additional features
|
||||
(IP-Adapter, etc.).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _double_stream_block_forward(
|
||||
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
|
||||
values.
|
||||
"""
|
||||
img_mod1, img_mod2 = block.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = block.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = block.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = block.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = einops.rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
||||
img_q, img_k = block.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = block.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = block.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = einops.rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
||||
txt_q, txt_k = block.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * block.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * block.img_mlp((1 + img_mod2.scale) * block.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * block.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * block.txt_mlp((1 + txt_mod2.scale) * block.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt, img_q
|
||||
|
||||
@staticmethod
|
||||
def custom_double_block_forward(
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
block_index: int,
|
||||
block: DoubleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
||||
- IP-Adapter support
|
||||
"""
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
for ip_adapter_extension in ip_adapter_extensions:
|
||||
img = ip_adapter_extension.run_ip_adapter(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
block_index=block_index,
|
||||
block=block,
|
||||
img_q=img_q,
|
||||
img=img,
|
||||
)
|
||||
|
||||
return img, txt
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@@ -7,6 +8,7 @@ from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFl
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
@@ -16,15 +18,23 @@ def denoise(
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
# positive text conditioning
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# negative text conditioning
|
||||
neg_txt: torch.Tensor | None,
|
||||
neg_txt_ids: torch.Tensor | None,
|
||||
neg_vec: torch.Tensor | None,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
cfg_scale: list[float],
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -37,10 +47,9 @@ def denoise(
|
||||
latents=img,
|
||||
),
|
||||
)
|
||||
step = 1
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Run ControlNet models.
|
||||
@@ -48,7 +57,7 @@ def denoise(
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step - 1,
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -61,7 +70,7 @@ def denoise(
|
||||
)
|
||||
|
||||
# Merge the ControlNet residuals from multiple ControlNets.
|
||||
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
|
||||
# TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
@@ -74,10 +83,39 @@ def denoise(
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
|
||||
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
|
||||
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
)
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
@@ -87,13 +125,12 @@ def denoise(
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step,
|
||||
step=step_index + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
step += 1
|
||||
|
||||
return img
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
|
||||
|
||||
class XLabsIPAdapterExtension:
|
||||
def __init__(
|
||||
self,
|
||||
model: XlabsIpAdapterFlux,
|
||||
image_prompt_clip_embed: torch.Tensor,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
self._model = model
|
||||
self._image_prompt_clip_embed = image_prompt_clip_embed
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
|
||||
self._image_proj: torch.Tensor | None = None
|
||||
|
||||
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
|
||||
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
|
||||
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
|
||||
|
||||
if timestep_index < first_step or timestep_index > last_step:
|
||||
return 0.0
|
||||
|
||||
if isinstance(self._weight, list):
|
||||
return self._weight[timestep_index]
|
||||
|
||||
return self._weight
|
||||
|
||||
@staticmethod
|
||||
def run_clip_image_encoder(
|
||||
pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection
|
||||
) -> torch.Tensor:
|
||||
clip_image_processor = CLIPImageProcessor()
|
||||
clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype)
|
||||
clip_image_embeds = image_encoder(clip_image).image_embeds
|
||||
return clip_image_embeds
|
||||
|
||||
def run_image_proj(self, dtype: torch.dtype):
|
||||
image_prompt_clip_embed = self._image_prompt_clip_embed.to(dtype=dtype)
|
||||
self._image_proj = self._model.image_proj(image_prompt_clip_embed)
|
||||
|
||||
def run_ip_adapter(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
block_index: int,
|
||||
block: DoubleStreamBlock,
|
||||
img_q: torch.Tensor,
|
||||
img: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""The logic in this function is based on:
|
||||
https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/modules/layers.py#L245-L301
|
||||
"""
|
||||
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
|
||||
if weight < 1e-6:
|
||||
return img
|
||||
|
||||
ip_adapter_block = self._model.ip_adapter_double_blocks.double_blocks[block_index]
|
||||
|
||||
ip_key = ip_adapter_block.ip_adapter_double_stream_k_proj(self._image_proj)
|
||||
ip_value = ip_adapter_block.ip_adapter_double_stream_v_proj(self._image_proj)
|
||||
|
||||
# Reshape projections for multi-head attention.
|
||||
ip_key = einops.rearrange(ip_key, "B L (H D) -> B H L D", H=block.num_heads)
|
||||
ip_value = einops.rearrange(ip_value, "B L (H D) -> B H L D", H=block.num_heads)
|
||||
|
||||
# Compute attention between IP projections and the latent query.
|
||||
ip_attn = torch.nn.functional.scaled_dot_product_attention(
|
||||
img_q, ip_key, ip_value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_attn = einops.rearrange(ip_attn, "B H L D -> B L (H D)", H=block.num_heads)
|
||||
|
||||
img = img + weight * ip_attn
|
||||
|
||||
return img
|
||||
0
invokeai/backend/flux/ip_adapter/__init__.py
Normal file
0
invokeai/backend/flux/ip_adapter/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# This file is based on:
|
||||
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/modules/layers.py#L221
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.math import attention
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
|
||||
|
||||
class IPDoubleStreamBlockProcessor(torch.nn.Module):
|
||||
"""Attention processor for handling IP-adapter with double stream block."""
|
||||
|
||||
def __init__(self, context_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
|
||||
# Ensure context_dim matches the dimension of image_proj
|
||||
self.context_dim = context_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# Initialize projections for IP-adapter
|
||||
self.ip_adapter_double_stream_k_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
|
||||
self.ip_adapter_double_stream_v_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
|
||||
|
||||
torch.nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
|
||||
torch.nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
|
||||
|
||||
torch.nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
|
||||
torch.nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: DoubleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
image_proj: torch.Tensor,
|
||||
ip_scale: float = 1.0,
|
||||
):
|
||||
# Prepare image for attention
|
||||
img_mod1, img_mod2 = attn.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
||||
|
||||
img_modulated = attn.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = attn.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = einops.rearrange(
|
||||
img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
|
||||
)
|
||||
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
txt_modulated = attn.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = einops.rearrange(
|
||||
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
|
||||
)
|
||||
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn1 = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
||||
|
||||
# print(f"txt_attn shape: {txt_attn.size()}")
|
||||
# print(f"img_attn shape: {img_attn.size()}")
|
||||
|
||||
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
# IP-adapter processing
|
||||
ip_query = img_q # latent sample query
|
||||
ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
|
||||
ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
|
||||
|
||||
# Reshape projections for multi-head attention
|
||||
ip_key = einops.rearrange(ip_key, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
|
||||
ip_value = einops.rearrange(ip_value, "B L (H D) -> B H L D", H=attn.num_heads, D=attn.head_dim)
|
||||
|
||||
# Compute attention between IP projections and the latent query
|
||||
ip_attention = torch.nn.functional.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_attention = einops.rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
|
||||
|
||||
img = img + ip_scale * ip_attention
|
||||
|
||||
return img, txt
|
||||
50
invokeai/backend/flux/ip_adapter/state_dict_utils.py
Normal file
50
invokeai/backend/flux/ip_adapter/state_dict_utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams
|
||||
|
||||
|
||||
def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an XLabs FLUX IP-Adapter model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely an XLabs IP-Adapter model.
|
||||
expected_keys = {
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias",
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight",
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias",
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.weight",
|
||||
"ip_adapter_proj_model.norm.bias",
|
||||
"ip_adapter_proj_model.norm.weight",
|
||||
"ip_adapter_proj_model.proj.bias",
|
||||
"ip_adapter_proj_model.proj.weight",
|
||||
}
|
||||
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams:
|
||||
num_double_blocks = 0
|
||||
context_dim = 0
|
||||
hidden_dim = 0
|
||||
|
||||
# Count the number of double blocks.
|
||||
double_block_index = 0
|
||||
while f"double_blocks.{double_block_index}.processor.ip_adapter_double_stream_k_proj.weight" in state_dict:
|
||||
double_block_index += 1
|
||||
num_double_blocks = double_block_index
|
||||
|
||||
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
|
||||
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
|
||||
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
|
||||
|
||||
return XlabsIpAdapterParams(
|
||||
num_double_blocks=num_double_blocks,
|
||||
context_dim=context_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
clip_embeddings_dim=clip_embeddings_dim,
|
||||
)
|
||||
67
invokeai/backend/flux/ip_adapter/xlabs_ip_adapter_flux.py
Normal file
67
invokeai/backend/flux/ip_adapter/xlabs_ip_adapter_flux.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import ImageProjModel
|
||||
|
||||
|
||||
class IPDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, context_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
self.ip_adapter_double_stream_k_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
|
||||
self.ip_adapter_double_stream_v_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
|
||||
|
||||
|
||||
class IPAdapterDoubleBlocks(torch.nn.Module):
|
||||
def __init__(self, num_double_blocks: int, context_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.double_blocks = torch.nn.ModuleList(
|
||||
[IPDoubleStreamBlock(context_dim, hidden_dim) for _ in range(num_double_blocks)]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XlabsIpAdapterParams:
|
||||
num_double_blocks: int
|
||||
context_dim: int
|
||||
hidden_dim: int
|
||||
|
||||
clip_embeddings_dim: int
|
||||
|
||||
|
||||
class XlabsIpAdapterFlux(torch.nn.Module):
|
||||
def __init__(self, params: XlabsIpAdapterParams):
|
||||
super().__init__()
|
||||
self.image_proj = ImageProjModel(
|
||||
cross_attention_dim=params.context_dim, clip_embeddings_dim=params.clip_embeddings_dim
|
||||
)
|
||||
self.ip_adapter_double_blocks = IPAdapterDoubleBlocks(
|
||||
num_double_blocks=params.num_double_blocks, context_dim=params.context_dim, hidden_dim=params.hidden_dim
|
||||
)
|
||||
|
||||
def load_xlabs_state_dict(self, state_dict: dict[str, torch.Tensor], assign: bool = False):
|
||||
"""We need this custom function to load state dicts rather than using .load_state_dict(...) because the model
|
||||
structure does not match the state_dict structure.
|
||||
"""
|
||||
# Split the state_dict into the image projection model and the double blocks.
|
||||
image_proj_sd: dict[str, torch.Tensor] = {}
|
||||
double_blocks_sd: dict[str, torch.Tensor] = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("ip_adapter_proj_model."):
|
||||
image_proj_sd[k] = v
|
||||
elif k.startswith("double_blocks."):
|
||||
double_blocks_sd[k] = v
|
||||
else:
|
||||
raise ValueError(f"Unexpected key: {k}")
|
||||
|
||||
# Initialize the image projection model.
|
||||
image_proj_sd = {k.replace("ip_adapter_proj_model.", ""): v for k, v in image_proj_sd.items()}
|
||||
self.image_proj.load_state_dict(image_proj_sd, assign=assign)
|
||||
|
||||
# Initialize the double blocks.
|
||||
double_blocks_sd = {k.replace("processor.", ""): v for k, v in double_blocks_sd.items()}
|
||||
self.ip_adapter_double_blocks.load_state_dict(double_blocks_sd, assign=assign)
|
||||
@@ -5,6 +5,8 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
@@ -88,8 +90,11 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -111,7 +116,19 @@ class Flux(nn.Module):
|
||||
if controlnet_double_block_residuals is not None:
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
assert isinstance(block, DoubleStreamBlock)
|
||||
|
||||
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
block_index=block_index,
|
||||
block=block,
|
||||
img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
ip_adapter_extensions=ip_adapter_extensions,
|
||||
)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
img += controlnet_double_block_residuals[block_index]
|
||||
|
||||
@@ -168,8 +168,17 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
|
||||
if device.type == "mps":
|
||||
orig_dtype = dtype
|
||||
dtype = torch.float16
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
if device.type == "mps":
|
||||
img_ids.to(orig_dtype)
|
||||
|
||||
return img_ids
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, TypeAlias
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -7,6 +7,14 @@ from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
# Type aliases for the inputs to the SAM model.
|
||||
ListOfBoundingBoxes: TypeAlias = list[list[int]]
|
||||
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
|
||||
ListOfPoints: TypeAlias = list[list[int]]
|
||||
"""A list of points. Each point is in the format [x, y]."""
|
||||
ListOfPointLabels: TypeAlias = list[int]
|
||||
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
|
||||
|
||||
|
||||
class SegmentAnythingPipeline(RawModel):
|
||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||
@@ -27,20 +35,53 @@ class SegmentAnythingPipeline(RawModel):
|
||||
|
||||
return calc_module_size(self._sam_model)
|
||||
|
||||
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
||||
def segment(
|
||||
self,
|
||||
image: Image.Image,
|
||||
bounding_boxes: list[list[int]] | None = None,
|
||||
point_lists: list[list[list[int]]] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run the SAM model.
|
||||
|
||||
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
|
||||
point_lists will be ignored.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to segment.
|
||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||
[xmin, ymin, xmax, ymax].
|
||||
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
|
||||
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
# Add batch dimension of 1 to the bounding boxes.
|
||||
boxes = [bounding_boxes]
|
||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||
|
||||
# Prep the inputs:
|
||||
# - Create a list of bounding boxes or points and labels.
|
||||
# - Add a batch dimension of 1 to the inputs.
|
||||
if bounding_boxes:
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
|
||||
input_points: list[ListOfPoints] | None = None
|
||||
input_labels: list[ListOfPointLabels] | None = None
|
||||
elif point_lists:
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = None
|
||||
input_points: list[ListOfPoints] | None = []
|
||||
input_labels: list[ListOfPointLabels] | None = []
|
||||
for point_list in point_lists:
|
||||
input_points.append([[p[0], p[1]] for p in point_list])
|
||||
input_labels.append([p[2] for p in point_list])
|
||||
|
||||
else:
|
||||
raise ValueError("Either bounding_boxes or points and labels must be provided.")
|
||||
|
||||
inputs = self._sam_processor(
|
||||
images=image,
|
||||
input_boxes=input_boxes,
|
||||
input_points=input_points,
|
||||
input_labels=input_labels,
|
||||
return_tensors="pt",
|
||||
).to(self._sam_model.device)
|
||||
outputs = self._sam_model(**inputs)
|
||||
masks = self._sam_processor.post_process_masks(
|
||||
masks=outputs.pred_masks,
|
||||
|
||||
@@ -53,6 +53,8 @@ class BaseModelType(str, Enum):
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
# TODO(ryand): Should this just be StableDiffusion3?
|
||||
StableDiffusion35 = "sd-3.5"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
@@ -394,6 +396,8 @@ class IPAdapterBaseConfig(ModelConfigBase):
|
||||
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
||||
"""Model config for IP Adapter diffusers format models."""
|
||||
|
||||
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
|
||||
# time. Need to go through the history to make sure I'm understanding this fully.
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
DiffusersConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
class ClipVisionLoader(ModelLoader):
|
||||
"""Class to load CLIPVision models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, DiffusersConfigBase):
|
||||
raise ValueError("Only DiffusersConfigBase models are currently supported here.")
|
||||
|
||||
if submodel_type is not None:
|
||||
raise Exception("There are no submodels in CLIP Vision models.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(
|
||||
model_path, torch_dtype=self._torch_dtype, local_files_only=True
|
||||
)
|
||||
assert isinstance(model, CLIPVisionModelWithProjection)
|
||||
|
||||
return model
|
||||
@@ -19,6 +19,10 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import infer_xlabs_ip_adapter_params_from_state_dict
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
|
||||
XlabsIpAdapterFlux,
|
||||
)
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
@@ -35,6 +39,7 @@ from invokeai.backend.model_manager.config import (
|
||||
CLIPEmbedDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
IPAdapterCheckpointConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
MainGGUFCheckpointConfig,
|
||||
@@ -170,7 +175,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
@@ -352,3 +357,26 @@ class FluxControlnetModel(ModelLoader):
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
|
||||
class FluxIpAdapterModel(ModelLoader):
|
||||
"""Class to load FLUX IP-Adapter models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, IPAdapterCheckpointConfig):
|
||||
raise ValueError(f"Unexpected model config type: {type(config)}.")
|
||||
|
||||
sd = load_file(Path(config.path))
|
||||
|
||||
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = XlabsIpAdapterFlux(params=params)
|
||||
|
||||
model.load_xlabs_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -22,7 +22,6 @@ from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
"""Class to load simple diffusers models."""
|
||||
|
||||
55
invokeai/backend/model_manager/load/model_loaders/sd3.py
Normal file
55
invokeai/backend/model_manager/load/model_loaders/sd3.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion35, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class FluxCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
# model = Flux(params[config.config_path])
|
||||
# sd = load_file(model_path)
|
||||
# if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
# sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
# new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||
# self._ram_cache.make_room(new_sd_size)
|
||||
# for k in sd.keys():
|
||||
# # We need to cast to bfloat16 due to it being the only currently supported dtype for inference
|
||||
# sd[k] = sd[k].to(torch.bfloat16)
|
||||
# model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
@@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
@@ -36,6 +37,7 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.sd3.sd3_state_dict_utils import is_sd3_checkpoint
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -119,6 +121,7 @@ class ModelProbe(object):
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
}
|
||||
@@ -240,11 +243,14 @@ class ModelProbe(object):
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
# The following prefixes appear when multiple models have been bundled together in a single file (I
|
||||
# believe the format originated in ComfyUI).
|
||||
# first_stage_model = VAE
|
||||
# cond_stage_model = Text Encoder
|
||||
# model.diffusion_model = UNet / Transformer
|
||||
"cond_stage_model.",
|
||||
"first_stage_model.",
|
||||
"model.diffusion_model.",
|
||||
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
|
||||
"double_blocks.",
|
||||
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
|
||||
# This prefix is typically used to distinguish between multiple models bundled in a single file.
|
||||
"model.diffusion_model.double_blocks.",
|
||||
@@ -252,6 +258,10 @@ class ModelProbe(object):
|
||||
):
|
||||
# Keys starting with double_blocks are associated with Flux models
|
||||
return ModelType.Main
|
||||
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
|
||||
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
|
||||
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
@@ -274,7 +284,14 @@ class ModelProbe(object):
|
||||
)
|
||||
):
|
||||
return ModelType.ControlNet
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
elif key.startswith(
|
||||
(
|
||||
"image_proj.",
|
||||
"ip_adapter.",
|
||||
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
|
||||
"ip_adapter_proj_model.",
|
||||
)
|
||||
):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
@@ -387,6 +404,9 @@ class ModelProbe(object):
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
elif base_type == BaseModelType.StableDiffusion35:
|
||||
# TODO(ryand): Think about what to do here.
|
||||
config_file = "sd3.5-large"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
@@ -506,7 +526,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||
if model_type != ModelType.Main or base_type in (BaseModelType.Flux, BaseModelType.StableDiffusion35):
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
@@ -531,6 +551,10 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
|
||||
if is_sd3_checkpoint(state_dict):
|
||||
return BaseModelType.StableDiffusion35
|
||||
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
@@ -672,6 +696,10 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
|
||||
if is_state_dict_xlabs_ip_adapter(checkpoint):
|
||||
return BaseModelType.Flux
|
||||
|
||||
for key in checkpoint.keys():
|
||||
if not key.startswith(("image_proj.", "ip_adapter.")):
|
||||
continue
|
||||
|
||||
@@ -25,22 +25,6 @@ class StarterModelBundles(BaseModel):
|
||||
models: list[StarterModel]
|
||||
|
||||
|
||||
ip_adapter_sd_image_encoder = StarterModel(
|
||||
name="IP Adapter SD1.5 Image Encoder",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_sd_image_encoder",
|
||||
description="IP Adapter SD Image Encoder",
|
||||
type=ModelType.CLIPVision,
|
||||
)
|
||||
|
||||
ip_adapter_sdxl_image_encoder = StarterModel(
|
||||
name="IP Adapter SDXL Image Encoder",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="InvokeAI/ip_adapter_sdxl_image_encoder",
|
||||
description="IP Adapter SDXL Image Encoder",
|
||||
type=ModelType.CLIPVision,
|
||||
)
|
||||
|
||||
cyberrealistic_negative = StarterModel(
|
||||
name="CyberRealistic Negative v3",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -49,6 +33,32 @@ cyberrealistic_negative = StarterModel(
|
||||
type=ModelType.TextualInversion,
|
||||
)
|
||||
|
||||
# region CLIP Image Encoders
|
||||
ip_adapter_sd_image_encoder = StarterModel(
|
||||
name="IP Adapter SD1.5 Image Encoder",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_sd_image_encoder",
|
||||
description="IP Adapter SD Image Encoder",
|
||||
type=ModelType.CLIPVision,
|
||||
)
|
||||
ip_adapter_sdxl_image_encoder = StarterModel(
|
||||
name="IP Adapter SDXL Image Encoder",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="InvokeAI/ip_adapter_sdxl_image_encoder",
|
||||
description="IP Adapter SDXL Image Encoder",
|
||||
type=ModelType.CLIPVision,
|
||||
)
|
||||
# Note: This model is installed from the same source as the CLIPEmbed model below. The model contains both the image
|
||||
# encoder and the text encoder, but we need separate model entries so that they get loaded correctly.
|
||||
clip_vit_l_image_encoder = StarterModel(
|
||||
name="clip-vit-large-patch14",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/clip-vit-large-patch14",
|
||||
description="CLIP ViT-L Image Encoder",
|
||||
type=ModelType.CLIPVision,
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region TextEncoders
|
||||
t5_base_encoder = StarterModel(
|
||||
name="t5_base_encoder",
|
||||
@@ -186,6 +196,16 @@ dreamshaper_sdxl = StarterModel(
|
||||
type=ModelType.Main,
|
||||
dependencies=[sdxl_fp16_vae_fix],
|
||||
)
|
||||
|
||||
archvis_sdxl = StarterModel(
|
||||
name="Architecture (RealVisXL5)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SG161222/RealVisXL_V5.0",
|
||||
description="A photorealistic model, with architecture among its many use cases",
|
||||
type=ModelType.Main,
|
||||
dependencies=[sdxl_fp16_vae_fix],
|
||||
)
|
||||
|
||||
sdxl_refiner = StarterModel(
|
||||
name="SDXL Refiner",
|
||||
base=BaseModelType.StableDiffusionXLRefiner,
|
||||
@@ -254,6 +274,14 @@ ip_adapter_sdxl = StarterModel(
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||
)
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="XLabs FLUX IP-Adapter",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
|
||||
description="FLUX IP-Adapter",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[clip_vit_l_image_encoder],
|
||||
)
|
||||
# endregion
|
||||
# region ControlNet
|
||||
qr_code_cnet_sd1 = StarterModel(
|
||||
@@ -545,6 +573,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
deliberate_inpainting_sd1,
|
||||
juggernaut_sdxl,
|
||||
dreamshaper_sdxl,
|
||||
archvis_sdxl,
|
||||
sdxl_refiner,
|
||||
sdxl_fp16_vae_fix,
|
||||
flux_vae,
|
||||
@@ -555,6 +584,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
ip_adapter_plus_sd1,
|
||||
ip_adapter_plus_face_sd1,
|
||||
ip_adapter_sdxl,
|
||||
ip_adapter_flux,
|
||||
qr_code_cnet_sd1,
|
||||
qr_code_cnet_sdxl,
|
||||
canny_sd1,
|
||||
@@ -642,6 +672,7 @@ flux_bundle: list[StarterModel] = [
|
||||
t5_8b_quantized_encoder,
|
||||
clip_l_encoder,
|
||||
union_cnet_flux,
|
||||
ip_adapter_flux,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
|
||||
|
||||
@@ -54,6 +54,11 @@ GGML_TENSOR_OP_TABLE = {
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
|
||||
}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
GGML_TENSOR_OP_TABLE.update(
|
||||
{torch.ops.aten.linear.default: dequantize_and_run} # pyright: ignore
|
||||
)
|
||||
|
||||
|
||||
class GGMLTensor(torch.Tensor):
|
||||
"""A torch.Tensor sub-class holding a quantized GGML tensor.
|
||||
|
||||
0
invokeai/backend/sd3/__init__.py
Normal file
0
invokeai/backend/sd3/__init__.py
Normal file
891
invokeai/backend/sd3/mmditx.py
Normal file
891
invokeai/backend/sd3/mmditx.py
Normal file
@@ -0,0 +1,891 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/mmditx.py
|
||||
|
||||
|
||||
### This file contains impls for MM-DiT, the core model component of SD3
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from invokeai.backend.sd3.other_impls import Mlp, attention
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size, strict=False)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = torch.nn.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor | None, scale: torch.Tensor) -> torch.Tensor:
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
grid_size: int,
|
||||
cls_token: bool = False,
|
||||
extra_tokens: int = 0,
|
||||
scaling_factor: Optional[float] = None,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
if scaling_factor is not None:
|
||||
grid = grid / scaling_factor
|
||||
if offset is not None:
|
||||
grid = grid - offset
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(torch.nn.Module):
|
||||
"""Embeds scalar timesteps into vector representations."""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
frequency_embedding_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=t.device
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class VectorEmbedder(torch.nn.Module):
|
||||
"""Embeds a flat vector of dimension input_dim"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
|
||||
class SelfAttention(torch.nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
attn_mode: str = "xformers",
|
||||
pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
rmsnorm: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = torch.nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
self.pre_only = pre_only
|
||||
|
||||
if qk_norm == "rms":
|
||||
self.ln_q = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm == "ln":
|
||||
self.ln_q = torch.nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = torch.nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm is None:
|
||||
self.ln_q = torch.nn.Identity()
|
||||
self.ln_k = torch.nn.Identity()
|
||||
else:
|
||||
raise ValueError(qk_norm)
|
||||
|
||||
def pre_attention(self, x: torch.Tensor):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.head_dim)
|
||||
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
(q, k, v) = self.pre_attention(x)
|
||||
x = attention(q, k, v, self.num_heads)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine: bool = False,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (torch.nn.Parameter): Learnable scaling parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.learnable_scale = elementwise_affine
|
||||
if self.learnable_scale:
|
||||
self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = torch.nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = torch.nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = torch.nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class DismantledBlock(torch.nn.Module):
|
||||
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
|
||||
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: str = "xformers",
|
||||
qkv_bias: bool = False,
|
||||
pre_only: bool = False,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if not rmsnorm:
|
||||
self.norm1 = torch.nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=pre_only,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = torch.nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
if not pre_only:
|
||||
if not swiglu:
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=torch.nn.GELU(approximate="tanh"),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.pre_only = pre_only
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
|
||||
assert x is not None, "pre_attention called with None input"
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
|
||||
6, dim=1
|
||||
)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def post_attention(
|
||||
self,
|
||||
attn: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
gate_msa: torch.Tensor,
|
||||
shift_mlp: torch.Tensor,
|
||||
scale_mlp: torch.Tensor,
|
||||
gate_mlp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
def pre_attention_x(
|
||||
self, x: torch.Tensor, c: torch.Tensor
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
]:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return (
|
||||
qkv,
|
||||
qkv2,
|
||||
(
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
),
|
||||
)
|
||||
|
||||
def post_attention_x(
|
||||
self,
|
||||
attn: torch.Tensor,
|
||||
attn2: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
gate_msa: torch.Tensor,
|
||||
shift_mlp: torch.Tensor,
|
||||
scale_mlp: torch.Tensor,
|
||||
gate_mlp: torch.Tensor,
|
||||
gate_msa2: torch.Tensor,
|
||||
attn1_dropout: float = 0.0,
|
||||
):
|
||||
assert not self.pre_only
|
||||
if attn1_dropout > 0.0:
|
||||
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
||||
attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
||||
else:
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + attn_
|
||||
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
||||
x = x + attn2_
|
||||
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
x = x + mlp_
|
||||
return x, (gate_msa, gate_msa2, gate_mlp, attn_, attn2_)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor):
|
||||
assert not self.pre_only
|
||||
if self.x_block_self_attn:
|
||||
(q, k, v), (q2, k2, v2), intermediates = self.pre_attention_x(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
attn2 = attention(q2, k2, v2, self.attn2.num_heads)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
(q, k, v), intermediates = self.pre_attention(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(
|
||||
context: torch.Tensor, x: torch.Tensor, context_block: DismantledBlock, x_block: DismantledBlock, c: torch.Tensor
|
||||
):
|
||||
assert context is not None, "block_mixing called with None context"
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o: list[torch.Tensor] = []
|
||||
for t in range(3):
|
||||
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||
q, k, v = tuple(o)
|
||||
|
||||
attn = attention(q, k, v, x_block.attn.num_heads)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
attn[:, context_qkv[0].shape[1] :],
|
||||
)
|
||||
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
else:
|
||||
context = None
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_q2, x_k2, x_v2 = x_qkv2
|
||||
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
|
||||
return context, x
|
||||
|
||||
|
||||
class JointBlock(torch.nn.Module):
|
||||
"""just a small wrapper to serve as a fsdp unit"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(
|
||||
*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
|
||||
|
||||
|
||||
class FinalLayer(torch.nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
total_out_channels: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = torch.nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.linear = (
|
||||
torch.nn.Linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if (total_out_channels is None)
|
||||
else torch.nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MMDiTX(torch.nn.Module):
|
||||
"""Diffusion model with a Transformer backbone."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int | None = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
depth: int = 28,
|
||||
mlp_ratio: float = 4.0,
|
||||
learn_sigma: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
context_embedder_config: Optional[Dict] = None,
|
||||
register_length: int = 0,
|
||||
attn_mode: str = "torch",
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
pos_embed_scaling_factor: Optional[float] = None,
|
||||
pos_embed_offset: Optional[float] = None,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches: Optional[int] = None,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn_layers: Optional[List[int]] = None,
|
||||
qkv_bias: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if verbose:
|
||||
print(
|
||||
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers or []
|
||||
|
||||
# apply magic --> this defines a head_size of 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=self.pos_embed_max_size is None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
if adm_in_channels is not None:
|
||||
assert isinstance(adm_in_channels, int)
|
||||
self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.context_embedder = torch.nn.Identity()
|
||||
if context_embedder_config is not None:
|
||||
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||
self.context_embedder = torch.nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
|
||||
|
||||
self.register_length = register_length
|
||||
if self.register_length > 0:
|
||||
self.register = torch.nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
|
||||
|
||||
# num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
# just use a buffer already
|
||||
if num_patches is not None:
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.joint_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=i == depth - 1,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
def cropped_pos_embed(self, hw: torch.Size) -> torch.Tensor:
|
||||
assert self.pos_embed_max_size is not None
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
# patched size
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
spatial_pos_embed: torch.Tensor = rearrange(
|
||||
self.pos_embed,
|
||||
"1 (h w) c -> 1 h w c",
|
||||
h=self.pos_embed_max_size,
|
||||
w=self.pos_embed_max_size,
|
||||
) # type: ignore Type checking does not correctly infer the type of the self.pos_embed buffer.
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, hw: Optional[torch.Size] = None) -> torch.Tensor:
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
if hw is None:
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
else:
|
||||
h, w = hw
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def forward_core_with_concat(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||
context if context is not None else torch.Tensor([]).type_as(x),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
for block in self.joint_blocks:
|
||||
context, x = block(context, x, c=c_mod)
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
c = c + y # (N, D)
|
||||
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x
|
||||
795
invokeai/backend/sd3/other_impls.py
Normal file
795
invokeai/backend/sd3/other_impls.py
Normal file
@@ -0,0 +1,795 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/other_impls.py
|
||||
|
||||
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
#################################################################################################
|
||||
### Core/Utility
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
|
||||
class Mlp(torch.nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
||||
bias: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
if act_layer is None:
|
||||
act_layer = torch.nn.functional.gelu
|
||||
|
||||
self.fc1 = torch.nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer
|
||||
self.fc2 = torch.nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### CLIP
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
out = attention(q, k, v, self.heads, mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
ACTIVATIONS = {
|
||||
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
}
|
||||
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(
|
||||
embed_dim,
|
||||
intermediate_size,
|
||||
embed_dim,
|
||||
act_layer=ACTIVATIONS[intermediate_activation],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x += self.self_attn(self.layer_norm1(x), mask)
|
||||
x += self.mlp(self.layer_norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
CLIPLayer(
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x, mask)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.encoder = CLIPEncoder(
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
|
||||
x = self.final_layer_norm(x)
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
pooled_output = x[
|
||||
torch.arange(x.shape[0], device=x.device),
|
||||
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
|
||||
]
|
||||
return x, i, pooled_output
|
||||
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.text_model(*args, **kwargs)
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
nesting_level = 0
|
||||
for char in string:
|
||||
if char == "(":
|
||||
if nesting_level == 0:
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
current_item = "("
|
||||
else:
|
||||
current_item = "("
|
||||
else:
|
||||
current_item += char
|
||||
nesting_level += 1
|
||||
elif char == ")":
|
||||
nesting_level -= 1
|
||||
if nesting_level == 0:
|
||||
result.append(current_item + ")")
|
||||
current_item = ""
|
||||
else:
|
||||
current_item += char
|
||||
else:
|
||||
current_item += char
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
return result
|
||||
|
||||
|
||||
def token_weights(string, current_weight):
|
||||
a = parse_parentheses(string)
|
||||
out = []
|
||||
for x in a:
|
||||
weight = current_weight
|
||||
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
|
||||
x = x[1:-1]
|
||||
xx = x.rfind(":")
|
||||
weight *= 1.1
|
||||
if xx > 0:
|
||||
try:
|
||||
weight = float(x[xx + 1 :])
|
||||
x = x[:xx]
|
||||
except:
|
||||
pass
|
||||
out += token_weights(x, weight)
|
||||
else:
|
||||
out += [(x, current_weight)]
|
||||
return out
|
||||
|
||||
|
||||
def escape_important(text):
|
||||
text = text.replace("\\)", "\0\1")
|
||||
text = text.replace("\\(", "\0\2")
|
||||
return text
|
||||
|
||||
|
||||
def unescape_important(text):
|
||||
text = text.replace("\0\1", ")")
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
max_length=77,
|
||||
pad_with_end=True,
|
||||
tokenizer=None,
|
||||
has_start_token=True,
|
||||
pad_to_max_length=True,
|
||||
min_length=None,
|
||||
extra_padding_token=None,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
|
||||
empty = self.tokenizer("")["input_ids"]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
self.extra_padding_token = extra_padding_token
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.max_word_length = 8
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
"""
|
||||
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
|
||||
"""
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
# tokenize words
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(" ")
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
# parse word
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
|
||||
|
||||
# reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
# determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
# break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group[:remaining_length]])
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
# add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
# start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group])
|
||||
t_group = []
|
||||
|
||||
# pad extra padding token first befor getting to the end token
|
||||
if self.extra_padding_token is not None:
|
||||
batch.extend([(self.extra_padding_token, 1.0, 0)] * (self.min_length - len(batch) - 1))
|
||||
# fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||
|
||||
return batched_tokens
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
class SDXLClipGTokenizer(SDTokenizer):
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(pad_with_end=False, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self):
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||
self.t5xxl = T5XXLTokenizer()
|
||||
|
||||
def tokenize_with_weights(self, text: str):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
|
||||
return out
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
|
||||
out, pooled = self([tokens])
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled
|
||||
output = [out[0:1]]
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
max_length=77,
|
||||
layer="last",
|
||||
layer_idx=None,
|
||||
textmodel_json_config=None,
|
||||
dtype=None,
|
||||
model_class=CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
|
||||
layer_norm_hidden_state=True,
|
||||
return_projected_pooled=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
self.max_length = max_length
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (
|
||||
self.layer,
|
||||
self.layer_idx,
|
||||
self.return_projected_pooled,
|
||||
)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
outputs = self.transformer(
|
||||
tokens,
|
||||
intermediate_output=self.layer_idx,
|
||||
final_layer_norm_intermediate=self.layer_norm_hidden_state,
|
||||
)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs[1]
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
class SDXLClipG(SDClipModel):
|
||||
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
|
||||
|
||||
def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
)
|
||||
|
||||
|
||||
class T5XXLModel(SDClipModel):
|
||||
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
|
||||
|
||||
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"end": 1, "pad": 0},
|
||||
model_class=T5,
|
||||
)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
pad_with_end=False,
|
||||
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
|
||||
has_start_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=77,
|
||||
)
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
forwarded_states = self.layer_norm(x)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
x += forwarded_states
|
||||
return x
|
||||
|
||||
|
||||
class T5Attention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
||||
super().__init__()
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = torch.nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(
|
||||
self.relative_attention_num_buckets, self.num_heads, device=device
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||
# now relative_position is in the range [0, inf)
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_position_if_large = torch.min(
|
||||
relative_position_if_large,
|
||||
torch.full_like(relative_position_if_large, num_buckets - 1),
|
||||
)
|
||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=True,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
if past_bias is not None:
|
||||
mask = past_bias
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
||||
class T5LayerSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
|
||||
x += output
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(
|
||||
T5LayerSelfAttention(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
)
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
x, past_bias = self.layer[0](x, past_bias)
|
||||
x = self.layer[-1](x)
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
vocab_size,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
||||
self.block = torch.nn.ModuleList(
|
||||
[
|
||||
T5Block(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids)
|
||||
past_bias = None
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class T5(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
self.encoder = T5Stack(
|
||||
self.num_layers,
|
||||
config_dict["d_model"],
|
||||
config_dict["d_model"],
|
||||
config_dict["d_ff"],
|
||||
config_dict["num_heads"],
|
||||
config_dict["vocab_size"],
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.encoder.embed_tokens = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.encoder(*args, **kwargs)
|
||||
609
invokeai/backend/sd3/sd3_impls.py
Normal file
609
invokeai/backend/sd3/sd3_impls.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py
|
||||
|
||||
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.sd3.mmditx import MMDiTX
|
||||
|
||||
#################################################################################################
|
||||
### MMDiT Model Wrapping
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||
|
||||
def __init__(self, shift: float = 1.0):
|
||||
super().__init__()
|
||||
self.shift = shift
|
||||
timesteps = 1000
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer("sigmas", ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def calculate_denoised(
|
||||
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core MM-DiT model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shift=1.0,
|
||||
device=None,
|
||||
dtype=torch.float32,
|
||||
file=None,
|
||||
prefix="",
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
|
||||
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
|
||||
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
|
||||
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
|
||||
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in file.keys() else None
|
||||
x_block_self_attn_layers = sorted(
|
||||
[
|
||||
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
|
||||
for key in list(filter(re.compile(".*.x_block.attn2.ln_k.weight").match, file.keys()))
|
||||
]
|
||||
)
|
||||
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
"in_features": context_shape[1],
|
||||
"out_features": context_shape[0],
|
||||
},
|
||||
}
|
||||
self.diffusion_model = MMDiTX(
|
||||
input_size=None,
|
||||
pos_embed_scaling_factor=None,
|
||||
pos_embed_offset=None,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=16,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
adm_in_channels=adm_in_channels,
|
||||
context_embedder_config=context_embedder_config,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
|
||||
def apply_model(
|
||||
self, x: torch.Tensor, sigma: float, c_crossattn: torch.Tensor | None = None, y: torch.Tensor | None = None
|
||||
):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""Helper for applying CFG Scaling to diffusion outputs"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, x, timestep, cond, uncond, cond_scale):
|
||||
# Run cond and uncond in a batch together
|
||||
batched = self.model.apply_model(
|
||||
torch.cat([x, x]),
|
||||
torch.cat([timestep, timestep]),
|
||||
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
|
||||
y=torch.cat([cond["y"], uncond["y"]]),
|
||||
)
|
||||
# Then split and apply CFG Scaling
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
scaled = neg_out + (pos_out - neg_out) * cond_scale
|
||||
return scaled
|
||||
|
||||
|
||||
class SD3LatentFormat:
|
||||
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
"""Quick RGB approximate preview of sd3 latents"""
|
||||
factors = torch.tensor(
|
||||
[
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[0.0028, 0.0312, 0.0650],
|
||||
[0.1848, 0.0762, 0.0360],
|
||||
[0.0944, 0.0360, 0.0889],
|
||||
[0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259],
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
|
||||
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Samplers
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_euler(model, x, sigmas, extra_args=None):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_dpmpp_2m(model, x, sigmas, extra_args=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_denoised is None or sigmas[i + 1] == 0:
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### VAE
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = None
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = x
|
||||
hidden = self.norm1(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv1(hidden)
|
||||
hidden = self.norm2(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv2(hidden)
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class AttnBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.norm(x)
|
||||
q = self.q(hidden)
|
||||
k = self.k(hidden)
|
||||
v = self.v(hidden)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(
|
||||
lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
||||
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
hidden = self.proj_out(hidden)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class Downsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class VAEEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
in_channels=3,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = torch.nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = torch.nn.ModuleList()
|
||||
attn = torch.nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = torch.nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
||||
self.down.append(down)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class VAEDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
resolution=256,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
# upsampling
|
||||
self.up = torch.nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = torch.nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = torch.nn.Module()
|
||||
up.block = block
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
hidden = self.conv_in(z)
|
||||
# middle
|
||||
hidden = self.mid.block_1(hidden)
|
||||
hidden = self.mid.attn_1(hidden)
|
||||
hidden = self.mid.block_2(hidden)
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
hidden = self.up[i_level].block[i_block](hidden)
|
||||
if i_level != 0:
|
||||
hidden = self.up[i_level].upsample(hidden)
|
||||
# end
|
||||
hidden = self.norm_out(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv_out(hidden)
|
||||
return hidden
|
||||
|
||||
|
||||
class SDVAE(torch.nn.Module):
|
||||
def __init__(self, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
||||
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def decode(self, latent):
|
||||
return self.decoder(latent)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def encode(self, image):
|
||||
hidden = self.encoder(image)
|
||||
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
||||
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
426
invokeai/backend/sd3/sd3_infer.py
Normal file
426
invokeai/backend/sd3/sd3_infer.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_infer.py
|
||||
|
||||
# NOTE: Must have folder `models` with the following files:
|
||||
# - `clip_g.safetensors` (openclip bigG, same as SDXL)
|
||||
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
|
||||
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
|
||||
# - `sd3_medium.safetensors` (or whichever main MMDiT model file)
|
||||
# Also can have
|
||||
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import sd3_impls
|
||||
import torch
|
||||
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
from sd3_impls import SDVAE, BaseModel, CFGDenoiser, SD3LatentFormat
|
||||
from tqdm import tqdm
|
||||
|
||||
#################################################################################################
|
||||
### Wrappers for model parts
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def load_into(f, model, prefix, device, dtype=None):
|
||||
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
|
||||
for key in f.keys():
|
||||
if key.startswith(prefix) and not key.startswith("loss."):
|
||||
path = key[len(prefix) :].split(".")
|
||||
obj = model
|
||||
for p in path:
|
||||
if obj is list:
|
||||
obj = obj[int(p)]
|
||||
else:
|
||||
obj = getattr(obj, p, None)
|
||||
if obj is None:
|
||||
print(f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model")
|
||||
break
|
||||
if obj is None:
|
||||
continue
|
||||
try:
|
||||
tensor = f.get_tensor(key).to(device=device)
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
obj.requires_grad_(False)
|
||||
obj.set_(tensor)
|
||||
except Exception as e:
|
||||
print(f"Failed to load key '{key}' in safetensors file: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
}
|
||||
|
||||
|
||||
class ClipG:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_g.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
|
||||
class ClipL:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_l.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDClipModel(
|
||||
layer="hidden",
|
||||
layer_idx=-2,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
layer_norm_hidden_state=False,
|
||||
return_projected_pooled=False,
|
||||
textmodel_json_config=CLIPL_CONFIG,
|
||||
)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class T5XXL:
|
||||
def __init__(self):
|
||||
with safe_open("models/t5xxl.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
class SD3:
|
||||
def __init__(self, model, shift, verbose=False):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = BaseModel(
|
||||
shift=shift,
|
||||
file=f,
|
||||
prefix="model.diffusion_model.",
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
verbose=verbose,
|
||||
).eval()
|
||||
load_into(f, self.model, "model.", "cpu", torch.float16)
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, model):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
|
||||
prefix = ""
|
||||
if any(k.startswith("first_stage_model.") for k in f.keys()):
|
||||
prefix = "first_stage_model."
|
||||
load_into(f, self.model, prefix, "cpu", torch.float16)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Main inference logic
|
||||
#################################################################################################
|
||||
|
||||
|
||||
# Note: Sigma shift value, publicly released models use 3.0
|
||||
SHIFT = 3.0
|
||||
# Naturally, adjust to the width/height of the model you have
|
||||
WIDTH = 1024
|
||||
HEIGHT = 1024
|
||||
# Pick your prompt
|
||||
PROMPT = "a photo of a cat"
|
||||
# Most models prefer the range of 4-5, but still work well around 7
|
||||
CFG_SCALE = 4.5
|
||||
# Different models want different step counts but most will be good at 50, albeit that's slow to run
|
||||
# sd3_medium is quite decent at 28 steps
|
||||
STEPS = 40
|
||||
# Seed
|
||||
SEED = 23
|
||||
# SEEDTYPE = "fixed"
|
||||
SEEDTYPE = "rand"
|
||||
# SEEDTYPE = "roll"
|
||||
# Actual model file path
|
||||
# MODEL = "models/sd3_medium.safetensors"
|
||||
# MODEL = "models/sd3.5_large_turbo.safetensors"
|
||||
MODEL = "models/sd3.5_large.safetensors"
|
||||
# VAE model file path, or set None to use the same model file
|
||||
VAEFile = None # "models/sd3_vae.safetensors"
|
||||
# Optional init image file path
|
||||
INIT_IMAGE = None
|
||||
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
|
||||
DENOISE = 0.6
|
||||
# Output file path
|
||||
OUTDIR = "outputs"
|
||||
# SAMPLER
|
||||
# SAMPLER = "euler"
|
||||
SAMPLER = "dpmpp_2m"
|
||||
|
||||
|
||||
class SD3Inferencer:
|
||||
def print(self, txt):
|
||||
if self.verbose:
|
||||
print(txt)
|
||||
|
||||
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT, verbose=False):
|
||||
self.verbose = verbose
|
||||
print("Loading tokenizers...")
|
||||
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
|
||||
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
|
||||
# (T5 tokenizer is different though)
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
print("Loading OpenAI CLIP L...")
|
||||
self.clip_l = ClipL()
|
||||
print("Loading OpenCLIP bigG...")
|
||||
self.clip_g = ClipG()
|
||||
print("Loading Google T5-v1-XXL...")
|
||||
self.t5xxl = T5XXL()
|
||||
print(f"Loading SD3 model {os.path.basename(model)}...")
|
||||
self.sd3 = SD3(model, shift, verbose)
|
||||
print("Loading VAE model...")
|
||||
self.vae = VAE(vae or model)
|
||||
print("Models loaded.")
|
||||
|
||||
def get_empty_latent(self, width, height):
|
||||
self.print("Prep an empty latent...")
|
||||
return torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
|
||||
|
||||
def get_sigmas(self, sampling, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(sampling.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def get_noise(self, seed, latent):
|
||||
generator = torch.manual_seed(seed)
|
||||
self.print(f"dtype = {latent.dtype}, layout = {latent.layout}, device = {latent.device}")
|
||||
return torch.randn(
|
||||
latent.size(),
|
||||
dtype=torch.float32,
|
||||
layout=latent.layout,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
).to(latent.dtype)
|
||||
|
||||
def get_cond(self, prompt):
|
||||
self.print("Encode prompt...")
|
||||
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
||||
l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
|
||||
g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
|
||||
t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
def max_denoise(self, sigmas):
|
||||
max_sigma = float(self.sd3.model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
def fix_cond(self, cond):
|
||||
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
|
||||
return {"c_crossattn": cond, "y": pooled}
|
||||
|
||||
def do_sampling(
|
||||
self,
|
||||
latent,
|
||||
seed,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler="dpmpp_2m",
|
||||
denoise=1.0,
|
||||
) -> torch.Tensor:
|
||||
self.print("Sampling...")
|
||||
latent = latent.half().cuda()
|
||||
self.sd3.model = self.sd3.model.cuda()
|
||||
noise = self.get_noise(seed, latent).cuda()
|
||||
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
|
||||
sigmas = sigmas[int(steps * (1 - denoise)) :]
|
||||
conditioning = self.fix_cond(conditioning)
|
||||
neg_cond = self.fix_cond(neg_cond)
|
||||
extra_args = {"cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale}
|
||||
noise_scaled = self.sd3.model.model_sampling.noise_scaling(sigmas[0], noise, latent, self.max_denoise(sigmas))
|
||||
sample_fn = getattr(sd3_impls, f"sample_{sampler}")
|
||||
latent = sample_fn(CFGDenoiser(self.sd3.model), noise_scaled, sigmas, extra_args=extra_args)
|
||||
latent = SD3LatentFormat().process_out(latent)
|
||||
self.sd3.model = self.sd3.model.cpu()
|
||||
self.print("Sampling done")
|
||||
return latent
|
||||
|
||||
def vae_encode(self, image) -> torch.Tensor:
|
||||
self.print("Encoding image to latent...")
|
||||
image = image.convert("RGB")
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
image_np = np.moveaxis(image_np, 2, 0)
|
||||
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
|
||||
image_torch = torch.from_numpy(batch_images)
|
||||
image_torch = 2.0 * image_torch - 1.0
|
||||
image_torch = image_torch.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
latent = self.vae.model.encode(image_torch).cpu()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
self.print("Encoded")
|
||||
return latent
|
||||
|
||||
def vae_decode(self, latent) -> Image.Image:
|
||||
self.print("Decoding latent to image...")
|
||||
latent = latent.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
image = self.vae.model.decode(latent)
|
||||
image = image.float()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
out_image = Image.fromarray(decoded_np)
|
||||
self.print("Decoded")
|
||||
return out_image
|
||||
|
||||
def gen_image(
|
||||
self,
|
||||
prompts=[PROMPT],
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
steps=STEPS,
|
||||
cfg_scale=CFG_SCALE,
|
||||
sampler=SAMPLER,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
out_dir=OUTDIR,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
):
|
||||
latent = self.get_empty_latent(width, height)
|
||||
if init_image:
|
||||
image_data = Image.open(init_image)
|
||||
image_data = image_data.resize((width, height), Image.LANCZOS)
|
||||
latent = self.vae_encode(image_data)
|
||||
latent = SD3LatentFormat().process_in(latent)
|
||||
neg_cond = self.get_cond("")
|
||||
seed_num = None
|
||||
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True)
|
||||
for i, prompt in pbar:
|
||||
if seed_type == "roll":
|
||||
seed_num = seed if seed_num is None else seed_num + 1
|
||||
elif seed_type == "rand":
|
||||
seed_num = torch.randint(0, 100000, (1,)).item()
|
||||
else: # fixed
|
||||
seed_num = seed
|
||||
conditioning = self.get_cond(prompt)
|
||||
sampled_latent = self.do_sampling(
|
||||
latent,
|
||||
seed_num,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler,
|
||||
denoise if init_image else 1.0,
|
||||
)
|
||||
image = self.vae_decode(sampled_latent)
|
||||
save_path = os.path.join(out_dir, f"{i:06d}.png")
|
||||
self.print(f"Will save to {save_path}")
|
||||
image.save(save_path)
|
||||
self.print("Done")
|
||||
|
||||
|
||||
CONFIGS = {
|
||||
"sd3_medium": {
|
||||
"shift": 1.0,
|
||||
"cfg": 5.0,
|
||||
"steps": 50,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large": {
|
||||
"shift": 3.0,
|
||||
"cfg": 4.5,
|
||||
"steps": 40,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(
|
||||
prompt=PROMPT,
|
||||
model=MODEL,
|
||||
out_dir=OUTDIR,
|
||||
postfix=None,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
sampler=None,
|
||||
steps=None,
|
||||
cfg=None,
|
||||
shift=None,
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
vae=VAEFile,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
verbose=False,
|
||||
):
|
||||
steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
|
||||
cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
|
||||
shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
|
||||
sampler = sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
|
||||
|
||||
inferencer = SD3Inferencer()
|
||||
inferencer.load(model, vae, shift, verbose)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if os.path.splitext(prompt)[-1] == ".txt":
|
||||
with open(prompt, "r") as f:
|
||||
prompts = [l.strip() for l in f.readlines()]
|
||||
else:
|
||||
prompts = [prompt]
|
||||
|
||||
out_dir = os.path.join(
|
||||
out_dir,
|
||||
os.path.splitext(os.path.basename(model))[0],
|
||||
os.path.splitext(os.path.basename(prompt))[0][:50]
|
||||
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
|
||||
)
|
||||
print(f"Saving images to {out_dir}")
|
||||
os.makedirs(out_dir, exist_ok=False)
|
||||
|
||||
inferencer.gen_image(
|
||||
prompts,
|
||||
width,
|
||||
height,
|
||||
steps,
|
||||
cfg,
|
||||
sampler,
|
||||
seed,
|
||||
seed_type,
|
||||
out_dir,
|
||||
init_image,
|
||||
denoise,
|
||||
)
|
||||
|
||||
|
||||
fire.Fire(main)
|
||||
72
invokeai/backend/sd3/sd3_mmditx.py
Normal file
72
invokeai/backend/sd3/sd3_mmditx.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.sd3.mmditx import MMDiTX
|
||||
from invokeai.backend.sd3.sd3_impls import ModelSamplingDiscreteFlow
|
||||
|
||||
|
||||
class ContextEmbedderConfig(TypedDict):
|
||||
target: Literal["torch.nn.Linear"]
|
||||
params: dict[str, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sd3MMDiTXParams:
|
||||
patch_size: int
|
||||
depth: int
|
||||
num_patches: int
|
||||
pos_embed_max_size: int
|
||||
adm_in_channels: int
|
||||
context_shape: tuple[int, int]
|
||||
qk_norm: Literal["rms", None]
|
||||
x_block_self_attn_layers: list[int]
|
||||
context_embedder_config: ContextEmbedderConfig
|
||||
|
||||
|
||||
class Sd3MMDiTX(torch.nn.Module):
|
||||
"""This class is based closely on
|
||||
https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py#L53
|
||||
but has more standard model loading semantics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Sd3MMDiTXParams,
|
||||
shift: float = 1.0,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.diffusion_model = MMDiTX(
|
||||
input_size=None,
|
||||
pos_embed_scaling_factor=None,
|
||||
pos_embed_offset=None,
|
||||
pos_embed_max_size=params.pos_embed_max_size,
|
||||
patch_size=params.patch_size,
|
||||
in_channels=16,
|
||||
depth=params.depth,
|
||||
num_patches=params.num_patches,
|
||||
adm_in_channels=params.adm_in_channels,
|
||||
context_embedder_config=params.context_embedder_config,
|
||||
qk_norm=params.qk_norm,
|
||||
x_block_self_attn_layers=params.x_block_self_attn_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
|
||||
def apply_model(self, x: torch.Tensor, sigma: torch.Tensor, c_crossattn: torch.Tensor, y: torch.Tensor):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, x: torch.Tensor, sigma: float, c_crossattn: torch.Tensor, y: torch.Tensor):
|
||||
return self.apply_model(x=x, sigma=sigma, c_crossattn=c_crossattn, y=y)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
70
invokeai/backend/sd3/sd3_state_dict_utils.py
Normal file
70
invokeai/backend/sd3/sd3_state_dict_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import math
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
from invokeai.backend.sd3.sd3_mmditx import ContextEmbedderConfig, Sd3MMDiTXParams
|
||||
|
||||
|
||||
def is_sd3_checkpoint(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an SD3 checkpoint like this one?:
|
||||
https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors
|
||||
|
||||
Note that this checkpoint format contains both the VAE and the MMDiTX model.
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely a SD3 checkpoint.
|
||||
expected_keys = {
|
||||
# VAE decoder and encoder keys.
|
||||
"first_stage_model.decoder.conv_in.bias",
|
||||
"first_stage_model.decoder.conv_in.weight",
|
||||
"first_stage_model.encoder.conv_in.bias",
|
||||
"first_stage_model.encoder.conv_in.weight",
|
||||
# MMDiTX keys.
|
||||
"model.diffusion_model.final_layer.linear.bias",
|
||||
"model.diffusion_model.final_layer.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.ln_k.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.ln_q.weight",
|
||||
}
|
||||
|
||||
return expected_keys.issubset(sd.keys())
|
||||
|
||||
|
||||
def infer_sd3_mmditx_params(sd: Dict[str, Any], prefix: str = "model.diffusion_model.") -> Sd3MMDiTXParams:
|
||||
"""Infer the MMDiTX model parameters from the state dict.
|
||||
|
||||
This logic is based on:
|
||||
https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py#L68-L88
|
||||
"""
|
||||
patch_size = sd[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = sd[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = sd[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = sd[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = sd[f"{prefix}context_embedder.weight"].shape
|
||||
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in sd else None
|
||||
x_block_self_attn_layers = sorted(
|
||||
[
|
||||
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
|
||||
for key in list(filter(re.compile(".*.x_block.attn2.ln_k.weight").match, sd.keys()))
|
||||
]
|
||||
)
|
||||
|
||||
context_embedder_config: ContextEmbedderConfig = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
"in_features": context_shape[1],
|
||||
"out_features": context_shape[0],
|
||||
},
|
||||
}
|
||||
return Sd3MMDiTXParams(
|
||||
patch_size=patch_size,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
adm_in_channels=adm_in_channels,
|
||||
context_shape=context_shape,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
context_embedder_config=context_embedder_config,
|
||||
)
|
||||
@@ -114,8 +114,7 @@
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"ts-toolbelt": "^9.6.0"
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@invoke-ai/eslint-config-react": "^0.0.14",
|
||||
@@ -149,8 +148,8 @@
|
||||
"prettier": "^3.3.3",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"storybook": "^8.3.4",
|
||||
"ts-toolbelt": "^9.6.0",
|
||||
"tsafe": "^1.7.5",
|
||||
"type-fest": "^4.26.1",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^5.4.8",
|
||||
"vite-plugin-css-injected-by-js": "^3.5.2",
|
||||
|
||||
10
invokeai/frontend/web/pnpm-lock.yaml
generated
10
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -277,12 +277,12 @@ devDependencies:
|
||||
storybook:
|
||||
specifier: ^8.3.4
|
||||
version: 8.3.4
|
||||
ts-toolbelt:
|
||||
specifier: ^9.6.0
|
||||
version: 9.6.0
|
||||
tsafe:
|
||||
specifier: ^1.7.5
|
||||
version: 1.7.5
|
||||
type-fest:
|
||||
specifier: ^4.26.1
|
||||
version: 4.26.1
|
||||
typescript:
|
||||
specifier: ^5.6.2
|
||||
version: 5.6.2
|
||||
@@ -8830,10 +8830,6 @@ packages:
|
||||
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
|
||||
dev: false
|
||||
|
||||
/ts-toolbelt@9.6.0:
|
||||
resolution: {integrity: sha512-nsZd8ZeNUzukXPlJmTBwUAuABDe/9qtVDelJeT/qW0ow3ZS3BsQJtNkan1802aM9Uf68/Y8ljw86Hu0h5IUW3w==}
|
||||
dev: true
|
||||
|
||||
/tsafe@1.7.5:
|
||||
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
|
||||
dev: true
|
||||
|
||||
@@ -93,7 +93,9 @@
|
||||
"placeholderSelectAModel": "Modell auswählen",
|
||||
"reset": "Zurücksetzen",
|
||||
"none": "Keine",
|
||||
"new": "Neu"
|
||||
"new": "Neu",
|
||||
"ok": "OK",
|
||||
"close": "Schließen"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -156,7 +158,11 @@
|
||||
"displayBoardSearch": "Board durchsuchen",
|
||||
"displaySearch": "Bild suchen",
|
||||
"go": "Los",
|
||||
"jump": "Springen"
|
||||
"jump": "Springen",
|
||||
"assetsTab": "Dateien, die Sie zur Verwendung in Ihren Projekten hochgeladen haben.",
|
||||
"imagesTab": "Bilder, die Sie in Invoke erstellt und gespeichert haben.",
|
||||
"boardsSettings": "Ordnereinstellungen",
|
||||
"imagesSettings": "Galeriebildereinstellungen"
|
||||
},
|
||||
"hotkeys": {
|
||||
"noHotkeysFound": "Kein Hotkey gefunden",
|
||||
@@ -267,6 +273,18 @@
|
||||
"applyFilter": {
|
||||
"title": "Filter anwenden",
|
||||
"desc": "Wende den ausstehenden Filter auf die ausgewählte Ebene an."
|
||||
},
|
||||
"cancelFilter": {
|
||||
"title": "Filter abbrechen",
|
||||
"desc": "Den ausstehenden Filter abbrechen."
|
||||
},
|
||||
"applyTransform": {
|
||||
"desc": "Die ausstehende Transformation auf die ausgewählte Ebene anwenden.",
|
||||
"title": "Transformation anwenden"
|
||||
},
|
||||
"cancelTransform": {
|
||||
"title": "Transformation abbrechen",
|
||||
"desc": "Die ausstehende Transformation abbrechen."
|
||||
}
|
||||
},
|
||||
"viewer": {
|
||||
@@ -563,7 +581,18 @@
|
||||
"scanResults": "Ergebnisse des Scans",
|
||||
"urlOrLocalPathHelper": "URLs sollten auf eine einzelne Datei deuten. Lokale Pfade können zusätzlich auch auf einen Ordner für ein einzelnes Diffusers-Modell hinweisen.",
|
||||
"inplaceInstallDesc": "Installieren Sie Modelle, ohne die Dateien zu kopieren. Wenn Sie das Modell verwenden, wird es direkt von seinem Speicherort geladen. Wenn deaktiviert, werden die Dateien während der Installation in das von Invoke verwaltete Modellverzeichnis kopiert.",
|
||||
"scanFolderHelper": "Der Ordner wird rekursiv nach Modellen durchsucht. Dies kann bei sehr großen Ordnern etwas dauern."
|
||||
"scanFolderHelper": "Der Ordner wird rekursiv nach Modellen durchsucht. Dies kann bei sehr großen Ordnern etwas dauern.",
|
||||
"includesNModels": "Enthält {{n}} Modelle und deren Abhängigkeiten",
|
||||
"starterBundles": "Starterpakete",
|
||||
"installingXModels_one": "{{count}} Modell wird installiert",
|
||||
"installingXModels_other": "{{count}} Modelle werden installiert",
|
||||
"skippingXDuplicates_one": ", überspringe {{count}} Duplikat",
|
||||
"skippingXDuplicates_other": ", überspringe {{count}} Duplikate",
|
||||
"installingModel": "Modell wird installiert",
|
||||
"loraTriggerPhrases": "LoRA-Auslösephrasen",
|
||||
"installingBundle": "Bündel wird installiert",
|
||||
"triggerPhrases": "Auslösephrasen",
|
||||
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Bilder",
|
||||
@@ -667,7 +696,8 @@
|
||||
"about": "Über",
|
||||
"submitSupportTicket": "Support-Ticket senden",
|
||||
"toggleRightPanel": "Rechtes Bedienfeld umschalten (G)",
|
||||
"toggleLeftPanel": "Linkes Bedienfeld umschalten (T)"
|
||||
"toggleLeftPanel": "Linkes Bedienfeld umschalten (T)",
|
||||
"uploadImages": "Bild(er) hochladen"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Board automatisch erstellen",
|
||||
@@ -702,7 +732,7 @@
|
||||
"shared": "Geteilte Ordner",
|
||||
"archiveBoard": "Ordner archivieren",
|
||||
"archived": "Archiviert",
|
||||
"noBoards": "Kein {boardType}} Ordner",
|
||||
"noBoards": "Kein {{boardType}} Ordner",
|
||||
"hideBoards": "Ordner verstecken",
|
||||
"viewBoards": "Ordner ansehen",
|
||||
"deletedPrivateBoardsCannotbeRestored": "Gelöschte Boards können nicht wiederhergestellt werden. Wenn Sie „Nur Board löschen“ wählen, werden die Bilder in einen privaten, nicht kategorisierten Status für den Ersteller des Bildes versetzt.",
|
||||
@@ -811,7 +841,8 @@
|
||||
"parameterSet": "Parameter {{parameter}} setzen",
|
||||
"recallParameter": "{{label}} Abrufen",
|
||||
"parsingFailed": "Parsing Fehlgeschlagen",
|
||||
"canvasV2Metadata": "Leinwand"
|
||||
"canvasV2Metadata": "Leinwand",
|
||||
"guidance": "Führung"
|
||||
},
|
||||
"popovers": {
|
||||
"noiseUseCPU": {
|
||||
@@ -1137,7 +1168,9 @@
|
||||
"workflowNotes": "Notizen",
|
||||
"workflowTags": "Tags",
|
||||
"workflowVersion": "Version",
|
||||
"saveToGallery": "In Galerie speichern"
|
||||
"saveToGallery": "In Galerie speichern",
|
||||
"noWorkflows": "Keine Arbeitsabläufe",
|
||||
"noMatchingWorkflows": "Keine passenden Arbeitsabläufe"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||
|
||||
@@ -1842,6 +1842,17 @@
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"segment": {
|
||||
"autoMask": "Auto Mask",
|
||||
"pointType": "Point Type",
|
||||
"foreground": "Foreground",
|
||||
"background": "Background",
|
||||
"neutral": "Neutral",
|
||||
"reset": "Reset",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel",
|
||||
"process": "Process"
|
||||
},
|
||||
"settings": {
|
||||
"snapToGrid": {
|
||||
"label": "Snap to Grid",
|
||||
@@ -1852,10 +1863,10 @@
|
||||
"label": "Preserve Masked Region",
|
||||
"alert": "Preserving Masked Region"
|
||||
},
|
||||
"isolatedPreview": "Isolated Preview",
|
||||
"isolatedStagingPreview": "Isolated Staging Preview",
|
||||
"isolatedFilteringPreview": "Isolated Filtering Preview",
|
||||
"isolatedTransformingPreview": "Isolated Transforming Preview",
|
||||
"isolatedPreview": "Isolated Preview",
|
||||
"isolatedLayerPreview": "Isolated Layer Preview",
|
||||
"isolatedLayerPreviewDesc": "Whether to show only this layer when performing operations like filtering or transforming.",
|
||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||
"pressureSensitivity": "Pressure Sensitivity"
|
||||
},
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"settingsLabel": "Paramètres",
|
||||
"img2img": "Image vers Image",
|
||||
"nodes": "Processus",
|
||||
"upload": "Télécharger",
|
||||
"upload": "Importer",
|
||||
"load": "Charger",
|
||||
"back": "Retour",
|
||||
"statusDisconnected": "Hors ligne",
|
||||
@@ -51,7 +51,7 @@
|
||||
"green": "Vert",
|
||||
"delete": "Supprimer",
|
||||
"simple": "Simple",
|
||||
"template": "Modèle",
|
||||
"template": "Template",
|
||||
"advanced": "Avancé",
|
||||
"copy": "Copier",
|
||||
"saveAs": "Enregistrer sous",
|
||||
@@ -117,8 +117,8 @@
|
||||
"bulkDownloadRequestFailed": "Problème lors de la préparation du téléchargement",
|
||||
"copy": "Copier",
|
||||
"autoAssignBoardOnClick": "Assigner automatiquement une Planche lors du clic",
|
||||
"dropToUpload": "$t(gallery.drop) pour Charger",
|
||||
"dropOrUpload": "$t(gallery.drop) ou Séléctioner",
|
||||
"dropToUpload": "$t(gallery.drop) pour Importer",
|
||||
"dropOrUpload": "$t(gallery.drop) ou Importer",
|
||||
"oldestFirst": "Plus Ancien en premier",
|
||||
"deleteImagePermanent": "Les Images supprimées ne peuvent pas être restorées.",
|
||||
"displaySearch": "Recherche d'Image",
|
||||
@@ -161,7 +161,7 @@
|
||||
"unstarImage": "Retirer le marquage de l'Image",
|
||||
"viewerImage": "Visualisation de l'Image",
|
||||
"imagesSettings": "Paramètres des images de la galerie",
|
||||
"assetsTab": "Fichiers que vous avez chargé pour vos projets.",
|
||||
"assetsTab": "Fichiers que vous avez importé pour vos projets.",
|
||||
"imagesTab": "Images que vous avez créées et enregistrées dans Invoke.",
|
||||
"boardsSettings": "Paramètres des planches"
|
||||
},
|
||||
@@ -243,7 +243,7 @@
|
||||
"noModelsInstalled": "Aucun modèle installé",
|
||||
"urlOrLocalPath": "URL ou chemin local",
|
||||
"prune": "Vider",
|
||||
"uploadImage": "Charger une image",
|
||||
"uploadImage": "Importer une image",
|
||||
"addModels": "Ajouter des modèles",
|
||||
"install": "Installer",
|
||||
"localOnly": "local uniquement",
|
||||
@@ -273,7 +273,18 @@
|
||||
"spandrelImageToImage": "Image vers Image (Spandrel)",
|
||||
"starterModelsInModelManager": "Les modèles de démarrage peuvent être trouvés dans le gestionnaire de modèles",
|
||||
"t5Encoder": "Encodeur T5",
|
||||
"learnMoreAboutSupportedModels": "En savoir plus sur les modèles que nous prenons en charge"
|
||||
"learnMoreAboutSupportedModels": "En savoir plus sur les modèles que nous prenons en charge",
|
||||
"includesNModels": "Contient {{n}} modèles et leurs dépendances",
|
||||
"starterBundles": "Packs de démarrages",
|
||||
"starterBundleHelpText": "Installe facilement tous les modèles nécessaire pour démarrer avec un modèle de base, incluant un modèle principal, ControlNets, IP Adapters et plus encore. Choisir un pack igniorera tous les modèles déjà installés.",
|
||||
"installingXModels_one": "En cours d'installation de {{count}} modèle",
|
||||
"installingXModels_many": "En cours d'installation de {{count}} modèles",
|
||||
"installingXModels_other": "En cours d'installation de {{count}} modèles",
|
||||
"skippingXDuplicates_one": ", en ignorant {{count}} doublon",
|
||||
"skippingXDuplicates_many": ", en ignorant {{count}} doublons",
|
||||
"skippingXDuplicates_other": ", en ignorant {{count}} doublons",
|
||||
"installingModel": "Modèle en cours d'installation",
|
||||
"installingBundle": "Pack en cours d'installation"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Images",
|
||||
@@ -414,16 +425,16 @@
|
||||
"confirmOnNewSession": "Confirmer lors d'une nouvelle session"
|
||||
},
|
||||
"toast": {
|
||||
"uploadFailed": "Téléchargement échoué",
|
||||
"uploadFailed": "Importation échouée",
|
||||
"imageCopied": "Image copiée",
|
||||
"parametersNotSet": "Paramètres non rappelés",
|
||||
"serverError": "Erreur du serveur",
|
||||
"uploadFailedInvalidUploadDesc": "Doit être une unique image PNG ou JPEG",
|
||||
"uploadFailedInvalidUploadDesc": "Doit être des images au format PNG ou JPEG.",
|
||||
"problemCopyingImage": "Impossible de copier l'image",
|
||||
"parameterSet": "Paramètre Rappelé",
|
||||
"parameterNotSet": "Paramètre non Rappelé",
|
||||
"canceled": "Traitement annulé",
|
||||
"addedToBoard": "Ajouté à la planche",
|
||||
"addedToBoard": "Ajouté aux ressources de la planche {{name}}",
|
||||
"workflowLoaded": "Processus chargé",
|
||||
"connected": "Connecté au serveur",
|
||||
"setNodeField": "Définir comme champ de nœud",
|
||||
@@ -436,7 +447,7 @@
|
||||
"baseModelChangedCleared_one": "Effacé ou désactivé {{count}} sous-modèle incompatible",
|
||||
"baseModelChangedCleared_many": "Effacé ou désactivé {{count}} sous-modèles incompatibles",
|
||||
"baseModelChangedCleared_other": "Effacé ou désactivé {{count}} sous-modèles incompatibles",
|
||||
"invalidUpload": "Téléchargement invalide",
|
||||
"invalidUpload": "Importation invalide",
|
||||
"problemDownloadingImage": "Impossible de télécharger l'image",
|
||||
"problemRetrievingWorkflow": "Problème de récupération du processus",
|
||||
"problemDeletingWorkflow": "Problème de suppression du processus",
|
||||
@@ -468,10 +479,15 @@
|
||||
"baseModelChanged": "Modèle de base changé",
|
||||
"problemSavingLayer": "Impossible d'enregistrer la couche",
|
||||
"imageNotLoadedDesc": "Image introuvable",
|
||||
"linkCopied": "Lien copié"
|
||||
"linkCopied": "Lien copié",
|
||||
"imagesWillBeAddedTo": "Les images Importées seront ajoutées au ressources de la Planche {{boardName}}.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_one": "Doit être au maximum une image PNG ou JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_many": "Doit être au maximum {{count}} images PNG ou JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_other": "Doit être au maximum {{count}} images PNG ou JPEG.",
|
||||
"addedToUncategorized": "Ajouté aux ressources de la planche $t(boards.uncategorized)"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Charger une image",
|
||||
"uploadImage": "Importer une image",
|
||||
"reset": "Réinitialiser",
|
||||
"nextImage": "Image suivante",
|
||||
"previousImage": "Image précédente",
|
||||
@@ -483,7 +499,8 @@
|
||||
"submitSupportTicket": "Envoyer un ticket de support",
|
||||
"resetUI": "$t(accessibility.reset) l'Interface Utilisateur",
|
||||
"toggleRightPanel": "Afficher/Masquer le panneau de droite (G)",
|
||||
"toggleLeftPanel": "Afficher/Masquer le panneau de gauche (T)"
|
||||
"toggleLeftPanel": "Afficher/Masquer le panneau de gauche (T)",
|
||||
"uploadImages": "Importer Image(s)"
|
||||
},
|
||||
"boards": {
|
||||
"move": "Déplacer",
|
||||
@@ -1400,13 +1417,14 @@
|
||||
"parameterSet": "Paramètre {{parameter}} défini",
|
||||
"parsingFailed": "L'analyse a échoué",
|
||||
"recallParameter": "Rappeler {{label}}",
|
||||
"canvasV2Metadata": "Toile"
|
||||
"canvasV2Metadata": "Toile",
|
||||
"guidance": "Guide"
|
||||
},
|
||||
"sdxl": {
|
||||
"freePromptStyle": "Écriture de Prompt manuelle",
|
||||
"concatPromptStyle": "Lier Prompt & Style",
|
||||
"negStylePrompt": "Prompt Négatif",
|
||||
"posStylePrompt": "Prompt Positif",
|
||||
"negStylePrompt": "Style Prompt Négatif",
|
||||
"posStylePrompt": "Style Prompt Positif",
|
||||
"refinerStart": "Démarrer le Refiner",
|
||||
"denoisingStrength": "Force de débruitage",
|
||||
"steps": "Étapes",
|
||||
@@ -1582,7 +1600,7 @@
|
||||
"noDescription": "Aucune description",
|
||||
"deleteWorkflow": "Supprimer le processus",
|
||||
"openWorkflow": "Ouvrir le processus",
|
||||
"uploadWorkflow": "Charger à partir du fichier",
|
||||
"uploadWorkflow": "Charger à partir d'un fichier",
|
||||
"workflowName": "Nom du processus",
|
||||
"unnamedWorkflow": "Processus sans nom",
|
||||
"saveWorkflowAs": "Enregistrer le processus sous",
|
||||
@@ -1613,7 +1631,7 @@
|
||||
"projectWorkflows": "Processus du projet",
|
||||
"copyShareLink": "Copier le lien de partage",
|
||||
"chooseWorkflowFromLibrary": "Choisir le Processus dans la Bibliothèque",
|
||||
"uploadAndSaveWorkflow": "Charger dans la bibliothèque",
|
||||
"uploadAndSaveWorkflow": "Importer dans la bibliothèque",
|
||||
"edit": "Modifer",
|
||||
"deleteWorkflow2": "Êtes-vous sûr de vouloir supprimer ce processus ? Ceci ne peut pas être annulé.",
|
||||
"download": "Télécharger",
|
||||
@@ -1980,50 +1998,50 @@
|
||||
"missingTileControlNetModel": "Aucun modèle ControlNet valide installé"
|
||||
},
|
||||
"stylePresets": {
|
||||
"deleteTemplate": "Supprimer le modèle",
|
||||
"editTemplate": "Modifier le modèle",
|
||||
"deleteTemplate": "Supprimer le template",
|
||||
"editTemplate": "Modifier le template",
|
||||
"exportFailed": "Impossible de générer et de télécharger le CSV",
|
||||
"name": "Nom",
|
||||
"acceptedColumnsKeys": "Colonnes/clés acceptées :",
|
||||
"promptTemplatesDesc1": "Les modèles de prompt ajoutent du texte aux prompts que vous écrivez dans la zone de saisie des prompts.",
|
||||
"promptTemplatesDesc1": "Les templates de prompt ajoutent du texte aux prompts que vous écrivez dans la zone de saisie.",
|
||||
"private": "Privé",
|
||||
"searchByName": "Rechercher par nom",
|
||||
"viewList": "Afficher la liste des modèles",
|
||||
"noTemplates": "Aucun modèle",
|
||||
"viewList": "Afficher la liste des templates",
|
||||
"noTemplates": "Aucun templates",
|
||||
"insertPlaceholder": "Insérer un placeholder",
|
||||
"defaultTemplates": "Modèles par défaut",
|
||||
"defaultTemplates": "Template pré-défini",
|
||||
"deleteImage": "Supprimer l'image",
|
||||
"createPromptTemplate": "Créer un modèle de prompt",
|
||||
"createPromptTemplate": "Créer un template de prompt",
|
||||
"negativePrompt": "Prompt négatif",
|
||||
"promptTemplatesDesc3": "Si vous omettez le placeholder, le modèle sera ajouté à la fin de votre prompt.",
|
||||
"promptTemplatesDesc3": "Si vous omettez le placeholder, le template sera ajouté à la fin de votre prompt.",
|
||||
"positivePrompt": "Prompt positif",
|
||||
"choosePromptTemplate": "Choisir un modèle de prompt",
|
||||
"choosePromptTemplate": "Choisir un template de prompt",
|
||||
"toggleViewMode": "Basculer le mode d'affichage",
|
||||
"updatePromptTemplate": "Mettre à jour le modèle de prompt",
|
||||
"flatten": "Intégrer le modèle sélectionné dans le prompt actuel",
|
||||
"myTemplates": "Mes modèles",
|
||||
"updatePromptTemplate": "Mettre à jour le template de prompt",
|
||||
"flatten": "Intégrer le template sélectionné dans le prompt actuel",
|
||||
"myTemplates": "Mes Templates",
|
||||
"type": "Type",
|
||||
"exportDownloaded": "Exportation téléchargée",
|
||||
"clearTemplateSelection": "Supprimer la sélection de modèle",
|
||||
"promptTemplateCleared": "Modèle de prompt effacé",
|
||||
"templateDeleted": "Modèle de prompt supprimé",
|
||||
"exportPromptTemplates": "Exporter mes modèles de prompt (CSV)",
|
||||
"clearTemplateSelection": "Supprimer la sélection de template",
|
||||
"promptTemplateCleared": "Template de prompt effacé",
|
||||
"templateDeleted": "Template de prompt supprimé",
|
||||
"exportPromptTemplates": "Exporter mes templates de prompt (CSV)",
|
||||
"nameColumn": "'nom'",
|
||||
"positivePromptColumn": "\"prompt\" ou \"prompt_positif\"",
|
||||
"useForTemplate": "Utiliser pour le modèle de prompt",
|
||||
"uploadImage": "Charger une image",
|
||||
"importTemplates": "Importer des modèles de prompt (CSV/JSON)",
|
||||
"useForTemplate": "Utiliser pour le template de prompt",
|
||||
"uploadImage": "Importer une image",
|
||||
"importTemplates": "Importer des templates de prompt (CSV/JSON)",
|
||||
"negativePromptColumn": "'prompt_négatif'",
|
||||
"deleteTemplate2": "Êtes-vous sûr de vouloir supprimer ce modèle ? Cette action ne peut pas être annulée.",
|
||||
"deleteTemplate2": "Êtes-vous sûr de vouloir supprimer ce template ? Cette action ne peut pas être annulée.",
|
||||
"preview": "Aperçu",
|
||||
"shared": "Partagé",
|
||||
"noMatchingTemplates": "Aucun modèle correspondant",
|
||||
"sharedTemplates": "Modèles partagés",
|
||||
"unableToDeleteTemplate": "Impossible de supprimer le modèle de prompt",
|
||||
"noMatchingTemplates": "Aucun templates correspondant",
|
||||
"sharedTemplates": "Template partagés",
|
||||
"unableToDeleteTemplate": "Impossible de supprimer le template de prompt",
|
||||
"active": "Actif",
|
||||
"copyTemplate": "Copier le modèle",
|
||||
"viewModeTooltip": "Voici à quoi ressemblera votre prompt avec le modèle actuellement sélectionné. Pour modifier votre prompt, cliquez n'importe où dans la zone de texte.",
|
||||
"promptTemplatesDesc2": "Utilisez la chaîne de remplacement <Pre>{{placeholder}}</Pre> pour spécifier où votre prompt doit être inclus dans le modèle."
|
||||
"copyTemplate": "Copier le template",
|
||||
"viewModeTooltip": "Voici à quoi ressemblera votre prompt avec le template actuellement sélectionné. Pour modifier votre prompt, cliquez n'importe où dans la zone de texte.",
|
||||
"promptTemplatesDesc2": "Utilisez la chaîne de remplacement <Pre>{{placeholder}}</Pre> pour spécifier où votre prompt doit être inclus dans le template."
|
||||
},
|
||||
"system": {
|
||||
"logNamespaces": {
|
||||
@@ -2051,8 +2069,12 @@
|
||||
"enableLogging": "Activer la journalisation"
|
||||
},
|
||||
"newUserExperience": {
|
||||
"toGetStarted": "Pour commencer, saisissez un prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un modèle de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement dans la <StrongComponent>Galerie</StrongComponent> ou de les modifier sur la <StrongComponent>Toile</StrongComponent>.",
|
||||
"gettingStartedSeries": "Vous souhaitez plus de conseils ? Consultez notre <LinkComponent>Série de démarrage</LinkComponent> pour des astuces sur l'exploitation du plein potentiel de l'Invoke Studio."
|
||||
"toGetStarted": "Pour commencer, saisissez un prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement dans la <StrongComponent>Galerie</StrongComponent> ou de les modifier sur la <StrongComponent>Toile</StrongComponent>.",
|
||||
"gettingStartedSeries": "Vous souhaitez plus de conseils ? Consultez notre <LinkComponent>Série de démarrage</LinkComponent> pour des astuces sur l'exploitation du plein potentiel de l'Invoke Studio.",
|
||||
"noModelsInstalled": "Il semblerait qu'aucun modèle ne soit installé",
|
||||
"downloadStarterModels": "Télécharger les modèles de démarrage",
|
||||
"importModels": "Importer Modèles",
|
||||
"toGetStartedLocal": "Pour commencer, assurez-vous de télécharger ou d'importer des modèles nécessaires pour exécuter Invoke. Ensuite, saisissez le prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement sur <StrongComponent>Galerie</StrongComponent> ou les modifier sur la <StrongComponent>Toile</StrongComponent>."
|
||||
},
|
||||
"upsell": {
|
||||
"shareAccess": "Partager l'accès",
|
||||
|
||||
@@ -577,7 +577,18 @@
|
||||
"noMatchingModels": "Nessun modello corrispondente",
|
||||
"starterModelsInModelManager": "I modelli iniziali possono essere trovati in Gestione Modelli",
|
||||
"spandrelImageToImage": "Immagine a immagine (Spandrel)",
|
||||
"learnMoreAboutSupportedModels": "Scopri di più sui modelli che supportiamo"
|
||||
"learnMoreAboutSupportedModels": "Scopri di più sui modelli che supportiamo",
|
||||
"starterBundles": "Pacchetti per iniziare",
|
||||
"installingBundle": "Installazione del pacchetto",
|
||||
"skippingXDuplicates_one": ", saltando {{count}} duplicato",
|
||||
"skippingXDuplicates_many": ", saltando {{count}} duplicati",
|
||||
"skippingXDuplicates_other": ", saltando {{count}} duplicati",
|
||||
"installingModel": "Installazione del modello",
|
||||
"installingXModels_one": "Installazione di {{count}} modello",
|
||||
"installingXModels_many": "Installazione di {{count}} modelli",
|
||||
"installingXModels_other": "Installazione di {{count}} modelli",
|
||||
"includesNModels": "Include {{n}} modelli e le loro dipendenze",
|
||||
"starterBundleHelpText": "Installa facilmente tutti i modelli necessari per iniziare con un modello base, tra cui un modello principale, controlnet, adattatori IP e altro. Selezionando un pacchetto salterai tutti i modelli che hai già installato."
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -722,7 +733,7 @@
|
||||
"serverError": "Errore del Server",
|
||||
"connected": "Connesso al server",
|
||||
"canceled": "Elaborazione annullata",
|
||||
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
||||
"uploadFailedInvalidUploadDesc": "Devono essere immagini PNG o JPEG.",
|
||||
"parameterSet": "Parametro richiamato",
|
||||
"parameterNotSet": "Parametro non richiamato",
|
||||
"problemCopyingImage": "Impossibile copiare l'immagine",
|
||||
@@ -731,7 +742,7 @@
|
||||
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
|
||||
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
|
||||
"imageUploaded": "Immagine caricata",
|
||||
"addedToBoard": "Aggiunto alla bacheca",
|
||||
"addedToBoard": "Aggiunto alle risorse della bacheca {{name}}",
|
||||
"modelAddedSimple": "Modello aggiunto alla Coda",
|
||||
"imageUploadFailed": "Caricamento immagine non riuscito",
|
||||
"setControlImage": "Imposta come immagine di controllo",
|
||||
@@ -770,7 +781,12 @@
|
||||
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
|
||||
"layerCopiedToClipboard": "Livello copiato negli appunti",
|
||||
"imageNotLoadedDesc": "Impossibile trovare l'immagine",
|
||||
"linkCopied": "Collegamento copiato"
|
||||
"linkCopied": "Collegamento copiato",
|
||||
"addedToUncategorized": "Aggiunto alle risorse della bacheca $t(boards.uncategorized)",
|
||||
"imagesWillBeAddedTo": "Le immagini caricate verranno aggiunte alle risorse della bacheca {{boardName}}.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_one": "Devi caricare al massimo 1 immagine PNG o JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_many": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG o JPEG."
|
||||
},
|
||||
"accessibility": {
|
||||
"invokeProgressBar": "Barra di avanzamento generazione",
|
||||
@@ -785,7 +801,8 @@
|
||||
"about": "Informazioni",
|
||||
"submitSupportTicket": "Invia ticket di supporto",
|
||||
"toggleLeftPanel": "Attiva/disattiva il pannello sinistro (T)",
|
||||
"toggleRightPanel": "Attiva/disattiva il pannello destro (G)"
|
||||
"toggleRightPanel": "Attiva/disattiva il pannello destro (G)",
|
||||
"uploadImages": "Carica immagine(i)"
|
||||
},
|
||||
"nodes": {
|
||||
"zoomOutNodes": "Rimpicciolire",
|
||||
@@ -2006,7 +2023,11 @@
|
||||
},
|
||||
"newUserExperience": {
|
||||
"gettingStartedSeries": "Desideri maggiori informazioni? Consulta la nostra <LinkComponent>Getting Started Series</LinkComponent> per suggerimenti su come sfruttare appieno il potenziale di Invoke Studio.",
|
||||
"toGetStarted": "Per iniziare, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>."
|
||||
"toGetStarted": "Per iniziare, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>.",
|
||||
"importModels": "Importa modelli",
|
||||
"downloadStarterModels": "Scarica i modelli per iniziare",
|
||||
"noModelsInstalled": "Sembra che tu non abbia installato alcun modello",
|
||||
"toGetStartedLocal": "Per iniziare, assicurati di scaricare o importare i modelli necessari per eseguire Invoke. Quindi, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>."
|
||||
},
|
||||
"whatsNew": {
|
||||
"canvasV2Announcement": {
|
||||
|
||||
@@ -94,7 +94,8 @@
|
||||
"reset": "Сброс",
|
||||
"none": "Ничего",
|
||||
"new": "Новый",
|
||||
"ok": "Ok"
|
||||
"ok": "Ok",
|
||||
"close": "Закрыть"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@@ -160,7 +161,9 @@
|
||||
"openViewer": "Открыть просмотрщик",
|
||||
"closeViewer": "Закрыть просмотрщик",
|
||||
"imagesTab": "Изображения, созданные и сохраненные в Invoke.",
|
||||
"assetsTab": "Файлы, которые вы загрузили для использования в своих проектах."
|
||||
"assetsTab": "Файлы, которые вы загрузили для использования в своих проектах.",
|
||||
"boardsSettings": "Настройки доски",
|
||||
"imagesSettings": "Настройки галереи изображений"
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Поиск горячих клавиш",
|
||||
@@ -583,7 +586,18 @@
|
||||
"learnMoreAboutSupportedModels": "Подробнее о поддерживаемых моделях",
|
||||
"t5Encoder": "T5 энкодер",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"clipEmbed": "CLIP Embed"
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"installingXModels_one": "Установка {{count}} модели",
|
||||
"installingXModels_few": "Установка {{count}} моделей",
|
||||
"installingXModels_many": "Установка {{count}} моделей",
|
||||
"installingBundle": "Установка пакета",
|
||||
"installingModel": "Установка модели",
|
||||
"starterBundles": "Стартовые пакеты",
|
||||
"skippingXDuplicates_one": ", пропуская {{count}} дубликат",
|
||||
"skippingXDuplicates_few": ", пропуская {{count}} дубликата",
|
||||
"skippingXDuplicates_many": ", пропуская {{count}} дубликатов",
|
||||
"includesNModels": "Включает в себя {{n}} моделей и их зависимостей",
|
||||
"starterBundleHelpText": "Легко установите все модели, необходимые для начала работы с базовой моделью, включая основную модель, сети управления, IP-адаптеры и многое другое. При выборе комплекта все уже установленные модели будут пропущены."
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Изображения",
|
||||
@@ -730,7 +744,7 @@
|
||||
"serverError": "Ошибка сервера",
|
||||
"connected": "Подключено к серверу",
|
||||
"canceled": "Обработка отменена",
|
||||
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
||||
"uploadFailedInvalidUploadDesc": "Это должны быть изображения PNG или JPEG.",
|
||||
"parameterNotSet": "Параметр не задан",
|
||||
"parameterSet": "Параметр задан",
|
||||
"problemCopyingImage": "Не удается скопировать изображение",
|
||||
@@ -742,7 +756,7 @@
|
||||
"setNodeField": "Установить как поле узла",
|
||||
"invalidUpload": "Неверная загрузка",
|
||||
"imageUploaded": "Изображение загружено",
|
||||
"addedToBoard": "Добавлено на доску",
|
||||
"addedToBoard": "Добавлено в активы доски {{name}}",
|
||||
"workflowLoaded": "Рабочий процесс загружен",
|
||||
"problemDeletingWorkflow": "Проблема с удалением рабочего процесса",
|
||||
"modelAddedSimple": "Модель добавлена в очередь",
|
||||
@@ -777,7 +791,13 @@
|
||||
"unableToLoadStylePreset": "Невозможно загрузить предустановку стиля",
|
||||
"layerCopiedToClipboard": "Слой скопирован в буфер обмена",
|
||||
"sentToUpscale": "Отправить на увеличение",
|
||||
"layerSavedToAssets": "Слой сохранен в активах"
|
||||
"layerSavedToAssets": "Слой сохранен в активах",
|
||||
"linkCopied": "Ссылка скопирована",
|
||||
"addedToUncategorized": "Добавлено в активы доски $t(boards.uncategorized)",
|
||||
"imagesWillBeAddedTo": "Загруженные изображения будут добавлены в активы доски {{boardName}}.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_one": "Должно быть не более {{count}} изображения в формате PNG или JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_few": "Должно быть не более {{count}} изображений в формате PNG или JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_many": "Должно быть не более {{count}} изображений в формате PNG или JPEG."
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Загрузить изображение",
|
||||
@@ -792,7 +812,8 @@
|
||||
"about": "Об этом",
|
||||
"submitSupportTicket": "Отправить тикет в службу поддержки",
|
||||
"toggleRightPanel": "Переключить правую панель (G)",
|
||||
"toggleLeftPanel": "Переключить левую панель (T)"
|
||||
"toggleLeftPanel": "Переключить левую панель (T)",
|
||||
"uploadImages": "Загрузить изображения"
|
||||
},
|
||||
"nodes": {
|
||||
"zoomInNodes": "Увеличьте масштаб",
|
||||
@@ -933,7 +954,7 @@
|
||||
"saveToGallery": "Сохранить в галерею",
|
||||
"noWorkflows": "Нет рабочих процессов",
|
||||
"noMatchingWorkflows": "Нет совпадающих рабочих процессов",
|
||||
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>"
|
||||
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>."
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Авто добавление Доски",
|
||||
@@ -1409,7 +1430,8 @@
|
||||
"recallParameter": "Отозвать {{label}}",
|
||||
"allPrompts": "Все запросы",
|
||||
"imageDimensions": "Размеры изображения",
|
||||
"canvasV2Metadata": "Холст"
|
||||
"canvasV2Metadata": "Холст",
|
||||
"guidance": "Точность"
|
||||
},
|
||||
"queue": {
|
||||
"status": "Статус",
|
||||
@@ -1561,7 +1583,12 @@
|
||||
"defaultWorkflows": "Стандартные рабочие процессы",
|
||||
"deleteWorkflow2": "Вы уверены, что хотите удалить этот рабочий процесс? Это нельзя отменить.",
|
||||
"chooseWorkflowFromLibrary": "Выбрать рабочий процесс из библиотеки",
|
||||
"uploadAndSaveWorkflow": "Загрузить в библиотеку"
|
||||
"uploadAndSaveWorkflow": "Загрузить в библиотеку",
|
||||
"edit": "Редактировать",
|
||||
"download": "Скачать",
|
||||
"copyShareLink": "Скопировать ссылку на общий доступ",
|
||||
"copyShareLinkForWorkflow": "Скопировать ссылку на общий доступ для рабочего процесса",
|
||||
"delete": "Удалить"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Включить исправление высокого разрешения",
|
||||
@@ -1890,7 +1917,10 @@
|
||||
"fitToBbox": "Вместить в рамку",
|
||||
"reset": "Сбросить",
|
||||
"apply": "Применить",
|
||||
"cancel": "Отменить"
|
||||
"cancel": "Отменить",
|
||||
"fitModeContain": "Уместить",
|
||||
"fitMode": "Режим подгонки",
|
||||
"fitModeFill": "Заполнить"
|
||||
},
|
||||
"disableAutoNegative": "Отключить авто негатив",
|
||||
"deleteReferenceImage": "Удалить эталонное изображение",
|
||||
@@ -1920,7 +1950,8 @@
|
||||
"globalReferenceImage": "Глобальное эталонное изображение",
|
||||
"sendToGallery": "Отправить в галерею",
|
||||
"referenceImage": "Эталонное изображение",
|
||||
"addGlobalReferenceImage": "Добавить $t(controlLayers.globalReferenceImage)"
|
||||
"addGlobalReferenceImage": "Добавить $t(controlLayers.globalReferenceImage)",
|
||||
"newImg2ImgCanvasFromImage": "Новое img2img из изображения"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
|
||||
@@ -4,6 +4,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
|
||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
@@ -59,6 +60,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
useGlobalModifiersInit();
|
||||
useGlobalHotkeys();
|
||||
useGetOpenAPISchemaQuery();
|
||||
useSyncLoggingConfig();
|
||||
|
||||
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ import 'i18n';
|
||||
|
||||
import type { Middleware } from '@reduxjs/toolkit';
|
||||
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import type { LoggingOverrides } from 'app/logging/logger';
|
||||
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
|
||||
@@ -20,7 +22,7 @@ import Loading from 'common/components/Loading/Loading';
|
||||
import AppDndContext from 'features/dnd/components/AppDndContext';
|
||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
||||
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
import { $socketOptions } from 'services/events/stores';
|
||||
@@ -46,6 +48,7 @@ interface Props extends PropsWithChildren {
|
||||
isDebugging?: boolean;
|
||||
logo?: ReactNode;
|
||||
workflowCategories?: WorkflowCategory[];
|
||||
loggingOverrides?: LoggingOverrides;
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({
|
||||
@@ -65,7 +68,26 @@ const InvokeAIUI = ({
|
||||
isDebugging = false,
|
||||
logo,
|
||||
workflowCategories,
|
||||
loggingOverrides,
|
||||
}: Props) => {
|
||||
useLayoutEffect(() => {
|
||||
/*
|
||||
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
|
||||
* possible opportunity.
|
||||
*
|
||||
* Once redux initializes, we will check the user's settings and update the logging config accordingly. See
|
||||
* `useSyncLoggingConfig`.
|
||||
*/
|
||||
$loggingOverrides.set(loggingOverrides);
|
||||
|
||||
// Until we get the user's settings, we will use the overrides OR default values.
|
||||
configureLogging(
|
||||
loggingOverrides?.logIsEnabled ?? true,
|
||||
loggingOverrides?.logLevel ?? 'debug',
|
||||
loggingOverrides?.logNamespaces ?? '*'
|
||||
);
|
||||
}, [loggingOverrides]);
|
||||
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
|
||||
@@ -9,11 +9,10 @@ const serializeMessage: MessageSerializer = (message) => {
|
||||
};
|
||||
|
||||
ROARR.serializeMessage = serializeMessage;
|
||||
ROARR.write = createLogWriter();
|
||||
|
||||
export const BASE_CONTEXT = {};
|
||||
const BASE_CONTEXT = {};
|
||||
|
||||
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||
const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||
|
||||
export const zLogNamespace = z.enum([
|
||||
'canvas',
|
||||
@@ -35,8 +34,22 @@ export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fat
|
||||
export type LogLevel = z.infer<typeof zLogLevel>;
|
||||
export const isLogLevel = (v: unknown): v is LogLevel => zLogLevel.safeParse(v).success;
|
||||
|
||||
/**
|
||||
* Override logging settings.
|
||||
* @property logIsEnabled Override the enabled log state. Omit to use the user's settings.
|
||||
* @property logNamespaces Override the enabled log namespaces. Use `"*"` for all namespaces. Omit to use the user's settings.
|
||||
* @property logLevel Override the log level. Omit to use the user's settings.
|
||||
*/
|
||||
export type LoggingOverrides = {
|
||||
logIsEnabled?: boolean;
|
||||
logNamespaces?: LogNamespace[] | '*';
|
||||
logLevel?: LogLevel;
|
||||
};
|
||||
|
||||
export const $loggingOverrides = atom<LoggingOverrides | undefined>();
|
||||
|
||||
// Translate human-readable log levels to numbers, used for log filtering
|
||||
export const LOG_LEVEL_MAP: Record<LogLevel, number> = {
|
||||
const LOG_LEVEL_MAP: Record<LogLevel, number> = {
|
||||
trace: 10,
|
||||
debug: 20,
|
||||
info: 30,
|
||||
@@ -44,3 +57,40 @@ export const LOG_LEVEL_MAP: Record<LogLevel, number> = {
|
||||
error: 50,
|
||||
fatal: 60,
|
||||
};
|
||||
|
||||
/**
|
||||
* Configure logging, pushing settings to local storage.
|
||||
*
|
||||
* @param logIsEnabled Whether logging is enabled
|
||||
* @param logLevel The log level
|
||||
* @param logNamespaces A list of log namespaces to enable, or '*' to enable all
|
||||
*/
|
||||
export const configureLogging = (
|
||||
logIsEnabled: boolean = true,
|
||||
logLevel: LogLevel = 'warn',
|
||||
logNamespaces: LogNamespace[] | '*'
|
||||
): void => {
|
||||
if (!logIsEnabled) {
|
||||
// Disable console log output
|
||||
localStorage.setItem('ROARR_LOG', 'false');
|
||||
} else {
|
||||
// Enable console log output
|
||||
localStorage.setItem('ROARR_LOG', 'true');
|
||||
|
||||
// Use a filter to show only logs of the given level
|
||||
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
|
||||
|
||||
const namespaces = logNamespaces === '*' ? zLogNamespace.options : logNamespaces;
|
||||
|
||||
if (namespaces.length > 0) {
|
||||
filter += ` AND (${namespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
|
||||
} else {
|
||||
// This effectively hides all logs because we use namespaces for all logs
|
||||
filter += ' AND context.namespace:undefined';
|
||||
}
|
||||
|
||||
localStorage.setItem('ROARR_FILTER', filter);
|
||||
}
|
||||
|
||||
ROARR.write = createLogWriter();
|
||||
};
|
||||
|
||||
@@ -1,53 +1,9 @@
|
||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectSystemLogIsEnabled,
|
||||
selectSystemLogLevel,
|
||||
selectSystemLogNamespaces,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import type { LogNamespace } from './logger';
|
||||
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
|
||||
import { logger } from './logger';
|
||||
|
||||
export const useLogger = (namespace: LogNamespace) => {
|
||||
const logLevel = useAppSelector(selectSystemLogLevel);
|
||||
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
|
||||
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
|
||||
|
||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
||||
useEffect(() => {
|
||||
if (logIsEnabled) {
|
||||
// Enable console log output
|
||||
localStorage.setItem('ROARR_LOG', 'true');
|
||||
|
||||
// Use a filter to show only logs of the given level
|
||||
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
|
||||
if (logNamespaces.length > 0) {
|
||||
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
|
||||
} else {
|
||||
filter += ' AND context.namespace:undefined';
|
||||
}
|
||||
localStorage.setItem('ROARR_FILTER', filter);
|
||||
} else {
|
||||
// Disable console log output
|
||||
localStorage.setItem('ROARR_LOG', 'false');
|
||||
}
|
||||
ROARR.write = createLogWriter();
|
||||
}, [logLevel, logIsEnabled, logNamespaces]);
|
||||
|
||||
// Update the module-scoped logger context as needed
|
||||
useEffect(() => {
|
||||
// TODO: type this properly
|
||||
//eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const newContext: Record<string, any> = {
|
||||
...BASE_CONTEXT,
|
||||
};
|
||||
|
||||
$logger.set(Roarr.child(newContext));
|
||||
}, []);
|
||||
|
||||
const log = useMemo(() => logger(namespace), [namespace]);
|
||||
|
||||
return log;
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import {
|
||||
selectSystemLogIsEnabled,
|
||||
selectSystemLogLevel,
|
||||
selectSystemLogNamespaces,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { useLayoutEffect } from 'react';
|
||||
|
||||
/**
|
||||
* This hook synchronizes the logging configuration stored in Redux with the logging system, which uses localstorage.
|
||||
*
|
||||
* The sync is one-way: from Redux to localstorage. This means that changes made in the UI will be reflected in the
|
||||
* logging system, but changes made directly to localstorage will not be reflected in the UI.
|
||||
*
|
||||
* See {@link configureLogging}
|
||||
*/
|
||||
export const useSyncLoggingConfig = () => {
|
||||
useAssertSingleton('useSyncLoggingConfig');
|
||||
|
||||
const loggingOverrides = useStore($loggingOverrides);
|
||||
|
||||
const logLevel = useAppSelector(selectSystemLogLevel);
|
||||
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
|
||||
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
configureLogging(
|
||||
loggingOverrides?.logIsEnabled ?? logIsEnabled,
|
||||
loggingOverrides?.logLevel ?? logLevel,
|
||||
loggingOverrides?.logNamespaces ?? logNamespaces
|
||||
);
|
||||
}, [
|
||||
logIsEnabled,
|
||||
logLevel,
|
||||
logNamespaces,
|
||||
loggingOverrides?.logIsEnabled,
|
||||
loggingOverrides?.logLevel,
|
||||
loggingOverrides?.logNamespaces,
|
||||
]);
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { FilterType } from 'features/controlLayers/store/filters';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import type { PartialDeep } from 'type-fest';
|
||||
|
||||
/**
|
||||
* A disable-able application feature
|
||||
@@ -119,4 +119,4 @@ export type AppConfig = {
|
||||
};
|
||||
};
|
||||
|
||||
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;
|
||||
export type PartialAppConfig = PartialDeep<AppConfig>;
|
||||
|
||||
@@ -1,4 +1,12 @@
|
||||
type SerializableValue = string | number | boolean | null | undefined | SerializableValue[] | SerializableObject;
|
||||
type SerializableValue =
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| null
|
||||
| undefined
|
||||
| SerializableValue[]
|
||||
| readonly SerializableValue[]
|
||||
| SerializableObject;
|
||||
export type SerializableObject = {
|
||||
[k: string | number]: SerializableValue;
|
||||
};
|
||||
|
||||
@@ -34,7 +34,6 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addGlobalReferenceImage}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</Button>
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectAutoProcess, settingsAutoProcessToggled } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasAutoProcessSwitch = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const autoProcess = useAppSelector(selectAutoProcess);
|
||||
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsAutoProcessToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
|
||||
<Switch size="sm" isChecked={autoProcess} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasAutoProcessSwitch.displayName = 'CanvasAutoProcessSwitch';
|
||||
@@ -5,6 +5,7 @@ import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/componen
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import {
|
||||
EntityIdentifierContext,
|
||||
@@ -15,6 +16,7 @@ import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/sel
|
||||
import {
|
||||
isFilterableEntityIdentifier,
|
||||
isSaveableEntityIdentifier,
|
||||
isSegmentableEntityIdentifier,
|
||||
isTransformableEntityIdentifier,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
@@ -27,6 +29,7 @@ const CanvasContextMenuSelectedEntityMenuItemsContent = memo(() => {
|
||||
<MenuGroup title={title}>
|
||||
{isFilterableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsFilter />}
|
||||
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsTransform />}
|
||||
{isSegmentableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSegment />}
|
||||
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCopyToClipboard />}
|
||||
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSave />}
|
||||
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCropToBbox />}
|
||||
|
||||
@@ -40,7 +40,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuGroup title={t('controlLayers.global')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isFLUX}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage}>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
|
||||
@@ -10,6 +10,7 @@ import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea
|
||||
import { Filter } from 'features/controlLayers/components/Filters/Filter';
|
||||
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
|
||||
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
|
||||
import { SegmentAnything } from 'features/controlLayers/components/SegmentAnything/SegmentAnything';
|
||||
import { StagingAreaIsStagingGate } from 'features/controlLayers/components/StagingArea/StagingAreaIsStagingGate';
|
||||
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
|
||||
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
|
||||
@@ -101,6 +102,7 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
<CanvasManagerProviderGate>
|
||||
<Filter />
|
||||
<Transform />
|
||||
<SegmentAnything />
|
||||
</CanvasManagerProviderGate>
|
||||
</Flex>
|
||||
<CanvasDropArea />
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
import { FormControl, FormLabel, Switch, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectIsolatedLayerPreview,
|
||||
settingsIsolatedLayerPreviewToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasOperationIsolatedLayerPreviewSwitch = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isolatedLayerPreview = useAppSelector(selectIsolatedLayerPreview);
|
||||
const onChangeIsolatedPreview = useCallback(() => {
|
||||
dispatch(settingsIsolatedLayerPreviewToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Tooltip label={t('controlLayers.settings.isolatedLayerPreviewDesc')}>
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.settings.isolatedPreview')}</FormLabel>
|
||||
<Switch size="sm" isChecked={isolatedLayerPreview} onChange={onChangeIsolatedPreview} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasOperationIsolatedLayerPreviewSwitch.displayName = 'CanvasOperationIsolatedLayerPreviewSwitch';
|
||||
@@ -7,6 +7,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { ControlLayerMenuItemsConvertControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertControlToRaster';
|
||||
import { ControlLayerMenuItemsTransparencyEffect } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsTransparencyEffect';
|
||||
@@ -23,6 +24,7 @@ export const ControlLayerMenuItems = memo(() => {
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<CanvasEntityMenuItemsSegment />
|
||||
<ControlLayerMenuItemsConvertControlToRaster />
|
||||
<ControlLayerMenuItemsTransparencyEffect />
|
||||
<MenuDivider />
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
|
||||
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
|
||||
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
|
||||
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import {
|
||||
selectAutoProcessFilter,
|
||||
selectIsolatedFilteringPreview,
|
||||
settingsAutoProcessFilterToggled,
|
||||
settingsIsolatedFilteringPreviewToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type { FilterConfig } from 'features/controlLayers/store/filters';
|
||||
import { IMAGE_FILTERS } from 'features/controlLayers/store/filters';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
@@ -23,19 +20,13 @@ import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold
|
||||
const FilterContent = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
|
||||
const config = useStore(adapter.filterer.$filterConfig);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const hasProcessed = useStore(adapter.filterer.$hasProcessed);
|
||||
const autoProcessFilter = useAppSelector(selectAutoProcessFilter);
|
||||
const isolatedFilteringPreview = useAppSelector(selectIsolatedFilteringPreview);
|
||||
const onChangeIsolatedPreview = useCallback(() => {
|
||||
dispatch(settingsIsolatedFilteringPreviewToggled());
|
||||
}, [dispatch]);
|
||||
const autoProcess = useAppSelector(selectAutoProcess);
|
||||
|
||||
const onChangeFilterConfig = useCallback(
|
||||
(filterConfig: FilterConfig) => {
|
||||
@@ -51,10 +42,6 @@ const FilterContent = memo(
|
||||
[adapter.filterer.$filterConfig]
|
||||
);
|
||||
|
||||
const onChangeAutoProcessFilter = useCallback(() => {
|
||||
dispatch(settingsAutoProcessFilterToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
const isValid = useMemo(() => {
|
||||
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
|
||||
}, [config]);
|
||||
@@ -94,14 +81,8 @@ const FilterContent = memo(
|
||||
{t('controlLayers.filter.filter')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
|
||||
<Switch size="sm" isChecked={autoProcessFilter} onChange={onChangeAutoProcessFilter} />
|
||||
</FormControl>
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.settings.isolatedPreview')}</FormLabel>
|
||||
<Switch size="sm" isChecked={isolatedFilteringPreview} onChange={onChangeIsolatedPreview} />
|
||||
</FormControl>
|
||||
<CanvasAutoProcessSwitch />
|
||||
<CanvasOperationIsolatedLayerPreviewSwitch />
|
||||
</Flex>
|
||||
<FilterTypeSelect filterType={config.type} onChange={onChangeFilterType} />
|
||||
<FilterSettings filterConfig={config} onChange={onChangeFilterConfig} />
|
||||
@@ -112,7 +93,7 @@ const FilterContent = memo(
|
||||
onClick={adapter.filterer.processImmediate}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.filter.process')}
|
||||
isDisabled={!isValid || autoProcessFilter}
|
||||
isDisabled={!isValid || autoProcess}
|
||||
>
|
||||
{t('controlLayers.filter.process')}
|
||||
</Button>
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBase, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -11,9 +11,13 @@ import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
|
||||
const FLUX_CLIP_VISION = 'ViT-L';
|
||||
|
||||
const CLIP_VISION_OPTIONS = [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
|
||||
];
|
||||
|
||||
type Props = {
|
||||
@@ -47,6 +51,8 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
[onChangeCLIPVisionModel]
|
||||
);
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
@@ -64,10 +70,16 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const clipVisionModelValue = useMemo(
|
||||
() => CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel),
|
||||
[clipVisionModel]
|
||||
);
|
||||
const clipVisionOptions = useMemo(() => {
|
||||
return CLIP_VISION_OPTIONS.map((option) => ({
|
||||
...option,
|
||||
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
|
||||
}));
|
||||
}, [isFLUX]);
|
||||
|
||||
const clipVisionModelValue = useMemo(() => {
|
||||
return CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel);
|
||||
}, [clipVisionModel]);
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
@@ -85,7 +97,7 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
{selectedModel?.format === 'checkpoint' && (
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
|
||||
<Combobox
|
||||
options={CLIP_VISION_OPTIONS}
|
||||
options={clipVisionOptions}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={clipVisionModelValue}
|
||||
onChange={_onChangeCLIPVisionModel}
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
referenceImageIPAdapterModelChanged,
|
||||
referenceImageIPAdapterWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
|
||||
import type { IPAImageDropData } from 'features/dnd/types';
|
||||
@@ -90,6 +91,8 @@ export const IPAdapterSettings = memo(() => {
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
return (
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
@@ -113,7 +116,7 @@ export const IPAdapterSettings = memo(() => {
|
||||
</Flex>
|
||||
<Flex gap={2} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
|
||||
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
|
||||
@@ -7,6 +7,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { RasterLayerMenuItemsConvertRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertRasterToControl';
|
||||
import { memo } from 'react';
|
||||
@@ -22,6 +23,7 @@ export const RasterLayerMenuItems = memo(() => {
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<CanvasEntityMenuItemsSegment />
|
||||
<RasterLayerMenuItemsConvertRasterToControl />
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsCropToBbox />
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { CanvasAutoProcessSwitch } from 'features/controlLayers/components/CanvasAutoProcessSwitch';
|
||||
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
|
||||
import { SegmentAnythingPointType } from 'features/controlLayers/components/SegmentAnything/SegmentAnythingPointType';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { memo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiStarBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
const SegmentAnythingContent = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.segmentAnything.$isProcessing);
|
||||
const hasPoints = useStore(adapter.segmentAnything.$hasPoints);
|
||||
const autoProcess = useAppSelector(selectAutoProcess);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applySegmentAnything',
|
||||
category: 'canvas',
|
||||
callback: adapter.segmentAnything.apply,
|
||||
options: { enabled: !isProcessing && isCanvasFocused },
|
||||
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'cancelSegmentAnything',
|
||||
category: 'canvas',
|
||||
callback: adapter.segmentAnything.cancel,
|
||||
options: { enabled: !isProcessing && isCanvasFocused },
|
||||
dependencies: [adapter.segmentAnything, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
p={4}
|
||||
flexDir="column"
|
||||
gap={4}
|
||||
minW={420}
|
||||
h="auto"
|
||||
shadow="dark-lg"
|
||||
transitionProperty="height"
|
||||
transitionDuration="normal"
|
||||
>
|
||||
<Flex w="full" gap={4}>
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.segment.autoMask')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<CanvasAutoProcessSwitch />
|
||||
<CanvasOperationIsolatedLayerPreviewSwitch />
|
||||
</Flex>
|
||||
|
||||
<SegmentAnythingPointType adapter={adapter} />
|
||||
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Button
|
||||
leftIcon={<PiStarBold />}
|
||||
onClick={adapter.segmentAnything.processImmediate}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.segment.process')}
|
||||
variant="ghost"
|
||||
isDisabled={!hasPoints || autoProcess}
|
||||
>
|
||||
{t('controlLayers.segment.process')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button
|
||||
leftIcon={<PiArrowsCounterClockwiseBold />}
|
||||
onClick={adapter.segmentAnything.reset}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.segment.reset')}
|
||||
variant="ghost"
|
||||
>
|
||||
{t('controlLayers.segment.reset')}
|
||||
</Button>
|
||||
<Button
|
||||
leftIcon={<PiCheckBold />}
|
||||
onClick={adapter.segmentAnything.apply}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.segment.apply')}
|
||||
variant="ghost"
|
||||
>
|
||||
{t('controlLayers.segment.apply')}
|
||||
</Button>
|
||||
<Button
|
||||
leftIcon={<PiXBold />}
|
||||
onClick={adapter.segmentAnything.cancel}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('common.cancel')}
|
||||
variant="ghost"
|
||||
>
|
||||
{t('controlLayers.segment.cancel')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SegmentAnythingContent.displayName = 'SegmentAnythingContent';
|
||||
|
||||
export const SegmentAnything = () => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useStore(canvasManager.stateApi.$segmentingAdapter);
|
||||
|
||||
if (!adapter) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <SegmentAnythingContent adapter={adapter} />;
|
||||
};
|
||||
@@ -0,0 +1,44 @@
|
||||
import { Flex, FormControl, FormLabel, Radio, RadioGroup, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import { SAM_POINT_LABEL_STRING_TO_NUMBER, zSAMPointLabelString } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const SegmentAnythingPointType = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const pointType = useStore(adapter.segmentAnything.$pointTypeString);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: string) => {
|
||||
const labelAsString = zSAMPointLabelString.parse(v);
|
||||
const labelAsNumber = SAM_POINT_LABEL_STRING_TO_NUMBER[labelAsString];
|
||||
adapter.segmentAnything.$pointType.set(labelAsNumber);
|
||||
},
|
||||
[adapter.segmentAnything.$pointType]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel>{t('controlLayers.segment.pointType')}</FormLabel>
|
||||
<RadioGroup value={pointType} onChange={onChange} w="full" size="md">
|
||||
<Flex alignItems="center" w="full" gap={4} fontWeight="semibold" color="base.300">
|
||||
<Radio value="foreground">
|
||||
<Text>{t('controlLayers.segment.foreground')}</Text>
|
||||
</Radio>
|
||||
<Radio value="background">
|
||||
<Text>{t('controlLayers.segment.background')}</Text>
|
||||
</Radio>
|
||||
<Radio value="neutral">
|
||||
<Text>{t('controlLayers.segment.neutral')}</Text>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SegmentAnythingPointType.displayName = 'SegmentAnythingPointType';
|
||||
@@ -1,28 +1,28 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectIsolatedFilteringPreview,
|
||||
settingsIsolatedFilteringPreviewToggled,
|
||||
selectIsolatedLayerPreview,
|
||||
settingsIsolatedLayerPreviewToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsIsolatedFilteringPreviewSwitch = memo(() => {
|
||||
export const CanvasSettingsIsolatedLayerPreviewSwitch = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isolatedFilteringPreview = useAppSelector(selectIsolatedFilteringPreview);
|
||||
const isolatedLayerPreview = useAppSelector(selectIsolatedLayerPreview);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsIsolatedFilteringPreviewToggled());
|
||||
dispatch(settingsIsolatedLayerPreviewToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel m={0} flexGrow={1}>
|
||||
{t('controlLayers.settings.isolatedFilteringPreview')}
|
||||
{t('controlLayers.settings.isolatedLayerPreview')}
|
||||
</FormLabel>
|
||||
<Switch size="sm" isChecked={isolatedFilteringPreview} onChange={onChange} />
|
||||
<Switch size="sm" isChecked={isolatedLayerPreview} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsIsolatedFilteringPreviewSwitch.displayName = 'CanvasSettingsIsolatedFilteringPreviewSwitch';
|
||||
CanvasSettingsIsolatedLayerPreviewSwitch.displayName = 'CanvasSettingsIsolatedLayerPreviewSwitch';
|
||||
@@ -1,28 +0,0 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectIsolatedTransformingPreview,
|
||||
settingsIsolatedTransformingPreviewToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsIsolatedTransformingPreviewSwitch = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isolatedTransformingPreview = useAppSelector(selectIsolatedTransformingPreview);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsIsolatedTransformingPreviewToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel m={0} flexGrow={1}>
|
||||
{t('controlLayers.settings.isolatedTransformingPreview')}
|
||||
</FormLabel>
|
||||
<Switch size="sm" isChecked={isolatedTransformingPreview} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsIsolatedTransformingPreviewSwitch.displayName = 'CanvasSettingsIsolatedTransformingPreviewSwitch';
|
||||
@@ -16,9 +16,8 @@ import { CanvasSettingsClipToBboxCheckbox } from 'features/controlLayers/compone
|
||||
import { CanvasSettingsDynamicGridSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsDynamicGridSwitch';
|
||||
import { CanvasSettingsSnapToGridCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsGridSize';
|
||||
import { CanvasSettingsInvertScrollCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsInvertScrollCheckbox';
|
||||
import { CanvasSettingsIsolatedFilteringPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedFilteringPreviewSwitch';
|
||||
import { CanvasSettingsIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedLayerPreviewSwitch';
|
||||
import { CanvasSettingsIsolatedStagingPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedStagingPreviewSwitch';
|
||||
import { CanvasSettingsIsolatedTransformingPreviewSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsIsolatedTransformingPreviewSwitch';
|
||||
import { CanvasSettingsLogDebugInfoButton } from 'features/controlLayers/components/Settings/CanvasSettingsLogDebugInfo';
|
||||
import { CanvasSettingsOutputOnlyMaskedRegionsCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsOutputOnlyMaskedRegionsCheckbox';
|
||||
import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPreserveMaskCheckbox';
|
||||
@@ -54,8 +53,7 @@ export const CanvasSettingsPopover = memo(() => {
|
||||
<CanvasSettingsPressureSensitivityCheckbox />
|
||||
<CanvasSettingsShowProgressOnCanvas />
|
||||
<CanvasSettingsIsolatedStagingPreviewSwitch />
|
||||
<CanvasSettingsIsolatedFilteringPreviewSwitch />
|
||||
<CanvasSettingsIsolatedTransformingPreviewSwitch />
|
||||
<CanvasSettingsIsolatedLayerPreviewSwitch />
|
||||
<CanvasSettingsDynamicGridSwitch />
|
||||
<CanvasSettingsBboxOverlaySwitch />
|
||||
<CanvasSettingsShowHUDSwitch />
|
||||
|
||||
@@ -10,8 +10,8 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const onClick = useCallback(() => {
|
||||
canvasManager.bbox.fitToLayers();
|
||||
}, [canvasManager.bbox]);
|
||||
canvasManager.tool.tools.bbox.fitToLayers();
|
||||
}, [canvasManager.tool.tools.bbox]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
|
||||
@@ -1,30 +1,21 @@
|
||||
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { CanvasOperationIsolatedLayerPreviewSwitch } from 'features/controlLayers/components/CanvasOperationIsolatedLayerPreviewSwitch';
|
||||
import { TransformFitToBboxButtons } from 'features/controlLayers/components/Transform/TransformFitToBboxButtons';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import {
|
||||
selectIsolatedTransformingPreview,
|
||||
settingsIsolatedTransformingPreviewToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { memo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
const TransformContent = memo(({ adapter }: { adapter: CanvasEntityAdapter }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.transformer.$isProcessing);
|
||||
const isolatedTransformingPreview = useAppSelector(selectIsolatedTransformingPreview);
|
||||
const onChangeIsolatedPreview = useCallback(() => {
|
||||
dispatch(settingsIsolatedTransformingPreviewToggled());
|
||||
}, [dispatch]);
|
||||
const silentTransform = useStore(adapter.transformer.$silentTransform);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -66,10 +57,7 @@ const TransformContent = memo(({ adapter }: { adapter: CanvasEntityAdapter }) =>
|
||||
{t('controlLayers.transform.transform')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.settings.isolatedPreview')}</FormLabel>
|
||||
<Switch size="sm" isChecked={isolatedTransformingPreview} onChange={onChangeIsolatedPreview} />
|
||||
</FormControl>
|
||||
<CanvasOperationIsolatedLayerPreviewSwitch />
|
||||
</Flex>
|
||||
|
||||
<TransformFitToBboxButtons adapter={adapter} />
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntitySegmentAnything } from 'features/controlLayers/hooks/useEntitySegmentAnything';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiMaskHappyBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityMenuItemsSegment = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const segmentAnything = useEntitySegmentAnything(entityIdentifier);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={segmentAnything.start} icon={<PiMaskHappyBold />} isDisabled={segmentAnything.isDisabled}>
|
||||
{t('controlLayers.segment.autoMask')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityMenuItemsSegment.displayName = 'CanvasEntityMenuItemsSegment';
|
||||
@@ -0,0 +1,57 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $false } from 'app/store/nanostores/util';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { isSegmentableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
export const useEntitySegmentAnything = (entityIdentifier: CanvasEntityIdentifier | null) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useEntityAdapterSafe(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
|
||||
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
if (!entityIdentifier) {
|
||||
return true;
|
||||
}
|
||||
if (!isSegmentableEntityIdentifier(entityIdentifier)) {
|
||||
return true;
|
||||
}
|
||||
if (!adapter) {
|
||||
return true;
|
||||
}
|
||||
if (isBusy) {
|
||||
return true;
|
||||
}
|
||||
if (!isInteractable) {
|
||||
return true;
|
||||
}
|
||||
if (isEmpty) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}, [entityIdentifier, adapter, isBusy, isInteractable, isEmpty]);
|
||||
|
||||
const start = useCallback(() => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
if (!entityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isSegmentableEntityIdentifier(entityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
const adapter = canvasManager.getAdapter(entityIdentifier);
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
adapter.segmentAnything.start();
|
||||
}, [isDisabled, entityIdentifier, canvasManager]);
|
||||
|
||||
return { isDisabled, start } as const;
|
||||
};
|
||||
@@ -10,11 +10,9 @@ import type { CanvasEntityTransformer } from 'features/controlLayers/konva/Canva
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
|
||||
import { getKonvaNodeDebugAttrs, getRectIntersection } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
selectIsolatedFilteringPreview,
|
||||
selectIsolatedTransformingPreview,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectIsolatedLayerPreview } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import {
|
||||
buildSelectIsHidden,
|
||||
buildSelectIsSelected,
|
||||
@@ -72,6 +70,15 @@ export abstract class CanvasEntityAdapterBase<
|
||||
// without requiring all adapters to implement this property and their own `destroy`?
|
||||
abstract filterer?: CanvasEntityFilterer;
|
||||
|
||||
/**
|
||||
* The segment anything module for this entity adapter. Entities that support segment anything should implement
|
||||
* this property.
|
||||
*/
|
||||
// TODO(psyche): This is in the ABC and not in the concrete classes to allow all adapters to share the `destroy`
|
||||
// method. If it wasn't in this ABC, we'd get a TS error in `destroy`. Maybe there's a better way to handle this
|
||||
// without requiring all adapters to implement this property and their own `destroy`?
|
||||
abstract segmentAnything?: CanvasSegmentAnythingModule;
|
||||
|
||||
/**
|
||||
* Synchronizes the entity state with the canvas. This includes rendering the entity's objects, handling visibility,
|
||||
* positioning, opacity, locked state, and any other properties.
|
||||
@@ -264,13 +271,11 @@ export abstract class CanvasEntityAdapterBase<
|
||||
*/
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsHidden, this.syncVisibility));
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectIsolatedFilteringPreview, this.syncVisibility)
|
||||
this.manager.stateApi.createStoreSubscription(selectIsolatedLayerPreview, this.syncVisibility)
|
||||
);
|
||||
this.subscriptions.add(this.manager.stateApi.$filteringAdapter.listen(this.syncVisibility));
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectIsolatedTransformingPreview, this.syncVisibility)
|
||||
);
|
||||
this.subscriptions.add(this.manager.stateApi.$transformingAdapter.listen(this.syncVisibility));
|
||||
this.subscriptions.add(this.manager.stateApi.$segmentingAdapter.listen(this.syncVisibility));
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsSelected, this.syncVisibility));
|
||||
|
||||
/**
|
||||
@@ -435,8 +440,10 @@ export abstract class CanvasEntityAdapterBase<
|
||||
return;
|
||||
}
|
||||
|
||||
const isolatedLayerPreview = this.manager.stateApi.runSelector(selectIsolatedLayerPreview);
|
||||
|
||||
// Handle isolated preview modes - if another entity is filtering or transforming, we may need to hide this entity.
|
||||
if (this.manager.stateApi.runSelector(selectIsolatedFilteringPreview)) {
|
||||
if (isolatedLayerPreview) {
|
||||
const filteringEntityIdentifier = this.manager.stateApi.$filteringAdapter.get()?.entityIdentifier;
|
||||
if (filteringEntityIdentifier && filteringEntityIdentifier.id !== this.id) {
|
||||
this.setVisibility(false);
|
||||
@@ -444,7 +451,7 @@ export abstract class CanvasEntityAdapterBase<
|
||||
}
|
||||
}
|
||||
|
||||
if (this.manager.stateApi.runSelector(selectIsolatedTransformingPreview)) {
|
||||
if (isolatedLayerPreview) {
|
||||
const transformingEntity = this.manager.stateApi.$transformingAdapter.get();
|
||||
if (
|
||||
transformingEntity &&
|
||||
@@ -457,6 +464,14 @@ export abstract class CanvasEntityAdapterBase<
|
||||
}
|
||||
}
|
||||
|
||||
if (isolatedLayerPreview) {
|
||||
const segmentingEntity = this.manager.stateApi.$segmentingAdapter.get();
|
||||
if (segmentingEntity && segmentingEntity.entityIdentifier.id !== this.id) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If the entity is not selected and offscreen, we can hide it
|
||||
if (!this.$isOnScreen.get() && !this.manager.stateApi.getIsSelected(this.entityIdentifier.id)) {
|
||||
this.setVisibility(false);
|
||||
@@ -517,8 +532,17 @@ export abstract class CanvasEntityAdapterBase<
|
||||
this.transformer.stopTransform();
|
||||
}
|
||||
this.transformer.destroy();
|
||||
if (this.filterer?.$isFiltering.get()) {
|
||||
this.filterer.cancel();
|
||||
if (this.filterer) {
|
||||
if (this.filterer.$isFiltering.get()) {
|
||||
this.filterer.cancel();
|
||||
}
|
||||
this.filterer?.destroy();
|
||||
}
|
||||
if (this.segmentAnything) {
|
||||
if (this.segmentAnything.$isSegmenting.get()) {
|
||||
this.segmentAnything.cancel();
|
||||
}
|
||||
this.segmentAnything.destroy();
|
||||
}
|
||||
this.konva.layer.destroy();
|
||||
this.manager.deleteAdapter(this.entityIdentifier);
|
||||
@@ -534,6 +558,7 @@ export abstract class CanvasEntityAdapterBase<
|
||||
transformer: this.transformer.repr(),
|
||||
renderer: this.renderer.repr(),
|
||||
bufferRenderer: this.bufferRenderer.repr(),
|
||||
segmentAnything: this.segmentAnything?.repr(),
|
||||
filterer: this.filterer?.repr(),
|
||||
hasCache: this.$canvasCache.get() !== null,
|
||||
isLocked: this.$isLocked.get(),
|
||||
|
||||
@@ -5,6 +5,7 @@ import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/
|
||||
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
|
||||
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
|
||||
import type { CanvasControlLayerState, CanvasEntityIdentifier, Rect } from 'features/controlLayers/store/types';
|
||||
import type { GroupConfig } from 'konva/lib/Group';
|
||||
import { omit } from 'lodash-es';
|
||||
@@ -17,6 +18,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
|
||||
bufferRenderer: CanvasEntityBufferObjectRenderer;
|
||||
transformer: CanvasEntityTransformer;
|
||||
filterer: CanvasEntityFilterer;
|
||||
segmentAnything: CanvasSegmentAnythingModule;
|
||||
|
||||
constructor(entityIdentifier: CanvasEntityIdentifier<'control_layer'>, manager: CanvasManager) {
|
||||
super(entityIdentifier, manager, 'control_layer_adapter');
|
||||
@@ -25,6 +27,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
|
||||
this.bufferRenderer = new CanvasEntityBufferObjectRenderer(this);
|
||||
this.transformer = new CanvasEntityTransformer(this);
|
||||
this.filterer = new CanvasEntityFilterer(this);
|
||||
this.segmentAnything = new CanvasSegmentAnythingModule(this);
|
||||
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectState, this.sync));
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<
|
||||
bufferRenderer: CanvasEntityBufferObjectRenderer;
|
||||
transformer: CanvasEntityTransformer;
|
||||
filterer = undefined;
|
||||
segmentAnything = undefined;
|
||||
|
||||
constructor(entityIdentifier: CanvasEntityIdentifier<'inpaint_mask'>, manager: CanvasManager) {
|
||||
super(entityIdentifier, manager, 'inpaint_mask_adapter');
|
||||
|
||||
@@ -5,6 +5,7 @@ import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/
|
||||
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
|
||||
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
|
||||
import type { CanvasEntityIdentifier, CanvasRasterLayerState, Rect } from 'features/controlLayers/store/types';
|
||||
import type { GroupConfig } from 'konva/lib/Group';
|
||||
import { omit } from 'lodash-es';
|
||||
@@ -17,6 +18,7 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
|
||||
bufferRenderer: CanvasEntityBufferObjectRenderer;
|
||||
transformer: CanvasEntityTransformer;
|
||||
filterer: CanvasEntityFilterer;
|
||||
segmentAnything: CanvasSegmentAnythingModule;
|
||||
|
||||
constructor(entityIdentifier: CanvasEntityIdentifier<'raster_layer'>, manager: CanvasManager) {
|
||||
super(entityIdentifier, manager, 'raster_layer_adapter');
|
||||
@@ -25,6 +27,7 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
|
||||
this.bufferRenderer = new CanvasEntityBufferObjectRenderer(this);
|
||||
this.transformer = new CanvasEntityTransformer(this);
|
||||
this.filterer = new CanvasEntityFilterer(this);
|
||||
this.segmentAnything = new CanvasSegmentAnythingModule(this);
|
||||
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectState, this.sync));
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ export class CanvasEntityAdapterRegionalGuidance extends CanvasEntityAdapterBase
|
||||
bufferRenderer: CanvasEntityBufferObjectRenderer;
|
||||
transformer: CanvasEntityTransformer;
|
||||
filterer = undefined;
|
||||
segmentAnything = undefined;
|
||||
|
||||
constructor(entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>, manager: CanvasManager) {
|
||||
super(entityIdentifier, manager, 'regional_guidance_adapter');
|
||||
|
||||
@@ -4,7 +4,7 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectAutoProcessFilter } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type { FilterConfig } from 'features/controlLayers/store/filters';
|
||||
import { getFilterForModel, IMAGE_FILTERS } from 'features/controlLayers/store/filters';
|
||||
import type { CanvasImageState } from 'features/controlLayers/store/types';
|
||||
@@ -15,7 +15,6 @@ import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
|
||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type CanvasEntityFiltererConfig = {
|
||||
processDebounceMs: number;
|
||||
@@ -56,30 +55,41 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating filter module');
|
||||
}
|
||||
|
||||
subscribe = () => {
|
||||
this.subscriptions.add(
|
||||
this.$filterConfig.listen(() => {
|
||||
if (this.manager.stateApi.getSettings().autoProcessFilter && this.$isFiltering.get()) {
|
||||
if (this.manager.stateApi.getSettings().autoProcess && this.$isFiltering.get()) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
);
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectAutoProcessFilter, (autoPreviewFilter) => {
|
||||
if (autoPreviewFilter && this.$isFiltering.get()) {
|
||||
this.manager.stateApi.createStoreSubscription(selectAutoProcess, (autoProcess) => {
|
||||
if (autoProcess && this.$isFiltering.get()) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
unsubscribe = () => {
|
||||
this.subscriptions.forEach((unsubscribe) => unsubscribe());
|
||||
this.subscriptions.clear();
|
||||
};
|
||||
|
||||
start = (config?: FilterConfig) => {
|
||||
const filteringAdapter = this.manager.stateApi.$filteringAdapter.get();
|
||||
if (filteringAdapter) {
|
||||
assert(false, `Already filtering an entity: ${filteringAdapter.id}`);
|
||||
this.log.error(`Already filtering an entity: ${filteringAdapter.id}`);
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.trace('Initializing filter');
|
||||
|
||||
this.subscribe();
|
||||
|
||||
if (config) {
|
||||
this.$filterConfig.set(config);
|
||||
} else if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) {
|
||||
@@ -97,7 +107,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
}
|
||||
this.$isFiltering.set(true);
|
||||
this.manager.stateApi.$filteringAdapter.set(this.parent);
|
||||
if (this.manager.stateApi.getSettings().autoProcessFilter) {
|
||||
if (this.manager.stateApi.getSettings().autoProcess) {
|
||||
this.processImmediate();
|
||||
}
|
||||
};
|
||||
@@ -204,6 +214,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
replaceObjects: true,
|
||||
});
|
||||
this.imageState = null;
|
||||
this.unsubscribe();
|
||||
this.$isFiltering.set(false);
|
||||
this.$hasProcessed.set(false);
|
||||
this.manager.stateApi.$filteringAdapter.set(null);
|
||||
@@ -225,6 +236,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
this.log.trace('Cancelling filter');
|
||||
|
||||
this.reset();
|
||||
this.unsubscribe();
|
||||
this.$isProcessing.set(false);
|
||||
this.$isFiltering.set(false);
|
||||
this.$hasProcessed.set(false);
|
||||
@@ -243,4 +255,13 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
$filterConfig: this.$filterConfig.get(),
|
||||
};
|
||||
};
|
||||
|
||||
destroy = () => {
|
||||
this.log.debug('Destroying module');
|
||||
if (this.abortController && !this.abortController.signal.aborted) {
|
||||
this.abortController.abort();
|
||||
}
|
||||
this.abortController = null;
|
||||
this.unsubscribe();
|
||||
};
|
||||
}
|
||||
|
||||
@@ -234,8 +234,25 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
|
||||
this.konva.transformer.on('transform', this.syncObjectGroupWithProxyRect);
|
||||
this.konva.transformer.on('transformend', this.snapProxyRectToPixelGrid);
|
||||
this.konva.transformer.on('pointerenter', () => {
|
||||
this.manager.stage.setCursor('move');
|
||||
});
|
||||
this.konva.transformer.on('pointerleave', () => {
|
||||
this.manager.stage.setCursor('default');
|
||||
});
|
||||
this.konva.proxyRect.on('dragmove', this.onDragMove);
|
||||
this.konva.proxyRect.on('dragend', this.onDragEnd);
|
||||
this.konva.proxyRect.on('pointerenter', () => {
|
||||
this.manager.stage.setCursor('move');
|
||||
});
|
||||
this.konva.proxyRect.on('pointerleave', () => {
|
||||
this.manager.stage.setCursor('default');
|
||||
});
|
||||
|
||||
this.subscriptions.add(() => {
|
||||
this.konva.transformer.off('transform transformend pointerenter pointerleave');
|
||||
this.konva.proxyRect.off('dragmove dragend pointerenter pointerleave');
|
||||
});
|
||||
|
||||
// When the stage scale changes, we may need to re-scale some of the transformer's components. For example,
|
||||
// the bbox outline should always be 1 screen pixel wide, so we need to update its stroke width.
|
||||
@@ -574,9 +591,9 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
syncInteractionState = () => {
|
||||
this.log.trace('Syncing interaction state');
|
||||
|
||||
if (this.manager.$isBusy.get() && !this.$isTransforming.get()) {
|
||||
// The canvas is busy, we can't interact with the transformer
|
||||
this.parent.konva.layer.listening(false);
|
||||
if (this.parent.segmentAnything?.$isSegmenting.get()) {
|
||||
// When segmenting, the layer should listen but the transformer should not be interactable
|
||||
this.parent.konva.layer.listening(true);
|
||||
this._setInteractionMode('off');
|
||||
return;
|
||||
}
|
||||
@@ -609,6 +626,13 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
const tool = this.manager.tool.$tool.get();
|
||||
const isSelected = this.manager.stateApi.getIsSelected(this.parent.id);
|
||||
|
||||
if (!isSelected) {
|
||||
// The layer is not selected
|
||||
this.parent.konva.layer.listening(false);
|
||||
this._setInteractionMode('off');
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.parent.$isEmpty.get()) {
|
||||
// The layer is totally empty, we can just disable the layer
|
||||
this.parent.konva.layer.listening(false);
|
||||
@@ -616,14 +640,14 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isSelected && !this.$isTransforming.get() && tool === 'move') {
|
||||
if (!this.$isTransforming.get() && tool === 'move') {
|
||||
// We are moving this layer, it must be listening
|
||||
this.parent.konva.layer.listening(true);
|
||||
this._setInteractionMode('drag');
|
||||
return;
|
||||
}
|
||||
|
||||
if (isSelected && this.$isTransforming.get()) {
|
||||
if (this.$isTransforming.get()) {
|
||||
// When transforming, we want the stage to still be movable if the view tool is selected. If the transformer is
|
||||
// active, it will interrupt the stage drag events. So we should disable listening when the view tool is selected.
|
||||
if (tool === 'view') {
|
||||
@@ -633,11 +657,12 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
this.parent.konva.layer.listening(true);
|
||||
this._setInteractionMode('all');
|
||||
}
|
||||
} else {
|
||||
// The layer is not selected, or we are using a tool that doesn't need the layer to be listening - disable interaction stuff
|
||||
this.parent.konva.layer.listening(false);
|
||||
this._setInteractionMode('off');
|
||||
return;
|
||||
}
|
||||
|
||||
// The layer is not selected
|
||||
this.parent.konva.layer.listening(false);
|
||||
this._setInteractionMode('off');
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { SyncableMap } from 'common/util/SyncableMap/SyncableMap';
|
||||
import { CanvasBboxModule } from 'features/controlLayers/konva/CanvasBboxModule';
|
||||
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
|
||||
import { CanvasCompositorModule } from 'features/controlLayers/konva/CanvasCompositorModule';
|
||||
import { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
@@ -62,7 +61,6 @@ export class CanvasManager extends CanvasModuleBase {
|
||||
entityRenderer: CanvasEntityRendererModule;
|
||||
compositor: CanvasCompositorModule;
|
||||
tool: CanvasToolModule;
|
||||
bbox: CanvasBboxModule;
|
||||
stagingArea: CanvasStagingAreaModule;
|
||||
progressImage: CanvasProgressImageModule;
|
||||
|
||||
@@ -111,11 +109,12 @@ export class CanvasManager extends CanvasModuleBase {
|
||||
this.stateApi.$isFiltering,
|
||||
this.stateApi.$isTransforming,
|
||||
this.stateApi.$isRasterizing,
|
||||
this.stateApi.$isSegmenting,
|
||||
this.stagingArea.$isStaging,
|
||||
this.compositor.$isBusy,
|
||||
],
|
||||
(isFiltering, isTransforming, isRasterizing, isStaging, isCompositing) => {
|
||||
return isFiltering || isTransforming || isRasterizing || isStaging || isCompositing;
|
||||
(isFiltering, isTransforming, isRasterizing, isSegmenting, isStaging, isCompositing) => {
|
||||
return isFiltering || isTransforming || isRasterizing || isSegmenting || isStaging || isCompositing;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -123,18 +122,16 @@ export class CanvasManager extends CanvasModuleBase {
|
||||
this.stage.addLayer(this.background.konva.layer);
|
||||
|
||||
this.konva = {
|
||||
previewLayer: new Konva.Layer({ listening: false, imageSmoothingEnabled: false }),
|
||||
previewLayer: new Konva.Layer({ listening: true, imageSmoothingEnabled: false }),
|
||||
};
|
||||
this.stage.addLayer(this.konva.previewLayer);
|
||||
|
||||
this.tool = new CanvasToolModule(this);
|
||||
this.progressImage = new CanvasProgressImageModule(this);
|
||||
this.bbox = new CanvasBboxModule(this);
|
||||
|
||||
// Must add in this order for correct z-index
|
||||
this.konva.previewLayer.add(this.stagingArea.konva.group);
|
||||
this.konva.previewLayer.add(this.progressImage.konva.group);
|
||||
this.konva.previewLayer.add(this.bbox.konva.group);
|
||||
this.konva.previewLayer.add(this.tool.konva.group);
|
||||
}
|
||||
|
||||
@@ -232,7 +229,6 @@ export class CanvasManager extends CanvasModuleBase {
|
||||
|
||||
getAllModules = (): CanvasModuleBase[] => {
|
||||
return [
|
||||
this.bbox,
|
||||
this.stagingArea,
|
||||
this.tool,
|
||||
this.progressImage,
|
||||
@@ -280,7 +276,6 @@ export class CanvasManager extends CanvasModuleBase {
|
||||
inpaintMasks: Array.from(this.adapters.inpaintMasks.values()).map((adapter) => adapter.repr()),
|
||||
regionMasks: Array.from(this.adapters.regionMasks.values()).map((adapter) => adapter.repr()),
|
||||
stateApi: this.stateApi.repr(),
|
||||
bbox: this.bbox.repr(),
|
||||
stagingArea: this.stagingArea.repr(),
|
||||
tool: this.tool.repr(),
|
||||
progressImage: this.progressImage.repr(),
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Mutex } from 'async-mutex';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
|
||||
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
|
||||
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
|
||||
import type { CanvasStagingAreaModule } from 'features/controlLayers/konva/CanvasStagingAreaModule';
|
||||
import { loadImage } from 'features/controlLayers/konva/util';
|
||||
import type { CanvasImageState } from 'features/controlLayers/store/types';
|
||||
@@ -21,7 +21,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
|
||||
| CanvasEntityObjectRenderer
|
||||
| CanvasEntityBufferObjectRenderer
|
||||
| CanvasStagingAreaModule
|
||||
| CanvasEntityFilterer;
|
||||
| CanvasSegmentAnythingModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
@@ -42,7 +42,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
|
||||
| CanvasEntityObjectRenderer
|
||||
| CanvasEntityBufferObjectRenderer
|
||||
| CanvasStagingAreaModule
|
||||
| CanvasEntityFilterer
|
||||
| CanvasSegmentAnythingModule
|
||||
) {
|
||||
super();
|
||||
this.id = state.id;
|
||||
|
||||
@@ -8,9 +8,9 @@ import { atom } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
|
||||
import type { S } from 'services/api/types';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import type { SetNonNullable } from 'type-fest';
|
||||
|
||||
type ProgressEventWithImage = O.NonNullable<S['InvocationProgressEvent'], 'image'>;
|
||||
type ProgressEventWithImage = SetNonNullable<S['InvocationProgressEvent'], 'image'>;
|
||||
const isProgressEventWithImage = (val: S['InvocationProgressEvent']): val is ProgressEventWithImage =>
|
||||
Boolean(val.image);
|
||||
|
||||
|
||||
@@ -0,0 +1,789 @@
|
||||
import { rgbaColorToString } from 'common/util/colorCodeTransformers';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
|
||||
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
|
||||
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type {
|
||||
CanvasImageState,
|
||||
Coordinate,
|
||||
RgbaColor,
|
||||
SAMPoint,
|
||||
SAMPointLabel,
|
||||
SAMPointLabelString,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { Atom } from 'nanostores';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
type CanvasSegmentAnythingModuleConfig = {
|
||||
/**
|
||||
* The radius of the SAM point Konva circle node.
|
||||
*/
|
||||
SAM_POINT_RADIUS: number;
|
||||
/**
|
||||
* The border width of the SAM point Konva circle node.
|
||||
*/
|
||||
SAM_POINT_BORDER_WIDTH: number;
|
||||
/**
|
||||
* The border color of the SAM point Konva circle node.
|
||||
*/
|
||||
SAM_POINT_BORDER_COLOR: RgbaColor;
|
||||
/**
|
||||
* The color of the SAM point Konva circle node when the label is 1.
|
||||
*/
|
||||
SAM_POINT_FOREGROUND_COLOR: RgbaColor;
|
||||
/**
|
||||
* The color of the SAM point Konva circle node when the label is -1.
|
||||
*/
|
||||
SAM_POINT_BACKGROUND_COLOR: RgbaColor;
|
||||
/**
|
||||
* The color of the SAM point Konva circle node when the label is 0.
|
||||
*/
|
||||
SAM_POINT_NEUTRAL_COLOR: RgbaColor;
|
||||
/**
|
||||
* The color to use for the mask preview overlay.
|
||||
*/
|
||||
MASK_COLOR: RgbaColor;
|
||||
/**
|
||||
* The debounce time in milliseconds for processing the points.
|
||||
*/
|
||||
PROCESS_DEBOUNCE_MS: number;
|
||||
};
|
||||
|
||||
const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
|
||||
SAM_POINT_RADIUS: 8,
|
||||
SAM_POINT_BORDER_WIDTH: 2,
|
||||
SAM_POINT_BORDER_COLOR: { r: 0, g: 0, b: 0, a: 1 },
|
||||
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // light green
|
||||
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
|
||||
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
|
||||
MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity
|
||||
PROCESS_DEBOUNCE_MS: 1000,
|
||||
};
|
||||
|
||||
/**
|
||||
* The state of a SAM point.
|
||||
* @property id - The unique identifier of the point.
|
||||
* @property label - The label of the point. -1 is background, 0 is neutral, 1 is foreground.
|
||||
* @property konva - The Konva node state of the point.
|
||||
* @property konva.circle - The Konva circle node of the point. The x and y coordinates for the point are derived from
|
||||
* this node.
|
||||
*/
|
||||
type SAMPointState = {
|
||||
id: string;
|
||||
label: SAMPointLabel;
|
||||
konva: {
|
||||
circle: Konva.Circle;
|
||||
};
|
||||
};
|
||||
|
||||
export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
readonly type = 'canvas_segment_anything';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
config: CanvasSegmentAnythingModuleConfig = DEFAULT_CONFIG;
|
||||
|
||||
subscriptions = new Set<() => void>();
|
||||
|
||||
/**
|
||||
* The AbortController used to cancel the filter processing.
|
||||
*/
|
||||
abortController: AbortController | null = null;
|
||||
|
||||
/**
|
||||
* Whether the module is currently segmenting an entity.
|
||||
*/
|
||||
$isSegmenting = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* Whether the current set of points has been processed.
|
||||
*/
|
||||
$hasProcessed = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* Whether the module is currently processing the points.
|
||||
*/
|
||||
$isProcessing = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* The type of point to create when segmenting. This is a number representation of the SAMPointLabel enum.
|
||||
*/
|
||||
$pointType = atom<SAMPointLabel>(1);
|
||||
|
||||
/**
|
||||
* The type of point to create when segmenting, as a string. This is a computed value based on $pointType.
|
||||
*/
|
||||
$pointTypeString = computed<SAMPointLabelString, Atom<SAMPointLabel>>(
|
||||
this.$pointType,
|
||||
(pointType) => SAM_POINT_LABEL_NUMBER_TO_STRING[pointType]
|
||||
);
|
||||
|
||||
/**
|
||||
* Whether a point is currently being dragged. This is used to prevent the point additions and deletions during
|
||||
* dragging.
|
||||
*/
|
||||
$isDraggingPoint = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* The ephemeral image state of the processed image. Only used while segmenting.
|
||||
*/
|
||||
imageState: CanvasImageState | null = null;
|
||||
|
||||
/**
|
||||
* The current input points.
|
||||
*/
|
||||
$points = atom<SAMPointState[]>([]);
|
||||
|
||||
/**
|
||||
* Whether the module has points. This is a computed value based on $points.
|
||||
*/
|
||||
$hasPoints = computed(this.$points, (points) => points.length > 0);
|
||||
|
||||
/**
|
||||
* The masked image object, if it exists.
|
||||
*/
|
||||
maskedImage: CanvasObjectImage | null = null;
|
||||
|
||||
/**
|
||||
* The Konva nodes for the module.
|
||||
*/
|
||||
konva: {
|
||||
/**
|
||||
* The main Konva group node for the module.
|
||||
*/
|
||||
group: Konva.Group;
|
||||
/**
|
||||
* The Konva group node for the SAM points.
|
||||
*
|
||||
* This is a child of the main group node, rendered above the mask group.
|
||||
*/
|
||||
pointGroup: Konva.Group;
|
||||
/**
|
||||
* The Konva group node for the mask image and compositing rect.
|
||||
*
|
||||
* This is a child of the main group node, rendered below the point group.
|
||||
*/
|
||||
maskGroup: Konva.Group;
|
||||
/**
|
||||
* The Konva rect node for compositing the mask image.
|
||||
*
|
||||
* It's rendered with a globalCompositeOperation of 'source-atop' to preview the mask as a semi-transparent overlay.
|
||||
*/
|
||||
compositingRect: Konva.Rect;
|
||||
};
|
||||
|
||||
KONVA_CIRCLE_NAME = `${this.type}:circle`;
|
||||
KONVA_GROUP_NAME = `${this.type}:group`;
|
||||
KONVA_POINT_GROUP_NAME = `${this.type}:point_group`;
|
||||
KONVA_MASK_GROUP_NAME = `${this.type}:mask_group`;
|
||||
KONVA_COMPOSITING_RECT_NAME = `${this.type}:compositing_rect`;
|
||||
|
||||
constructor(parent: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
|
||||
// Create all konva nodes
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: this.KONVA_GROUP_NAME }),
|
||||
pointGroup: new Konva.Group({ name: this.KONVA_POINT_GROUP_NAME }),
|
||||
maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME }),
|
||||
compositingRect: new Konva.Rect({
|
||||
name: this.KONVA_COMPOSITING_RECT_NAME,
|
||||
fill: rgbaColorToString(this.config.MASK_COLOR),
|
||||
globalCompositeOperation: 'source-atop',
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
visible: false,
|
||||
}),
|
||||
};
|
||||
|
||||
// Points should always be rendered above the mask group
|
||||
this.konva.group.add(this.konva.maskGroup);
|
||||
this.konva.group.add(this.konva.pointGroup);
|
||||
|
||||
// Compositing rect is added to the mask group - will also be above the mask image, but that doesn't get created
|
||||
// until after processing
|
||||
this.konva.maskGroup.add(this.konva.compositingRect);
|
||||
}
|
||||
|
||||
/**
|
||||
* Synchronizes the cursor style to crosshair.
|
||||
*/
|
||||
syncCursorStyle = (): void => {
|
||||
if (this.$isProcessing.get()) {
|
||||
this.manager.stage.setCursor('wait');
|
||||
} else if (this.$isSegmenting.get()) {
|
||||
this.manager.stage.setCursor('crosshair');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a SAM point at the given coordinate with the given label. -1 is background, 0 is neutral, 1 is foreground.
|
||||
* @param coord The coordinate
|
||||
* @param label The label.
|
||||
* @returns The SAM point state.
|
||||
*/
|
||||
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
|
||||
const id = getPrefixedId('sam_point');
|
||||
|
||||
const circle = new Konva.Circle({
|
||||
name: this.KONVA_CIRCLE_NAME,
|
||||
x: Math.round(coord.x),
|
||||
y: Math.round(coord.y),
|
||||
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes
|
||||
fill: rgbaColorToString(this.getSAMPointColor(label)),
|
||||
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
|
||||
strokeWidth: this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH), // We will scale this as the stage scale changes
|
||||
draggable: true,
|
||||
perfectDrawEnabled: true, // Required for the stroke/fill to draw correctly w/ partial opacity
|
||||
opacity: 0.6,
|
||||
dragDistance: 3,
|
||||
});
|
||||
|
||||
// When the point is clicked, remove it
|
||||
circle.on('pointerup', (e) => {
|
||||
// Ignore if we are dragging
|
||||
if (this.$isDraggingPoint.get()) {
|
||||
return;
|
||||
}
|
||||
// This event should not bubble up to the parent, stage or any other nodes
|
||||
e.cancelBubble = true;
|
||||
circle.destroy();
|
||||
this.$points.set(this.$points.get().filter((point) => point.id !== id));
|
||||
if (this.$points.get().length === 0) {
|
||||
this.resetEphemeralState();
|
||||
} else {
|
||||
this.$hasProcessed.set(false);
|
||||
}
|
||||
});
|
||||
|
||||
circle.on('dragstart', () => {
|
||||
this.$isDraggingPoint.set(true);
|
||||
});
|
||||
|
||||
circle.on('dragend', () => {
|
||||
this.$isDraggingPoint.set(false);
|
||||
// Point has changed!
|
||||
this.$hasProcessed.set(false);
|
||||
this.$points.notify();
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Moved SAM point'
|
||||
);
|
||||
});
|
||||
|
||||
this.konva.pointGroup.add(circle);
|
||||
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Created SAM point'
|
||||
);
|
||||
|
||||
return {
|
||||
id,
|
||||
label,
|
||||
konva: { circle },
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Synchronizes the scales of the SAM points to the stage scale.
|
||||
*
|
||||
* SAM points are always the same size, regardless of the stage scale.
|
||||
*/
|
||||
syncPointScales = () => {
|
||||
const radius = this.manager.stage.unscale(this.config.SAM_POINT_RADIUS);
|
||||
const borderWidth = this.manager.stage.unscale(this.config.SAM_POINT_BORDER_WIDTH);
|
||||
for (const point of this.$points.get()) {
|
||||
point.konva.circle.radius(radius);
|
||||
point.konva.circle.strokeWidth(borderWidth);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
|
||||
*/
|
||||
getSAMPoints = (): SAMPoint[] => {
|
||||
const points: SAMPoint[] = [];
|
||||
|
||||
for (const { konva, label } of this.$points.get()) {
|
||||
points.push({
|
||||
// Pull out and round the x and y values from Konva
|
||||
x: Math.round(konva.circle.x()),
|
||||
y: Math.round(konva.circle.y()),
|
||||
label,
|
||||
});
|
||||
}
|
||||
|
||||
return points;
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointerup event on the stage. This is used to add a SAM point to the module.
|
||||
*/
|
||||
onStagePointerUp = (e: KonvaEventObject<PointerEvent>) => {
|
||||
// Only handle left-clicks
|
||||
if (e.evt.button !== 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if the stage is dragging/panning
|
||||
if (this.manager.stage.getIsDragging()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if a point is being dragged
|
||||
if (this.$isDraggingPoint.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if we are already processing
|
||||
if (this.$isProcessing.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore if the cursor is not within the stage (should never happen)
|
||||
const cursorPos = this.manager.tool.$cursorPos.get();
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We need to offset the cursor position by the parent entity's position + pixel rect to get the correct position
|
||||
const pixelRect = this.parent.transformer.$pixelRect.get();
|
||||
const parentPosition = addCoords(this.parent.state.position, pixelRect);
|
||||
|
||||
// Normalize the cursor position to the parent entity's position
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, parentPosition);
|
||||
|
||||
// Create a SAM point at the normalized position
|
||||
const point = this.createPoint(normalizedPoint, this.$pointType.get());
|
||||
this.$points.set([...this.$points.get(), point]);
|
||||
|
||||
// Mark the module as having _not_ processed the points now that they have changed
|
||||
this.$hasProcessed.set(false);
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds event listeners needed while segmenting the entity.
|
||||
*/
|
||||
subscribe = () => {
|
||||
this.manager.stage.konva.stage.on('pointerup', this.onStagePointerUp);
|
||||
this.subscriptions.add(() => {
|
||||
this.manager.stage.konva.stage.off('pointerup', this.onStagePointerUp);
|
||||
});
|
||||
|
||||
// When we change the processing status, we should update the cursor style and the layer's listening status. For
|
||||
// example, when processing, we should disable listening on the layer so the user can't add more points, else we
|
||||
// should enable listening.
|
||||
this.subscriptions.add(
|
||||
this.$isProcessing.listen((isProcessing) => {
|
||||
this.syncCursorStyle();
|
||||
this.parent.konva.layer.listening(!isProcessing);
|
||||
})
|
||||
);
|
||||
|
||||
// Scale the SAM points when the stage scale changes
|
||||
this.subscriptions.add(
|
||||
this.manager.stage.$stageAttrs.listen((stageAttrs, oldStageAttrs) => {
|
||||
if (stageAttrs.scale !== oldStageAttrs.scale) {
|
||||
this.syncPointScales();
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// When the points change, process them if autoProcess is enabled
|
||||
this.subscriptions.add(
|
||||
this.$points.listen((points) => {
|
||||
if (points.length === 0) {
|
||||
return;
|
||||
}
|
||||
if (this.manager.stateApi.getSettings().autoProcess) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// When auto-process is enabled, process the points if they have not been processed
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectAutoProcess, (autoProcess) => {
|
||||
if (this.$points.get().length === 0) {
|
||||
return;
|
||||
}
|
||||
if (autoProcess && !this.$hasProcessed.get()) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds event listeners needed while segmenting the entity.
|
||||
*/
|
||||
unsubscribe = () => {
|
||||
this.subscriptions.forEach((unsubscribe) => unsubscribe());
|
||||
this.subscriptions.clear();
|
||||
};
|
||||
|
||||
/**
|
||||
* Starts the segmenting process.
|
||||
*/
|
||||
start = () => {
|
||||
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
|
||||
if (segmentingAdapter) {
|
||||
this.log.error(`Already segmenting an entity: ${segmentingAdapter.id}`);
|
||||
return;
|
||||
}
|
||||
this.log.trace('Starting segment anything');
|
||||
|
||||
// Reset the module's state
|
||||
this.resetEphemeralState();
|
||||
this.$isSegmenting.set(true);
|
||||
|
||||
// Update the konva group's position to match the parent entity
|
||||
const pixelRect = this.parent.transformer.$pixelRect.get();
|
||||
const position = addCoords(this.parent.state.position, pixelRect);
|
||||
this.konva.group.setAttrs(position);
|
||||
|
||||
// Add the module's Konva group to the parent adapter's layer so it is rendered
|
||||
this.parent.konva.layer.add(this.konva.group);
|
||||
|
||||
// Enable listening on the parent adapter's layer so the module can receive pointer events
|
||||
this.parent.konva.layer.listening(true);
|
||||
|
||||
// Subscribe all listeners needed for segmenting (e.g. window pointerup, state listeners)
|
||||
this.subscribe();
|
||||
|
||||
// Set the global segmenting adapter to this module
|
||||
this.manager.stateApi.$segmentingAdapter.set(this.parent);
|
||||
|
||||
// Sync the cursor style to crosshair
|
||||
this.syncCursorStyle();
|
||||
};
|
||||
|
||||
/**
|
||||
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
|
||||
*/
|
||||
processImmediate = async () => {
|
||||
if (this.$isProcessing.get()) {
|
||||
this.log.warn('Already processing');
|
||||
return;
|
||||
}
|
||||
|
||||
const points = this.getSAMPoints();
|
||||
|
||||
if (points.length === 0) {
|
||||
this.log.trace('No points to segment');
|
||||
return;
|
||||
}
|
||||
|
||||
this.$isProcessing.set(true);
|
||||
|
||||
this.log.trace({ points }, 'Segmenting');
|
||||
|
||||
// Rasterize the entity in its current state
|
||||
const rect = this.parent.transformer.getRelativeRect();
|
||||
const rasterizeResult = await withResultAsync(() =>
|
||||
this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } })
|
||||
);
|
||||
|
||||
if (rasterizeResult.isErr()) {
|
||||
this.log.error({ error: serializeError(rasterizeResult.error) }, 'Error rasterizing entity');
|
||||
this.$isProcessing.set(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create an AbortController for the segmenting process
|
||||
const controller = new AbortController();
|
||||
this.abortController = controller;
|
||||
|
||||
// Build the graph for segmenting the image, using the rasterized image DTO
|
||||
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value);
|
||||
|
||||
// Run the graph and get the segmented image output
|
||||
const segmentResult = await withResultAsync(() =>
|
||||
this.manager.stateApi.runGraphAndReturnImageOutput({
|
||||
graph,
|
||||
outputNodeId,
|
||||
prepend: true,
|
||||
signal: controller.signal,
|
||||
})
|
||||
);
|
||||
|
||||
// If there is an error, log it and bail out of this processing run
|
||||
if (segmentResult.isErr()) {
|
||||
this.log.error({ error: serializeError(segmentResult.error) }, 'Error segmenting');
|
||||
this.$isProcessing.set(false);
|
||||
// Clean up the abort controller as needed
|
||||
if (!this.abortController.signal.aborted) {
|
||||
this.abortController.abort();
|
||||
}
|
||||
this.abortController = null;
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.trace({ imageDTO: segmentResult.value }, 'Segmented');
|
||||
|
||||
// Prepare the ephemeral image state
|
||||
this.imageState = imageDTOToImageObject(segmentResult.value);
|
||||
|
||||
// Destroy any existing masked image and create a new one
|
||||
if (this.maskedImage) {
|
||||
this.maskedImage.destroy();
|
||||
}
|
||||
this.maskedImage = new CanvasObjectImage(this.imageState, this);
|
||||
|
||||
// Force update the masked image - after awaiting, the image will be rendered (in memory)
|
||||
await this.maskedImage.update(this.imageState, true);
|
||||
|
||||
// Update the compositing rect to match the image size
|
||||
this.konva.compositingRect.setAttrs({
|
||||
width: this.imageState.image.width,
|
||||
height: this.imageState.image.height,
|
||||
visible: true,
|
||||
});
|
||||
|
||||
// Now we can add the masked image to the mask group. It will be rendered above the compositing rect, but should be
|
||||
// under it, so we will move the compositing rect to the top
|
||||
this.konva.maskGroup.add(this.maskedImage.konva.group);
|
||||
this.konva.compositingRect.moveToTop();
|
||||
|
||||
// Cache the group to ensure the mask is rendered correctly w/ opacity
|
||||
this.konva.maskGroup.cache();
|
||||
|
||||
// We are done processing (still segmenting though!)
|
||||
this.$isProcessing.set(false);
|
||||
|
||||
// The current points have been processed
|
||||
this.$hasProcessed.set(true);
|
||||
|
||||
// Clean up the abort controller as needed
|
||||
if (!this.abortController.signal.aborted) {
|
||||
this.abortController.abort();
|
||||
}
|
||||
this.abortController = null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Debounced version of processImmediate.
|
||||
*/
|
||||
process = debounce(this.processImmediate, this.config.PROCESS_DEBOUNCE_MS);
|
||||
|
||||
/**
|
||||
* Applies the segmented image to the entity.
|
||||
*/
|
||||
apply = () => {
|
||||
if (!this.$hasProcessed.get()) {
|
||||
this.log.error('Cannot apply unprocessed points');
|
||||
return;
|
||||
}
|
||||
const imageState = this.imageState;
|
||||
if (!imageState) {
|
||||
this.log.error('No image state to apply');
|
||||
return;
|
||||
}
|
||||
this.log.trace('Applying');
|
||||
|
||||
// Commit the buffer, which will move the buffer to from the layers' buffer renderer to its main renderer
|
||||
this.parent.bufferRenderer.commitBuffer();
|
||||
|
||||
// Rasterize the entity, this time replacing the objects with the masked image
|
||||
const rect = this.parent.transformer.getRelativeRect();
|
||||
this.manager.stateApi.rasterizeEntity({
|
||||
entityIdentifier: this.parent.entityIdentifier,
|
||||
imageObject: imageState,
|
||||
position: {
|
||||
x: Math.round(rect.x),
|
||||
y: Math.round(rect.y),
|
||||
},
|
||||
replaceObjects: true,
|
||||
});
|
||||
|
||||
// Final cleanup and teardown, returning user to main canvas UI
|
||||
this.resetEphemeralState();
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
/**
|
||||
* Resets the module (e.g. remove all points and the mask image).
|
||||
*
|
||||
* Does not cancel or otherwise complete the segmenting process.
|
||||
*/
|
||||
reset = () => {
|
||||
this.log.trace('Resetting');
|
||||
this.resetEphemeralState();
|
||||
};
|
||||
|
||||
/**
|
||||
* Cancels the segmenting process.
|
||||
*/
|
||||
cancel = () => {
|
||||
this.log.trace('Canceling');
|
||||
// Reset the module's state and tear down, returning user to main canvas UI
|
||||
this.resetEphemeralState();
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
/**
|
||||
* Performs teardown of the module. This shared logic is used for canceling and applying - when the segmenting is
|
||||
* complete and the module is deactivated.
|
||||
*
|
||||
* This method:
|
||||
* - Removes the module's main Konva node from the parent adapter's layer
|
||||
* - Removes segmenting event listeners (e.g. window pointerup)
|
||||
* - Resets the segmenting state
|
||||
* - Resets the global segmenting adapter
|
||||
*/
|
||||
teardown = () => {
|
||||
this.konva.group.remove();
|
||||
this.unsubscribe();
|
||||
this.$isSegmenting.set(false);
|
||||
this.manager.stateApi.$segmentingAdapter.set(null);
|
||||
};
|
||||
|
||||
/**
|
||||
* Resets the module's ephemeral state. This shared logic is used for resetting, canceling, and applying.
|
||||
*
|
||||
* This method:
|
||||
* - Aborts any processing
|
||||
* - Destroys ephemeral Konva nodes
|
||||
* - Resets internal module state
|
||||
* - Resets non-ephemeral Konva nodes
|
||||
* - Clears the parent module's buffer
|
||||
*/
|
||||
resetEphemeralState = () => {
|
||||
// First we need to bail out of any processing
|
||||
this.abortController?.abort();
|
||||
this.abortController = null;
|
||||
|
||||
// Destroy ephemeral konva nodes
|
||||
for (const point of this.$points.get()) {
|
||||
point.konva.circle.destroy();
|
||||
}
|
||||
if (this.maskedImage) {
|
||||
this.maskedImage.destroy();
|
||||
}
|
||||
|
||||
// Empty internal module state
|
||||
this.$points.set([]);
|
||||
this.imageState = null;
|
||||
this.$pointType.set(1);
|
||||
this.$hasProcessed.set(false);
|
||||
this.$isProcessing.set(false);
|
||||
|
||||
// Reset non-ephemeral konva nodes
|
||||
this.konva.compositingRect.visible(false);
|
||||
this.konva.maskGroup.clearCache();
|
||||
|
||||
// The parent module's buffer should be reset & forcibly sync the cache
|
||||
this.parent.bufferRenderer.clearBuffer();
|
||||
this.parent.renderer.syncKonvaCache(true);
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds a graph for segmenting an image with the given image DTO.
|
||||
*/
|
||||
buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => {
|
||||
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
|
||||
|
||||
// TODO(psyche): When SAM2 is available in transformers, use it here
|
||||
// See: https://github.com/huggingface/transformers/pull/32317
|
||||
const segmentAnything = graph.addNode({
|
||||
id: getPrefixedId('segment_anything'),
|
||||
type: 'segment_anything',
|
||||
model: 'segment-anything-huge',
|
||||
image: { image_name },
|
||||
point_lists: [{ points: this.getSAMPoints() }],
|
||||
mask_filter: 'largest',
|
||||
});
|
||||
|
||||
// Apply the mask to the image, outputting an image w/ alpha transparency
|
||||
const applyMask = graph.addNode({
|
||||
id: getPrefixedId('apply_tensor_mask_to_image'),
|
||||
type: 'apply_tensor_mask_to_image',
|
||||
image: { image_name },
|
||||
});
|
||||
graph.addEdge(segmentAnything, 'mask', applyMask, 'mask');
|
||||
|
||||
return {
|
||||
graph,
|
||||
outputNodeId: applyMask.id,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the color of a SAM point based on its label.
|
||||
*/
|
||||
getSAMPointColor(label: SAMPointLabel): RgbaColor {
|
||||
if (label === 0) {
|
||||
return this.config.SAM_POINT_NEUTRAL_COLOR;
|
||||
} else if (label === 1) {
|
||||
return this.config.SAM_POINT_FOREGROUND_COLOR;
|
||||
} else {
|
||||
// label === -1
|
||||
return this.config.SAM_POINT_BACKGROUND_COLOR;
|
||||
}
|
||||
}
|
||||
|
||||
repr = () => {
|
||||
return {
|
||||
id: this.id,
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
parent: this.parent.id,
|
||||
points: this.$points.get().map(({ id, konva, label }) => ({
|
||||
id,
|
||||
label,
|
||||
circle: getKonvaNodeDebugAttrs(konva.circle),
|
||||
})),
|
||||
imageState: deepClone(this.imageState),
|
||||
maskedImage: this.maskedImage?.repr(),
|
||||
config: deepClone(this.config),
|
||||
$isSegmenting: this.$isSegmenting.get(),
|
||||
$hasProcessed: this.$hasProcessed.get(),
|
||||
$isProcessing: this.$isProcessing.get(),
|
||||
$pointType: this.$pointType.get(),
|
||||
$pointTypeString: this.$pointTypeString.get(),
|
||||
$isDraggingPoint: this.$isDraggingPoint.get(),
|
||||
konva: {
|
||||
group: getKonvaNodeDebugAttrs(this.konva.group),
|
||||
compositingRect: getKonvaNodeDebugAttrs(this.konva.compositingRect),
|
||||
maskGroup: getKonvaNodeDebugAttrs(this.konva.maskGroup),
|
||||
pointGroup: getKonvaNodeDebugAttrs(this.konva.pointGroup),
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
destroy = () => {
|
||||
this.log.debug('Destroying module');
|
||||
if (this.abortController && !this.abortController.signal.aborted) {
|
||||
this.abortController.abort();
|
||||
}
|
||||
this.abortController = null;
|
||||
this.unsubscribe();
|
||||
this.konva.group.destroy();
|
||||
};
|
||||
}
|
||||
@@ -311,7 +311,7 @@ export class CanvasStageModule extends CanvasModuleBase {
|
||||
this.setIsDraggable(true);
|
||||
|
||||
// Then start dragging the stage if it's not already being dragged
|
||||
if (!this.konva.stage.isDragging()) {
|
||||
if (!this.getIsDragging()) {
|
||||
this.konva.stage.startDrag();
|
||||
}
|
||||
|
||||
@@ -328,7 +328,7 @@ export class CanvasStageModule extends CanvasModuleBase {
|
||||
this.setIsDraggable(this.manager.tool.$tool.get() === 'view');
|
||||
|
||||
// Stop dragging the stage if it's being dragged
|
||||
if (this.konva.stage.isDragging()) {
|
||||
if (this.getIsDragging()) {
|
||||
this.konva.stage.stopDrag();
|
||||
}
|
||||
|
||||
@@ -404,6 +404,10 @@ export class CanvasStageModule extends CanvasModuleBase {
|
||||
this.konva.stage.draggable(isDraggable);
|
||||
};
|
||||
|
||||
getIsDragging = () => {
|
||||
return this.konva.stage.isDragging();
|
||||
};
|
||||
|
||||
addLayer = (layer: Konva.Layer) => {
|
||||
this.konva.stage.add(layer);
|
||||
};
|
||||
|
||||
@@ -613,10 +613,20 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
$rasterizingAdapter = atom<CanvasEntityAdapter | null>(null);
|
||||
|
||||
/**
|
||||
* Whether an entity is currently being transformed. Derived from `$transformingAdapter`.
|
||||
* Whether an entity is currently being rasterized. Derived from `$rasterizingAdapter`.
|
||||
*/
|
||||
$isRasterizing = computed(this.$rasterizingAdapter, (rasterizingAdapter) => Boolean(rasterizingAdapter));
|
||||
|
||||
/**
|
||||
* The entity adapter being segmented, if any.
|
||||
*/
|
||||
$segmentingAdapter = atom<CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer | null>(null);
|
||||
|
||||
/**
|
||||
* Whether an entity is currently being segmented. Derived from `$segmentingAdapter`.
|
||||
*/
|
||||
$isSegmenting = computed(this.$segmentingAdapter, (segmentingAdapter) => Boolean(segmentingAdapter));
|
||||
|
||||
/**
|
||||
* Whether the space key is currently pressed.
|
||||
*/
|
||||
|
||||
@@ -6,11 +6,13 @@ import {
|
||||
} from 'common/util/roundDownToMultiple';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectBbox } from 'features/controlLayers/store/selectors';
|
||||
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
|
||||
import Konva from 'konva';
|
||||
import { noop } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -31,11 +33,11 @@ const NO_ANCHORS: string[] = [];
|
||||
/**
|
||||
* Renders the bounding box. The bounding box can be transformed by the user.
|
||||
*/
|
||||
export class CanvasBboxModule extends CanvasModuleBase {
|
||||
export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
readonly type = 'bbox';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasManager;
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
@@ -61,18 +63,18 @@ export class CanvasBboxModule extends CanvasModuleBase {
|
||||
*/
|
||||
$aspectRatioBuffer = atom(1);
|
||||
|
||||
constructor(manager: CanvasManager) {
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = manager;
|
||||
this.manager = manager;
|
||||
this.parent = parent;
|
||||
this.manager = parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating bbox module');
|
||||
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: `${this.type}:group`, listening: true }),
|
||||
group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
|
||||
// We will use a Konva.Transformer for the generation bbox. Transformers need some shape to transform, so we will
|
||||
// create a transparent rect for this purpose.
|
||||
proxyRect: new Konva.Rect({
|
||||
@@ -127,6 +129,7 @@ export class CanvasBboxModule extends CanvasModuleBase {
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
transformer: new Konva.Transformer({
|
||||
listening: false,
|
||||
name: `${this.type}:transformer`,
|
||||
borderDash: [5, 5],
|
||||
borderStroke: 'rgba(212,216,234,1)',
|
||||
@@ -135,7 +138,6 @@ export class CanvasBboxModule extends CanvasModuleBase {
|
||||
rotateEnabled: false,
|
||||
keepRatio: false,
|
||||
ignoreStroke: true,
|
||||
listening: false,
|
||||
flipEnabled: false,
|
||||
anchorFill: 'rgba(212,216,234,1)',
|
||||
anchorStroke: 'rgb(42,42,42)',
|
||||
@@ -149,9 +151,18 @@ export class CanvasBboxModule extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
this.konva.proxyRect.on('dragmove', this.onDragMove);
|
||||
this.konva.proxyRect.on('pointerenter', () => {
|
||||
this.manager.stage.setCursor('move');
|
||||
});
|
||||
this.konva.proxyRect.on('pointerleave', () => {
|
||||
this.manager.stage.setCursor('default');
|
||||
});
|
||||
this.konva.transformer.on('transform', this.onTransform);
|
||||
this.konva.transformer.on('transformend', this.onTransformEnd);
|
||||
|
||||
this.subscriptions.add(() => {
|
||||
this.konva.proxyRect.off('dragmove pointerenter pointerleave');
|
||||
this.konva.transformer.off('transform transformend');
|
||||
});
|
||||
// The transformer will always be transforming the proxy rect
|
||||
this.konva.transformer.nodes([this.konva.proxyRect]);
|
||||
|
||||
@@ -161,7 +172,7 @@ export class CanvasBboxModule extends CanvasModuleBase {
|
||||
this.konva.group.add(this.konva.transformer);
|
||||
|
||||
// We will listen to the tool state to determine if the bbox should be visible or not.
|
||||
this.subscriptions.add(this.manager.tool.$tool.listen(this.render));
|
||||
this.subscriptions.add(this.parent.$tool.listen(this.render));
|
||||
|
||||
// Also listen to redux state to update the bbox's position and dimensions.
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectBbox, this.render));
|
||||
@@ -176,6 +187,9 @@ export class CanvasBboxModule extends CanvasModuleBase {
|
||||
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
|
||||
}
|
||||
|
||||
// This is a noop. The cursor is changed when the cursor enters or leaves the bbox.
|
||||
syncCursorStyle = noop;
|
||||
|
||||
initialize = () => {
|
||||
this.log.debug('Initializing module');
|
||||
// We need to retain a copy of the bbox state because
|
||||
@@ -189,16 +203,13 @@ export class CanvasBboxModule extends CanvasModuleBase {
|
||||
* Renders the bbox. The bbox is only visible when the tool is set to 'bbox'.
|
||||
*/
|
||||
render = () => {
|
||||
this.log.trace('Rendering');
|
||||
|
||||
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
|
||||
const tool = this.manager.tool.$tool.get();
|
||||
|
||||
this.konva.group.visible(true);
|
||||
const { x, y, width, height } = this.manager.stateApi.runSelector(selectBbox).rect;
|
||||
|
||||
// We need to reach up to the preview layer to enable/disable listening so that the bbox can be interacted with.
|
||||
// If the mangaer is busy, we disable listening so the bbox cannot be interacted with.
|
||||
this.manager.konva.previewLayer.listening(tool === 'bbox' && !this.manager.$isBusy.get());
|
||||
this.konva.group.listening(tool === 'bbox' && !this.manager.$isBusy.get());
|
||||
|
||||
this.konva.proxyRect.setAttrs({
|
||||
x,
|
||||
@@ -0,0 +1,427 @@
|
||||
import { rgbaColorToString } from 'common/util/colorCodeTransformers';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import {
|
||||
alignCoordForTool,
|
||||
getLastPointOfLastLine,
|
||||
getLastPointOfLastLineWithPressure,
|
||||
getLastPointOfLine,
|
||||
getPrefixedId,
|
||||
isDistanceMoreThanMin,
|
||||
offsetCoord,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
type CanvasBrushToolModuleConfig = {
|
||||
/**
|
||||
* The inner border color for the brush tool preview.
|
||||
*/
|
||||
BORDER_INNER_COLOR: string;
|
||||
/**
|
||||
* The outer border color for the brush tool preview.
|
||||
*/
|
||||
BORDER_OUTER_COLOR: string;
|
||||
/**
|
||||
* The number of milliseconds to wait before hiding the brush preview's fill circle after the mouse is released.
|
||||
*/
|
||||
HIDE_FILL_TIMEOUT_MS: number;
|
||||
};
|
||||
|
||||
const DEFAULT_CONFIG: CanvasBrushToolModuleConfig = {
|
||||
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
|
||||
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
|
||||
HIDE_FILL_TIMEOUT_MS: 1500, // same as Affinity
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders a preview of the brush tool on the canvas.
|
||||
*/
|
||||
export class CanvasBrushToolModule extends CanvasModuleBase {
|
||||
readonly type = 'brush_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
config: CanvasBrushToolModuleConfig = DEFAULT_CONFIG;
|
||||
hideFillTimeoutId: number | null = null;
|
||||
|
||||
/**
|
||||
* The Konva objects that make up the brush tool preview:
|
||||
* - A group to hold the fill circle and borders
|
||||
* - A circle to fill the brush area
|
||||
* - An inner border ring
|
||||
* - An outer border ring
|
||||
*/
|
||||
konva: {
|
||||
group: Konva.Group;
|
||||
fillCircle: Konva.Circle;
|
||||
innerBorder: Konva.Ring;
|
||||
outerBorder: Konva.Ring;
|
||||
};
|
||||
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: `${this.type}:brush_group`, listening: false }),
|
||||
fillCircle: new Konva.Circle({
|
||||
name: `${this.type}:brush_fill_circle`,
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
innerBorder: new Konva.Ring({
|
||||
name: `${this.type}:brush_inner_border_ring`,
|
||||
listening: false,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_INNER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
outerBorder: new Konva.Ring({
|
||||
name: `${this.type}:brush_outer_border_ring`,
|
||||
listening: false,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_OUTER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
};
|
||||
this.konva.group.add(this.konva.fillCircle, this.konva.innerBorder, this.konva.outerBorder);
|
||||
}
|
||||
|
||||
syncCursorStyle = () => {
|
||||
this.manager.stage.setCursor('none');
|
||||
};
|
||||
|
||||
render = () => {
|
||||
if (this.parent.$tool.get() !== 'brush') {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.parent.getCanDraw()) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
|
||||
if (!cursorPos) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
|
||||
const lastPointerType = this.parent.$lastPointerType.get();
|
||||
|
||||
if (lastPointerType !== 'mouse' && isPrimaryPointerDown) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
this.setVisibility(true);
|
||||
|
||||
if (this.hideFillTimeoutId !== null) {
|
||||
window.clearTimeout(this.hideFillTimeoutId);
|
||||
this.hideFillTimeoutId = null;
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const brushPreviewFill = this.manager.stateApi.getBrushPreviewColor();
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.brushWidth);
|
||||
const radius = settings.brushWidth / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.fillCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
fill: rgbaColorToString(brushPreviewFill),
|
||||
visible: !isPrimaryPointerDown && lastPointerType === 'mouse',
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
const onePixel = this.manager.stage.unscale(1);
|
||||
const twoPixels = this.manager.stage.unscale(2);
|
||||
|
||||
this.konva.innerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.outerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
|
||||
this.hideFillTimeoutId = window.setTimeout(() => {
|
||||
this.konva.fillCircle.visible(false);
|
||||
this.hideFillTimeoutId = null;
|
||||
}, this.config.HIDE_FILL_TIMEOUT_MS);
|
||||
};
|
||||
|
||||
setVisibility = (visible: boolean) => {
|
||||
this.konva.group.visible(visible);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer enter event on the stage, when the brush tool is active. This may create a new brush line if
|
||||
* the mouse is down as the cursor enters the stage.
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerEnter = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) {
|
||||
/**
|
||||
* Can't do anything without:
|
||||
* - A cursor position: the cursor is not on the stage
|
||||
* - The mouse is down: the user is not drawing
|
||||
* - A selected entity: there is no entity to draw on
|
||||
*/
|
||||
return;
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
// If the pen is down and pressure sensitivity is enabled, add the point with pressure
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line_with_pressure'),
|
||||
type: 'brush_line_with_pressure',
|
||||
points: [alignedPoint.x, alignedPoint.y, e.evt.pressure],
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
// Else, add the point without pressure
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line'),
|
||||
type: 'brush_line',
|
||||
points: [alignedPoint.x, alignedPoint.y],
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer down event on the stage, when the brush tool is active. If the shift key is held, this will
|
||||
* create a straight line from the last point of the last line to the current point. Else, it will create a new line
|
||||
* with the current point.
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerDown = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!cursorPos || !selectedEntity || !isPrimaryPointerDown) {
|
||||
/**
|
||||
* Can't do anything without:
|
||||
* - A cursor position: the cursor is not on the stage
|
||||
* - The mouse is down: the user is not drawing
|
||||
* - A selected entity: there is no entity to draw on
|
||||
*/
|
||||
return;
|
||||
}
|
||||
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
// We need to get the last point of the last line to create a straight line if shift is held
|
||||
const lastLinePoint = getLastPointOfLastLineWithPressure(
|
||||
selectedEntity.state.objects,
|
||||
'brush_line_with_pressure'
|
||||
);
|
||||
|
||||
let points: number[];
|
||||
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [
|
||||
lastLinePoint.x,
|
||||
lastLinePoint.y,
|
||||
lastLinePoint.pressure,
|
||||
alignedPoint.x,
|
||||
alignedPoint.y,
|
||||
e.evt.pressure,
|
||||
];
|
||||
} else {
|
||||
// Create a new line with the current point
|
||||
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line_with_pressure'),
|
||||
type: 'brush_line_with_pressure',
|
||||
points,
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'brush_line');
|
||||
|
||||
let points: number[];
|
||||
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
|
||||
} else {
|
||||
// Create a new line with the current point
|
||||
points = [alignedPoint.x, alignedPoint.y];
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line'),
|
||||
type: 'brush_line',
|
||||
points,
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer up event on the stage, when the brush tool is active. This handles finalizing the brush line
|
||||
* that was being drawn (if any).
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
if (!selectedEntity) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
(selectedEntity.bufferRenderer.state?.type === 'brush_line' ||
|
||||
selectedEntity.bufferRenderer.state?.type === 'brush_line_with_pressure') &&
|
||||
selectedEntity.bufferRenderer.hasBuffer()
|
||||
) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
} else {
|
||||
selectedEntity.bufferRenderer.clearBuffer();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer move event on the stage, when the brush tool is active. This handles extending the brush line
|
||||
* that is being drawn (if any).
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.parent.$isPrimaryPointerDown.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!selectedEntity) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bufferState = selectedEntity.bufferRenderer.state;
|
||||
|
||||
if (!bufferState) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (bufferState.type !== 'brush_line' && bufferState.type !== 'brush_line_with_pressure') {
|
||||
return;
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
|
||||
const lastPoint = getLastPointOfLine(bufferState.points);
|
||||
const minDistance = settings.brushWidth * this.parent.config.BRUSH_SPACING_TARGET_SCALE;
|
||||
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
|
||||
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
|
||||
// Do not add duplicate points
|
||||
return;
|
||||
}
|
||||
|
||||
bufferState.points.push(alignedPoint.x, alignedPoint.y);
|
||||
|
||||
// Add pressure if the pen is down and pressure sensitivity is enabled
|
||||
if (bufferState.type === 'brush_line_with_pressure' && settings.pressureSensitivity) {
|
||||
bufferState.points.push(e.evt.pressure);
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer(bufferState);
|
||||
};
|
||||
|
||||
repr = () => {
|
||||
return {
|
||||
id: this.id,
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
config: this.config,
|
||||
};
|
||||
};
|
||||
|
||||
destroy = () => {
|
||||
this.log.debug('Destroying module');
|
||||
this.konva.group.destroy();
|
||||
};
|
||||
}
|
||||
@@ -2,11 +2,16 @@ import { rgbColorToString } from 'common/util/colorCodeTransformers';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { getColorAtCoordinate, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { RgbColor } from 'features/controlLayers/store/types';
|
||||
import { RGBA_BLACK } from 'features/controlLayers/store/types';
|
||||
import Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { atom } from 'nanostores';
|
||||
import rafThrottle from 'raf-throttle';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
type CanvasToolColorPickerConfig = {
|
||||
type CanvasColorPickerToolModuleConfig = {
|
||||
/**
|
||||
* The inner radius of the ring.
|
||||
*/
|
||||
@@ -49,7 +54,7 @@ type CanvasToolColorPickerConfig = {
|
||||
CROSSHAIR_BORDER_COLOR: string;
|
||||
};
|
||||
|
||||
const DEFAULT_CONFIG: CanvasToolColorPickerConfig = {
|
||||
const DEFAULT_CONFIG: CanvasColorPickerToolModuleConfig = {
|
||||
RING_INNER_RADIUS: 25,
|
||||
RING_OUTER_RADIUS: 35,
|
||||
RING_BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
|
||||
@@ -65,7 +70,7 @@ const DEFAULT_CONFIG: CanvasToolColorPickerConfig = {
|
||||
/**
|
||||
* Renders a preview of the color picker tool on the canvas.
|
||||
*/
|
||||
export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
export class CanvasColorPickerToolModule extends CanvasModuleBase {
|
||||
readonly type = 'color_picker_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
@@ -73,7 +78,12 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
config: CanvasToolColorPickerConfig = DEFAULT_CONFIG;
|
||||
config: CanvasColorPickerToolModuleConfig = DEFAULT_CONFIG;
|
||||
|
||||
/**
|
||||
* The color currently under the cursor. Only has a value when the color picker tool is active.
|
||||
*/
|
||||
$colorUnderCursor = atom<RgbColor>(RGBA_BLACK);
|
||||
|
||||
/**
|
||||
* The Konva objects that make up the color picker tool preview:
|
||||
@@ -110,6 +120,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: `${this.type}:color_picker_group`, listening: false }),
|
||||
ringCandidateColor: new Konva.Ring({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_candidate_color_ring`,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
@@ -117,6 +128,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
ringCurrentColor: new Konva.Arc({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_current_color_arc`,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
@@ -125,6 +137,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
ringInnerBorder: new Konva.Ring({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_inner_border_ring`,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
@@ -133,6 +146,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
ringOuterBorder: new Konva.Ring({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_outer_border_ring`,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
@@ -141,41 +155,49 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairNorthInner: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_north1_line`,
|
||||
stroke: this.config.CROSSHAIR_LINE_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairNorthOuter: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_north2_line`,
|
||||
stroke: this.config.CROSSHAIR_BORDER_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairEastInner: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_east1_line`,
|
||||
stroke: this.config.CROSSHAIR_LINE_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairEastOuter: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_east2_line`,
|
||||
stroke: this.config.CROSSHAIR_BORDER_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairSouthInner: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_south1_line`,
|
||||
stroke: this.config.CROSSHAIR_LINE_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairSouthOuter: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_south2_line`,
|
||||
stroke: this.config.CROSSHAIR_BORDER_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairWestInner: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_west1_line`,
|
||||
stroke: this.config.CROSSHAIR_LINE_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
crosshairWestOuter: new Konva.Line({
|
||||
listening: false,
|
||||
name: `${this.type}:color_picker_crosshair_west2_line`,
|
||||
stroke: this.config.CROSSHAIR_BORDER_COLOR,
|
||||
perfectDrawEnabled: false,
|
||||
@@ -198,21 +220,27 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
);
|
||||
}
|
||||
|
||||
syncCursorStyle = () => {
|
||||
this.manager.stage.setCursor('none');
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the color picker tool preview on the canvas.
|
||||
*/
|
||||
render = () => {
|
||||
const tool = this.parent.$tool.get();
|
||||
if (this.parent.$tool.get() !== 'colorPicker') {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (tool !== 'colorPicker') {
|
||||
if (!this.parent.getCanDraw()) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const canDraw = this.parent.getCanDraw();
|
||||
|
||||
if (!cursorPos || tool !== 'colorPicker' || !canDraw) {
|
||||
if (!cursorPos) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
@@ -222,7 +250,7 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
const { x, y } = cursorPos.relative;
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const colorUnderCursor = this.parent.$colorUnderCursor.get();
|
||||
const colorUnderCursor = this.$colorUnderCursor.get();
|
||||
const colorPickerInnerRadius = this.manager.stage.unscale(this.config.RING_INNER_RADIUS);
|
||||
const colorPickerOuterRadius = this.manager.stage.unscale(this.config.RING_OUTER_RADIUS);
|
||||
const onePixel = this.manager.stage.unscale(1);
|
||||
@@ -299,12 +327,38 @@ export class CanvasToolColorPicker extends CanvasModuleBase {
|
||||
this.konva.group.visible(visible);
|
||||
};
|
||||
|
||||
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
|
||||
const color = this.$colorUnderCursor.get();
|
||||
if (color) {
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
// This will update the color but not the alpha value
|
||||
this.manager.stateApi.setColor({ ...settings.color, ...color });
|
||||
}
|
||||
};
|
||||
|
||||
onStagePointerMove = (_e: KonvaEventObject<PointerEvent>) => {
|
||||
this.syncColorUnderCursor();
|
||||
};
|
||||
|
||||
syncColorUnderCursor = rafThrottle(() => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
const color = getColorAtCoordinate(this.manager.stage.konva.stage, cursorPos.absolute);
|
||||
if (color) {
|
||||
this.$colorUnderCursor.set(color);
|
||||
}
|
||||
});
|
||||
|
||||
repr = () => {
|
||||
return {
|
||||
id: this.id,
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
config: this.config,
|
||||
$colorUnderCursor: this.$colorUnderCursor.get(),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,394 @@
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import {
|
||||
alignCoordForTool,
|
||||
getLastPointOfLastLine,
|
||||
getLastPointOfLastLineWithPressure,
|
||||
getLastPointOfLine,
|
||||
getPrefixedId,
|
||||
isDistanceMoreThanMin,
|
||||
offsetCoord,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
type CanvasEraserToolModuleConfig = {
|
||||
/**
|
||||
* The inner border color for the eraser tool preview.
|
||||
*/
|
||||
BORDER_INNER_COLOR: string;
|
||||
/**
|
||||
* The outer border color for the eraser tool preview.
|
||||
*/
|
||||
BORDER_OUTER_COLOR: string;
|
||||
};
|
||||
|
||||
const DEFAULT_CONFIG: CanvasEraserToolModuleConfig = {
|
||||
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
|
||||
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
|
||||
};
|
||||
|
||||
export class CanvasEraserToolModule extends CanvasModuleBase {
|
||||
readonly type = 'eraser_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
config: CanvasEraserToolModuleConfig = DEFAULT_CONFIG;
|
||||
|
||||
konva: {
|
||||
group: Konva.Group;
|
||||
cutoutCircle: Konva.Circle;
|
||||
innerBorder: Konva.Ring;
|
||||
outerBorder: Konva.Ring;
|
||||
};
|
||||
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: `${this.type}:eraser_group`, listening: false }),
|
||||
cutoutCircle: new Konva.Circle({
|
||||
name: `${this.type}:eraser_cutout_circle`,
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
// The fill is used only to erase what is underneath it, so its color doesn't matter - just needs to be opaque
|
||||
fill: 'white',
|
||||
globalCompositeOperation: 'destination-out',
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
innerBorder: new Konva.Ring({
|
||||
name: `${this.type}:eraser_inner_border_ring`,
|
||||
listening: false,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_INNER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
outerBorder: new Konva.Ring({
|
||||
listening: false,
|
||||
name: `${this.type}:eraser_outer_border_ring`,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_OUTER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
};
|
||||
this.konva.group.add(this.konva.cutoutCircle, this.konva.innerBorder, this.konva.outerBorder);
|
||||
}
|
||||
|
||||
syncCursorStyle = () => {
|
||||
this.manager.stage.setCursor('none');
|
||||
};
|
||||
|
||||
render = () => {
|
||||
if (this.parent.$tool.get() !== 'eraser') {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.parent.getCanDraw()) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
|
||||
if (!cursorPos) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
|
||||
const lastPointerType = this.parent.$lastPointerType.get();
|
||||
|
||||
if (lastPointerType !== 'mouse' && isPrimaryPointerDown) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
this.setVisibility(true);
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.eraserWidth);
|
||||
const radius = settings.eraserWidth / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.cutoutCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
const onePixel = this.manager.stage.unscale(1);
|
||||
const twoPixels = this.manager.stage.unscale(2);
|
||||
|
||||
this.konva.innerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.outerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
};
|
||||
|
||||
setVisibility = (visible: boolean) => {
|
||||
this.konva.group.visible(visible);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer enter event on the stage, when the eraser tool is active. This may create a new eraser line if
|
||||
* the mouse is down as the cursor enters the stage.
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'eraser', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerEnter = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) {
|
||||
/**
|
||||
* Can't do anything without:
|
||||
* - A cursor position: the cursor is not on the stage
|
||||
* - The mouse is down: the user is not drawing
|
||||
* - A selected entity: there is no entity to draw on
|
||||
*/
|
||||
return;
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
// If the pen is down and pressure sensitivity is enabled, add the point with pressure
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line_with_pressure'),
|
||||
type: 'eraser_line_with_pressure',
|
||||
points: [alignedPoint.x, alignedPoint.y, e.evt.pressure],
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
// Else, add the point without pressure
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line'),
|
||||
type: 'eraser_line',
|
||||
points: [alignedPoint.x, alignedPoint.y],
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer down event on the stage, when the eraser tool is active. If the shift key is held, this will
|
||||
* create a straight line from the last point of the last line to the current point. Else, it will create a new line
|
||||
* with the current point.
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'eraser', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerDown = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!cursorPos || !selectedEntity) {
|
||||
/**
|
||||
* Can't do anything without:
|
||||
* - A cursor position: the cursor is not on the stage
|
||||
* - A selected entity: there is no entity to draw on
|
||||
*/
|
||||
return;
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
// We need to get the last point of the last line to create a straight line if shift is held
|
||||
const lastLinePoint = getLastPointOfLastLineWithPressure(
|
||||
selectedEntity.state.objects,
|
||||
'eraser_line_with_pressure'
|
||||
);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
let points: number[];
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [
|
||||
lastLinePoint.x,
|
||||
lastLinePoint.y,
|
||||
lastLinePoint.pressure,
|
||||
alignedPoint.x,
|
||||
alignedPoint.y,
|
||||
e.evt.pressure,
|
||||
];
|
||||
} else {
|
||||
// Create a new line with the current point
|
||||
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
|
||||
}
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line_with_pressure'),
|
||||
type: 'eraser_line_with_pressure',
|
||||
points,
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
// We need to get the last point of the last line to create a straight line if shift is held
|
||||
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'eraser_line');
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
|
||||
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
|
||||
let points: number[];
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
|
||||
} else {
|
||||
// Create a new line with the current point
|
||||
points = [alignedPoint.x, alignedPoint.y];
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line'),
|
||||
type: 'eraser_line',
|
||||
points,
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer up event on the stage, when the eraser tool is active. This handles finalizing the eraser line
|
||||
* that was being drawn (if any).
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'eraser', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
if (!selectedEntity) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
(selectedEntity.bufferRenderer.state?.type === 'eraser_line' ||
|
||||
selectedEntity.bufferRenderer.state?.type === 'eraser_line_with_pressure') &&
|
||||
selectedEntity.bufferRenderer.hasBuffer()
|
||||
) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
} else {
|
||||
selectedEntity.bufferRenderer.clearBuffer();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the pointer move event on the stage, when the brush tool is active. This handles extending the brush line
|
||||
* that is being drawn (if any).
|
||||
*
|
||||
* The tool module will pass on the event to this method if the tool is 'brush', after doing any necessary checks
|
||||
* and non-tool-specific handling.
|
||||
*
|
||||
* @param e The Konva event object.
|
||||
*/
|
||||
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.parent.$isPrimaryPointerDown.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!selectedEntity) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bufferState = selectedEntity.bufferRenderer.state;
|
||||
|
||||
if (!bufferState) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (bufferState.type !== 'eraser_line' && bufferState.type !== 'eraser_line_with_pressure') {
|
||||
return;
|
||||
}
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
|
||||
const lastPoint = getLastPointOfLine(bufferState.points);
|
||||
const minDistance = settings.eraserWidth * this.parent.config.BRUSH_SPACING_TARGET_SCALE;
|
||||
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
|
||||
|
||||
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
|
||||
// Do not add duplicate points
|
||||
return;
|
||||
}
|
||||
|
||||
bufferState.points.push(alignedPoint.x, alignedPoint.y);
|
||||
|
||||
// Add pressure if the pen is down and pressure sensitivity is enabled
|
||||
if (bufferState.type === 'eraser_line_with_pressure' && settings.pressureSensitivity) {
|
||||
bufferState.points.push(e.evt.pressure);
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer(bufferState);
|
||||
};
|
||||
|
||||
repr = () => {
|
||||
return {
|
||||
id: this.id,
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
config: this.config,
|
||||
};
|
||||
};
|
||||
|
||||
destroy = () => {
|
||||
this.log.debug('Destroying eraser tool preview module');
|
||||
this.konva.group.destroy();
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { noop } from 'lodash-es';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
export class CanvasMoveToolModule extends CanvasModuleBase {
|
||||
readonly type = 'move_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
}
|
||||
|
||||
/**
|
||||
* This is a noop. Entity transformers handle cursor style when the move tool is active.
|
||||
*/
|
||||
syncCursorStyle = noop;
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { floorCoord, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
export class CanvasRectToolModule extends CanvasModuleBase {
|
||||
readonly type = 'rect_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
}
|
||||
|
||||
syncCursorStyle = () => {
|
||||
this.manager.stage.setCursor('crosshair');
|
||||
};
|
||||
|
||||
onStagePointerDown = async (_e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) {
|
||||
/**
|
||||
* Can't do anything without:
|
||||
* - A cursor position: the cursor is not on the stage
|
||||
* - The mouse is down: the user is not drawing
|
||||
* - A selected entity: there is no entity to draw on
|
||||
*/
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('rect'),
|
||||
type: 'rect',
|
||||
rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 },
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
});
|
||||
};
|
||||
|
||||
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
if (!selectedEntity) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (selectedEntity.bufferRenderer.state?.type === 'rect' && selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
} else {
|
||||
selectedEntity.bufferRenderer.clearBuffer();
|
||||
}
|
||||
};
|
||||
|
||||
onStagePointerMove = async (_e: KonvaEventObject<PointerEvent>) => {
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.parent.$isPrimaryPointerDown.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!selectedEntity) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bufferState = selectedEntity.bufferRenderer.state;
|
||||
|
||||
if (!bufferState) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (bufferState.type !== 'rect') {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = floorCoord(normalizedPoint);
|
||||
bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x);
|
||||
bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y);
|
||||
await selectedEntity.bufferRenderer.setBuffer(bufferState);
|
||||
};
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
import { rgbaColorToString } from 'common/util/colorCodeTransformers';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { alignCoordForTool, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import Konva from 'konva';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
type CanvasToolBrushConfig = {
|
||||
/**
|
||||
* The inner border color for the brush tool preview.
|
||||
*/
|
||||
BORDER_INNER_COLOR: string;
|
||||
/**
|
||||
* The outer border color for the brush tool preview.
|
||||
*/
|
||||
BORDER_OUTER_COLOR: string;
|
||||
/**
|
||||
* The number of milliseconds to wait before hiding the brush preview's fill circle after the mouse is released.
|
||||
*/
|
||||
HIDE_FILL_TIMEOUT_MS: number;
|
||||
};
|
||||
|
||||
const DEFAULT_CONFIG: CanvasToolBrushConfig = {
|
||||
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
|
||||
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
|
||||
HIDE_FILL_TIMEOUT_MS: 1500, // same as Affinity
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders a preview of the brush tool on the canvas.
|
||||
*/
|
||||
export class CanvasToolBrush extends CanvasModuleBase {
|
||||
readonly type = 'brush_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
config: CanvasToolBrushConfig = DEFAULT_CONFIG;
|
||||
hideFillTimeoutId: number | null = null;
|
||||
|
||||
/**
|
||||
* The Konva objects that make up the brush tool preview:
|
||||
* - A group to hold the fill circle and borders
|
||||
* - A circle to fill the brush area
|
||||
* - An inner border ring
|
||||
* - An outer border ring
|
||||
*/
|
||||
konva: {
|
||||
group: Konva.Group;
|
||||
fillCircle: Konva.Circle;
|
||||
innerBorder: Konva.Ring;
|
||||
outerBorder: Konva.Ring;
|
||||
};
|
||||
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: `${this.type}:brush_group`, listening: false }),
|
||||
fillCircle: new Konva.Circle({
|
||||
name: `${this.type}:brush_fill_circle`,
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
innerBorder: new Konva.Ring({
|
||||
name: `${this.type}:brush_inner_border_ring`,
|
||||
listening: false,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_INNER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
outerBorder: new Konva.Ring({
|
||||
name: `${this.type}:brush_outer_border_ring`,
|
||||
listening: false,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_OUTER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
};
|
||||
this.konva.group.add(this.konva.fillCircle, this.konva.innerBorder, this.konva.outerBorder);
|
||||
}
|
||||
render = () => {
|
||||
const tool = this.parent.$tool.get();
|
||||
|
||||
if (tool !== 'brush') {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const canDraw = this.parent.getCanDraw();
|
||||
|
||||
if (!cursorPos || !canDraw) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const isMouseDown = this.parent.$isMouseDown.get();
|
||||
const lastPointerType = this.parent.$lastPointerType.get();
|
||||
|
||||
if (lastPointerType !== 'mouse' && isMouseDown) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
this.setVisibility(true);
|
||||
|
||||
if (this.hideFillTimeoutId !== null) {
|
||||
window.clearTimeout(this.hideFillTimeoutId);
|
||||
this.hideFillTimeoutId = null;
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const brushPreviewFill = this.manager.stateApi.getBrushPreviewColor();
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.brushWidth);
|
||||
const radius = settings.brushWidth / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.fillCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
fill: rgbaColorToString(brushPreviewFill),
|
||||
visible: !isMouseDown && lastPointerType === 'mouse',
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
const onePixel = this.manager.stage.unscale(1);
|
||||
const twoPixels = this.manager.stage.unscale(2);
|
||||
|
||||
this.konva.innerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.outerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
|
||||
this.hideFillTimeoutId = window.setTimeout(() => {
|
||||
this.konva.fillCircle.visible(false);
|
||||
this.hideFillTimeoutId = null;
|
||||
}, this.config.HIDE_FILL_TIMEOUT_MS);
|
||||
};
|
||||
|
||||
setVisibility = (visible: boolean) => {
|
||||
this.konva.group.visible(visible);
|
||||
};
|
||||
|
||||
repr = () => {
|
||||
return {
|
||||
id: this.id,
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
config: this.config,
|
||||
};
|
||||
};
|
||||
|
||||
destroy = () => {
|
||||
this.log.debug('Destroying module');
|
||||
this.konva.group.destroy();
|
||||
};
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { alignCoordForTool, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import Konva from 'konva';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
type CanvasToolEraserConfig = {
|
||||
/**
|
||||
* The inner border color for the eraser tool preview.
|
||||
*/
|
||||
BORDER_INNER_COLOR: string;
|
||||
/**
|
||||
* The outer border color for the eraser tool preview.
|
||||
*/
|
||||
BORDER_OUTER_COLOR: string;
|
||||
};
|
||||
|
||||
const DEFAULT_CONFIG: CanvasToolEraserConfig = {
|
||||
BORDER_INNER_COLOR: 'rgba(0,0,0,1)',
|
||||
BORDER_OUTER_COLOR: 'rgba(255,255,255,0.8)',
|
||||
};
|
||||
|
||||
export class CanvasToolEraser extends CanvasModuleBase {
|
||||
readonly type = 'eraser_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
config: CanvasToolEraserConfig = DEFAULT_CONFIG;
|
||||
|
||||
konva: {
|
||||
group: Konva.Group;
|
||||
cutoutCircle: Konva.Circle;
|
||||
innerBorder: Konva.Ring;
|
||||
outerBorder: Konva.Ring;
|
||||
};
|
||||
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: `${this.type}:eraser_group`, listening: false }),
|
||||
cutoutCircle: new Konva.Circle({
|
||||
name: `${this.type}:eraser_cutout_circle`,
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
// The fill is used only to erase what is underneath it, so its color doesn't matter - just needs to be opaque
|
||||
fill: 'white',
|
||||
globalCompositeOperation: 'destination-out',
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
innerBorder: new Konva.Ring({
|
||||
name: `${this.type}:eraser_inner_border_ring`,
|
||||
listening: false,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_INNER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
outerBorder: new Konva.Ring({
|
||||
name: `${this.type}:eraser_outer_border_ring`,
|
||||
innerRadius: 0,
|
||||
outerRadius: 0,
|
||||
fill: this.config.BORDER_OUTER_COLOR,
|
||||
strokeEnabled: false,
|
||||
perfectDrawEnabled: false,
|
||||
}),
|
||||
};
|
||||
this.konva.group.add(this.konva.cutoutCircle, this.konva.innerBorder, this.konva.outerBorder);
|
||||
}
|
||||
|
||||
render = () => {
|
||||
const tool = this.parent.$tool.get();
|
||||
|
||||
if (tool !== 'eraser') {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const cursorPos = this.parent.$cursorPos.get();
|
||||
const canDraw = this.parent.getCanDraw();
|
||||
|
||||
if (!cursorPos || !canDraw) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const isMouseDown = this.parent.$isMouseDown.get();
|
||||
const lastPointerType = this.parent.$lastPointerType.get();
|
||||
|
||||
if (lastPointerType !== 'mouse' && isMouseDown) {
|
||||
this.setVisibility(false);
|
||||
return;
|
||||
}
|
||||
|
||||
this.setVisibility(true);
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos.relative, settings.eraserWidth);
|
||||
const radius = settings.eraserWidth / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.cutoutCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
const onePixel = this.manager.stage.unscale(1);
|
||||
const twoPixels = this.manager.stage.unscale(2);
|
||||
|
||||
this.konva.innerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.outerBorder.setAttrs({
|
||||
x: cursorPos.relative.x,
|
||||
y: cursorPos.relative.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
};
|
||||
|
||||
setVisibility = (visible: boolean) => {
|
||||
this.konva.group.visible(visible);
|
||||
};
|
||||
|
||||
repr = () => {
|
||||
return {
|
||||
id: this.id,
|
||||
type: this.type,
|
||||
path: this.path,
|
||||
config: this.config,
|
||||
};
|
||||
};
|
||||
|
||||
destroy = () => {
|
||||
this.log.debug('Destroying eraser tool preview module');
|
||||
this.konva.group.destroy();
|
||||
};
|
||||
}
|
||||
@@ -1,20 +1,16 @@
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasToolBrush } from 'features/controlLayers/konva/CanvasTool/CanvasToolBrush';
|
||||
import { CanvasToolColorPicker } from 'features/controlLayers/konva/CanvasTool/CanvasToolColorPicker';
|
||||
import { CanvasToolEraser } from 'features/controlLayers/konva/CanvasTool/CanvasToolEraser';
|
||||
import { CanvasBboxToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBboxToolModule';
|
||||
import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBrushToolModule';
|
||||
import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule';
|
||||
import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasEraserToolModule';
|
||||
import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule';
|
||||
import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule';
|
||||
import { CanvasViewToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasViewToolModule';
|
||||
import {
|
||||
alignCoordForTool,
|
||||
calculateNewBrushSizeFromWheelDelta,
|
||||
floorCoord,
|
||||
getColorAtCoordinate,
|
||||
getIsPrimaryMouseDown,
|
||||
getLastPointOfLastLine,
|
||||
getLastPointOfLastLineWithPressure,
|
||||
getLastPointOfLine,
|
||||
getPrefixedId,
|
||||
isDistanceMoreThanMin,
|
||||
offsetCoord,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
@@ -24,14 +20,11 @@ import type {
|
||||
CanvasRasterLayerState,
|
||||
CanvasRegionalGuidanceState,
|
||||
Coordinate,
|
||||
RgbColor,
|
||||
Tool,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { RGBA_BLACK } from 'features/controlLayers/store/types';
|
||||
import Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { atom } from 'nanostores';
|
||||
import rafThrottle from 'raf-throttle';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
// Konva's docs say the default drag buttons are [0], but it's actually [0,1]. We only want left-click to drag, so we
|
||||
@@ -63,9 +56,15 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
|
||||
config: CanvasToolModuleConfig = DEFAULT_CONFIG;
|
||||
|
||||
brushToolPreview: CanvasToolBrush;
|
||||
eraserToolPreview: CanvasToolEraser;
|
||||
colorPickerToolPreview: CanvasToolColorPicker;
|
||||
tools: {
|
||||
brush: CanvasBrushToolModule;
|
||||
eraser: CanvasEraserToolModule;
|
||||
rect: CanvasRectToolModule;
|
||||
colorPicker: CanvasColorPickerToolModule;
|
||||
bbox: CanvasBboxToolModule;
|
||||
view: CanvasViewToolModule;
|
||||
move: CanvasMoveToolModule;
|
||||
};
|
||||
|
||||
/**
|
||||
* The currently selected tool.
|
||||
@@ -77,17 +76,22 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
*/
|
||||
$toolBuffer = atom<Tool | null>(null);
|
||||
/**
|
||||
* Whether the mouse is currently down.
|
||||
* Whether the primary pointer (left mouse, pen, first touch) is currently down on the stage.
|
||||
*
|
||||
* This is set true when the pointer down is fired on the stage and false when the pointer up is fired anywhere,
|
||||
* including outside of the stage. This flag is thus true when the user is actively drawing on the stage.
|
||||
*
|
||||
* For example, if the pointer down was fired on the stage and the cursor then leaves the stage without a pointer up
|
||||
* event, this will still be true. If the cursor then moves back onto the stage, this will still be true.
|
||||
*
|
||||
* However, if the pointer down was initially fired _outside_ the stage, and the cursor moves onto the stage, this
|
||||
* will be false.
|
||||
*/
|
||||
$isMouseDown = atom<boolean>(false);
|
||||
$isPrimaryPointerDown = atom<boolean>(false);
|
||||
/**
|
||||
* The last cursor position.
|
||||
*/
|
||||
$cursorPos = atom<{ relative: Coordinate; absolute: Coordinate } | null>(null);
|
||||
/**
|
||||
* The color currently under the cursor. Only has a value when the color picker tool is active.
|
||||
*/
|
||||
$colorUnderCursor = atom<RgbColor>(RGBA_BLACK);
|
||||
/**
|
||||
* The last pointer type that was used on the stage. This is used to determine if we should show a tool preview. For
|
||||
* example, when using a pen, we should not show a brush preview.
|
||||
@@ -109,18 +113,25 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
|
||||
this.log.debug('Creating tool module');
|
||||
|
||||
this.brushToolPreview = new CanvasToolBrush(this);
|
||||
this.eraserToolPreview = new CanvasToolEraser(this);
|
||||
this.colorPickerToolPreview = new CanvasToolColorPicker(this);
|
||||
this.tools = {
|
||||
brush: new CanvasBrushToolModule(this),
|
||||
eraser: new CanvasEraserToolModule(this),
|
||||
rect: new CanvasRectToolModule(this),
|
||||
colorPicker: new CanvasColorPickerToolModule(this),
|
||||
bbox: new CanvasBboxToolModule(this),
|
||||
view: new CanvasViewToolModule(this),
|
||||
move: new CanvasMoveToolModule(this),
|
||||
};
|
||||
|
||||
this.konva = {
|
||||
stage: this.manager.stage.konva.stage,
|
||||
group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
|
||||
group: new Konva.Group({ name: `${this.type}:group`, listening: true }),
|
||||
};
|
||||
|
||||
this.konva.group.add(this.brushToolPreview.konva.group);
|
||||
this.konva.group.add(this.eraserToolPreview.konva.group);
|
||||
this.konva.group.add(this.colorPickerToolPreview.konva.group);
|
||||
this.konva.group.add(this.tools.brush.konva.group);
|
||||
this.konva.group.add(this.tools.eraser.konva.group);
|
||||
this.konva.group.add(this.tools.colorPicker.konva.group);
|
||||
this.konva.group.add(this.tools.bbox.konva.group);
|
||||
|
||||
this.subscriptions.add(this.manager.stage.$stageAttrs.listen(this.render));
|
||||
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
|
||||
@@ -129,7 +140,7 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
this.subscriptions.add(
|
||||
this.$tool.listen(() => {
|
||||
// On tool switch, reset mouse state
|
||||
this.manager.tool.$isMouseDown.set(false);
|
||||
this.manager.tool.$isPrimaryPointerDown.set(false);
|
||||
this.render();
|
||||
})
|
||||
);
|
||||
@@ -145,71 +156,47 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
this.syncCursorStyle();
|
||||
};
|
||||
|
||||
setToolVisibility = (tool: Tool, isDrawable: boolean) => {
|
||||
this.brushToolPreview.setVisibility(isDrawable && tool === 'brush');
|
||||
this.eraserToolPreview.setVisibility(isDrawable && tool === 'eraser');
|
||||
this.colorPickerToolPreview.setVisibility(tool === 'colorPicker');
|
||||
};
|
||||
|
||||
syncCursorStyle = () => {
|
||||
const stage = this.manager.stage;
|
||||
const tool = this.$tool.get();
|
||||
const isStageDragging = this.manager.stage.konva.stage.isDragging();
|
||||
const segmentingAdapter = this.manager.stateApi.$segmentingAdapter.get();
|
||||
|
||||
if (tool === 'view' && !isStageDragging) {
|
||||
stage.setCursor('grab');
|
||||
} else if (this.manager.stage.konva.stage.isDragging()) {
|
||||
stage.setCursor('grabbing');
|
||||
} else if (this.manager.stateApi.$isTransforming.get()) {
|
||||
stage.setCursor('default');
|
||||
if ((this.manager.stage.getIsDragging() || tool === 'view') && !segmentingAdapter) {
|
||||
this.tools.view.syncCursorStyle();
|
||||
} else if (segmentingAdapter) {
|
||||
segmentingAdapter.segmentAnything.syncCursorStyle();
|
||||
} else if (this.manager.stateApi.$isFiltering.get()) {
|
||||
stage.setCursor('not-allowed');
|
||||
} else if (this.manager.stagingArea.$isStaging.get()) {
|
||||
stage.setCursor('not-allowed');
|
||||
} else if (tool === 'bbox') {
|
||||
stage.setCursor('default');
|
||||
this.tools.bbox.syncCursorStyle();
|
||||
} else if (this.manager.stateApi.getRenderedEntityCount() === 0) {
|
||||
stage.setCursor('not-allowed');
|
||||
} else if (!this.manager.stateApi.getSelectedEntityAdapter()?.$isInteractable.get()) {
|
||||
stage.setCursor('not-allowed');
|
||||
} else if (tool === 'colorPicker' || tool === 'brush' || tool === 'eraser') {
|
||||
stage.setCursor('none');
|
||||
} else if (tool === 'brush') {
|
||||
this.tools.brush.syncCursorStyle();
|
||||
} else if (tool === 'eraser') {
|
||||
this.tools.eraser.syncCursorStyle();
|
||||
} else if (tool === 'colorPicker') {
|
||||
this.tools.colorPicker.syncCursorStyle();
|
||||
} else if (tool === 'move') {
|
||||
stage.setCursor('default');
|
||||
this.tools.move.syncCursorStyle();
|
||||
} else if (tool === 'rect') {
|
||||
stage.setCursor('crosshair');
|
||||
this.tools.rect.syncCursorStyle();
|
||||
} else {
|
||||
stage.setCursor('not-allowed');
|
||||
}
|
||||
};
|
||||
|
||||
render = () => {
|
||||
const renderedEntityCount = this.manager.stateApi.getRenderedEntityCount();
|
||||
const cursorPos = this.$cursorPos.get();
|
||||
const isFiltering = this.manager.stateApi.$isFiltering.get();
|
||||
const isStaging = this.manager.stagingArea.$isStaging.get();
|
||||
const isStageDragging = this.manager.stage.konva.stage.isDragging();
|
||||
|
||||
this.syncCursorStyle();
|
||||
|
||||
/**
|
||||
* The tool should not be rendered when:
|
||||
* - There is no cursor position (i.e. the cursor is outside of the stage)
|
||||
* - The user is filtering, in which case the user is not allowed to use the tools. Note that we do not disable
|
||||
* the group while transforming, bc that requires use of the move tool.
|
||||
* - The canvas is staging, in which case the user is not allowed to use the tools.
|
||||
* - There are no entities rendered on the canvas. Maybe we should allow the user to draw on an empty canvas,
|
||||
* creating a new layer when they start?
|
||||
* - The stage is being dragged, in which case the user is not allowed to use the tools.
|
||||
*/
|
||||
if (!cursorPos || isFiltering || isStaging || renderedEntityCount === 0 || isStageDragging) {
|
||||
this.konva.group.visible(false);
|
||||
} else {
|
||||
this.konva.group.visible(true);
|
||||
this.brushToolPreview.render();
|
||||
this.eraserToolPreview.render();
|
||||
this.colorPickerToolPreview.render();
|
||||
}
|
||||
this.tools.brush.render();
|
||||
this.tools.eraser.render();
|
||||
this.tools.colorPicker.render();
|
||||
this.tools.bbox.render();
|
||||
};
|
||||
|
||||
syncCursorPositions = () => {
|
||||
@@ -282,6 +269,14 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets whether the user is allowed to draw on the canvas.
|
||||
* - There must be at least one entity rendered on the canvas.
|
||||
* - The canvas must not be busy (e.g. transforming, filtering, rasterizing, staging, compositing, segment-anything-ing).
|
||||
* - There must be a selected entity.
|
||||
* - The selected entity must be interactable (e.g. not hidden, disabled or locked).
|
||||
* @returns Whether the user is allowed to draw on the canvas.
|
||||
*/
|
||||
getCanDraw = (): boolean => {
|
||||
if (this.manager.stateApi.getRenderedEntityCount() === 0) {
|
||||
return false;
|
||||
@@ -291,6 +286,10 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (this.manager.stage.getIsDragging()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!selectedEntity) {
|
||||
@@ -313,71 +312,19 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
}
|
||||
|
||||
this.syncCursorPositions();
|
||||
const cursorPos = this.$cursorPos.get();
|
||||
|
||||
const isMouseDown = this.$isMouseDown.get();
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const tool = this.$tool.get();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!cursorPos || !isMouseDown || !selectedEntity?.$isInteractable.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (selectedEntity.bufferRenderer.state?.type !== 'rect' && selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
if (selectedEntity?.bufferRenderer.state?.type !== 'rect' && selectedEntity?.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
return;
|
||||
}
|
||||
|
||||
if (tool === 'brush') {
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line_with_pressure'),
|
||||
type: 'brush_line_with_pressure',
|
||||
points: [alignedPoint.x, alignedPoint.y, e.evt.pressure],
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line'),
|
||||
type: 'brush_line',
|
||||
points: [alignedPoint.x, alignedPoint.y],
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (tool === 'eraser') {
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
if (selectedEntity.bufferRenderer.state && selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line_with_pressure'),
|
||||
type: 'eraser_line_with_pressure',
|
||||
points: [alignedPoint.x, alignedPoint.y],
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line'),
|
||||
type: 'eraser_line',
|
||||
points: [alignedPoint.x, alignedPoint.y],
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
return;
|
||||
await this.tools.brush.onStagePointerEnter(e);
|
||||
} else if (tool === 'eraser') {
|
||||
await this.tools.eraser.onStagePointerEnter(e);
|
||||
}
|
||||
} finally {
|
||||
this.render();
|
||||
@@ -385,6 +332,10 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
onStagePointerDown = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
if (e.target !== this.konva.stage) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.$lastPointerType.set(e.evt.pointerType);
|
||||
|
||||
@@ -392,147 +343,18 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
const isMouseDown = getIsPrimaryMouseDown(e);
|
||||
this.$isMouseDown.set(isMouseDown);
|
||||
this.$isPrimaryPointerDown.set(getIsPrimaryMouseDown(e));
|
||||
|
||||
this.syncCursorPositions();
|
||||
const cursorPos = this.$cursorPos.get();
|
||||
|
||||
const tool = this.$tool.get();
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!cursorPos || !isMouseDown || !selectedEntity?.$isInteractable.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
|
||||
if (tool === 'brush') {
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
const lastLinePoint = getLastPointOfLastLineWithPressure(
|
||||
selectedEntity.state.objects,
|
||||
'brush_line_with_pressure'
|
||||
);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
let points: number[];
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [
|
||||
lastLinePoint.x,
|
||||
lastLinePoint.y,
|
||||
lastLinePoint.pressure,
|
||||
alignedPoint.x,
|
||||
alignedPoint.y,
|
||||
e.evt.pressure,
|
||||
];
|
||||
} else {
|
||||
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
|
||||
}
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line_with_pressure'),
|
||||
type: 'brush_line_with_pressure',
|
||||
points,
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'brush_line');
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
|
||||
let points: number[];
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
|
||||
} else {
|
||||
points = [alignedPoint.x, alignedPoint.y];
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('brush_line'),
|
||||
type: 'brush_line',
|
||||
points,
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (tool === 'eraser') {
|
||||
if (e.evt.pointerType === 'pen' && settings.pressureSensitivity) {
|
||||
const lastLinePoint = getLastPointOfLastLineWithPressure(
|
||||
selectedEntity.state.objects,
|
||||
'eraser_line_with_pressure'
|
||||
);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
let points: number[];
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [
|
||||
lastLinePoint.x,
|
||||
lastLinePoint.y,
|
||||
lastLinePoint.pressure,
|
||||
alignedPoint.x,
|
||||
alignedPoint.y,
|
||||
e.evt.pressure,
|
||||
];
|
||||
} else {
|
||||
points = [alignedPoint.x, alignedPoint.y, e.evt.pressure];
|
||||
}
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line_with_pressure'),
|
||||
type: 'eraser_line_with_pressure',
|
||||
points,
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'eraser_line');
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
|
||||
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
|
||||
let points: number[];
|
||||
if (e.evt.shiftKey && lastLinePoint) {
|
||||
// Create a straight line from the last line point
|
||||
points = [lastLinePoint.x, lastLinePoint.y, alignedPoint.x, alignedPoint.y];
|
||||
} else {
|
||||
points = [alignedPoint.x, alignedPoint.y];
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('eraser_line'),
|
||||
type: 'eraser_line',
|
||||
points,
|
||||
strokeWidth: settings.eraserWidth,
|
||||
clip: this.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (tool === 'rect') {
|
||||
if (selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
}
|
||||
await selectedEntity.bufferRenderer.setBuffer({
|
||||
id: getPrefixedId('rect'),
|
||||
type: 'rect',
|
||||
rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 },
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
});
|
||||
await this.tools.brush.onStagePointerDown(e);
|
||||
} else if (tool === 'eraser') {
|
||||
await this.tools.eraser.onStagePointerDown(e);
|
||||
} else if (tool === 'rect') {
|
||||
await this.tools.rect.onStagePointerDown(e);
|
||||
}
|
||||
} finally {
|
||||
this.render();
|
||||
@@ -540,6 +362,10 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
onStagePointerUp = (e: KonvaEventObject<PointerEvent>) => {
|
||||
if (e.target !== this.konva.stage) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.$lastPointerType.set(e.evt.pointerType);
|
||||
|
||||
@@ -548,160 +374,46 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
}
|
||||
|
||||
const tool = this.$tool.get();
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
|
||||
if (tool === 'colorPicker') {
|
||||
const color = this.$colorUnderCursor.get();
|
||||
if (color) {
|
||||
this.manager.stateApi.setColor({ ...settings.color, ...color });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
if (!selectedEntity?.$isInteractable.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (tool === 'brush') {
|
||||
if (
|
||||
(selectedEntity.bufferRenderer.state?.type === 'brush_line' ||
|
||||
selectedEntity.bufferRenderer.state?.type === 'brush_line_with_pressure') &&
|
||||
selectedEntity.bufferRenderer.hasBuffer()
|
||||
) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
} else {
|
||||
selectedEntity.bufferRenderer.clearBuffer();
|
||||
}
|
||||
}
|
||||
|
||||
if (tool === 'eraser') {
|
||||
if (
|
||||
(selectedEntity.bufferRenderer.state?.type === 'eraser_line' ||
|
||||
selectedEntity.bufferRenderer.state?.type === 'eraser_line_with_pressure') &&
|
||||
selectedEntity.bufferRenderer.hasBuffer()
|
||||
) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
} else {
|
||||
selectedEntity.bufferRenderer.clearBuffer();
|
||||
}
|
||||
}
|
||||
|
||||
if (tool === 'rect') {
|
||||
if (selectedEntity.bufferRenderer.state?.type === 'rect' && selectedEntity.bufferRenderer.hasBuffer()) {
|
||||
selectedEntity.bufferRenderer.commitBuffer();
|
||||
} else {
|
||||
selectedEntity.bufferRenderer.clearBuffer();
|
||||
}
|
||||
this.tools.colorPicker.onStagePointerUp(e);
|
||||
} else if (tool === 'brush') {
|
||||
this.tools.brush.onStagePointerUp(e);
|
||||
} else if (tool === 'eraser') {
|
||||
this.tools.eraser.onStagePointerUp(e);
|
||||
} else if (tool === 'rect') {
|
||||
this.tools.rect.onStagePointerUp(e);
|
||||
}
|
||||
} finally {
|
||||
this.render();
|
||||
}
|
||||
};
|
||||
|
||||
syncColorUnderCursor = rafThrottle(() => {
|
||||
const cursorPos = this.$cursorPos.get();
|
||||
if (!cursorPos) {
|
||||
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
if (e.target !== this.konva.stage) {
|
||||
return;
|
||||
}
|
||||
|
||||
const color = getColorAtCoordinate(this.konva.stage, cursorPos.absolute);
|
||||
if (color) {
|
||||
this.$colorUnderCursor.set(color);
|
||||
}
|
||||
});
|
||||
|
||||
onStagePointerMove = async (e: KonvaEventObject<PointerEvent>) => {
|
||||
try {
|
||||
this.$lastPointerType.set(e.evt.pointerType);
|
||||
this.syncCursorPositions();
|
||||
|
||||
if (!this.getCanDraw()) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.syncCursorPositions();
|
||||
const cursorPos = this.$cursorPos.get();
|
||||
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
const tool = this.$tool.get();
|
||||
|
||||
if (tool === 'colorPicker') {
|
||||
this.syncColorUnderCursor();
|
||||
}
|
||||
|
||||
const isMouseDown = this.$isMouseDown.get();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (!isMouseDown || !selectedEntity?.$isInteractable.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bufferState = selectedEntity.bufferRenderer.state;
|
||||
|
||||
if (!bufferState) {
|
||||
return;
|
||||
}
|
||||
|
||||
const settings = this.manager.stateApi.getSettings();
|
||||
|
||||
if (tool === 'brush' && (bufferState.type === 'brush_line' || bufferState.type === 'brush_line_with_pressure')) {
|
||||
const lastPoint = getLastPointOfLine(bufferState.points);
|
||||
const minDistance = settings.brushWidth * this.config.BRUSH_SPACING_TARGET_SCALE;
|
||||
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.brushWidth);
|
||||
|
||||
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
|
||||
// Do not add duplicate points
|
||||
return;
|
||||
}
|
||||
|
||||
bufferState.points.push(alignedPoint.x, alignedPoint.y);
|
||||
|
||||
if (bufferState.type === 'brush_line_with_pressure') {
|
||||
bufferState.points.push(e.evt.pressure);
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer(bufferState);
|
||||
} else if (
|
||||
tool === 'eraser' &&
|
||||
(bufferState.type === 'eraser_line' || bufferState.type === 'eraser_line_with_pressure')
|
||||
) {
|
||||
const lastPoint = getLastPointOfLine(bufferState.points);
|
||||
const minDistance = settings.eraserWidth * this.config.BRUSH_SPACING_TARGET_SCALE;
|
||||
if (!lastPoint || !isDistanceMoreThanMin(cursorPos.relative, lastPoint, minDistance)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = alignCoordForTool(normalizedPoint, settings.eraserWidth);
|
||||
|
||||
if (lastPoint.x === alignedPoint.x && lastPoint.y === alignedPoint.y) {
|
||||
// Do not add duplicate points
|
||||
return;
|
||||
}
|
||||
|
||||
bufferState.points.push(alignedPoint.x, alignedPoint.y);
|
||||
|
||||
if (bufferState.type === 'eraser_line_with_pressure') {
|
||||
bufferState.points.push(e.evt.pressure);
|
||||
}
|
||||
|
||||
await selectedEntity.bufferRenderer.setBuffer(bufferState);
|
||||
} else if (tool === 'rect' && bufferState.type === 'rect') {
|
||||
const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
|
||||
const alignedPoint = floorCoord(normalizedPoint);
|
||||
bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x);
|
||||
bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y);
|
||||
await selectedEntity.bufferRenderer.setBuffer(bufferState);
|
||||
this.tools.colorPicker.onStagePointerMove(e);
|
||||
} else if (tool === 'brush') {
|
||||
await this.tools.brush.onStagePointerMove(e);
|
||||
} else if (tool === 'eraser') {
|
||||
await this.tools.eraser.onStagePointerMove(e);
|
||||
} else if (tool === 'rect') {
|
||||
await this.tools.rect.onStagePointerMove(e);
|
||||
} else {
|
||||
selectedEntity?.bufferRenderer.clearBuffer();
|
||||
this.manager.stateApi.getSelectedEntityAdapter()?.bufferRenderer.clearBuffer();
|
||||
}
|
||||
} finally {
|
||||
this.render();
|
||||
@@ -709,6 +421,10 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
onStagePointerLeave = (e: PointerEvent) => {
|
||||
if (e.target !== this.manager.stage.container) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.$lastPointerType.set(e.pointerType);
|
||||
this.$cursorPos.set(null);
|
||||
@@ -732,6 +448,10 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
};
|
||||
|
||||
onStageMouseWheel = (e: KonvaEventObject<WheelEvent>) => {
|
||||
if (e.target !== this.konva.stage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.getCanDraw()) {
|
||||
return;
|
||||
}
|
||||
@@ -770,7 +490,7 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
*/
|
||||
onWindowPointerUp = (_: PointerEvent) => {
|
||||
try {
|
||||
this.$isMouseDown.set(false);
|
||||
this.$isPrimaryPointerDown.set(false);
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
|
||||
|
||||
if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) {
|
||||
@@ -872,12 +592,18 @@ export class CanvasToolModule extends CanvasModuleBase {
|
||||
config: this.config,
|
||||
$tool: this.$tool.get(),
|
||||
$toolBuffer: this.$toolBuffer.get(),
|
||||
$isMouseDown: this.$isMouseDown.get(),
|
||||
$isPrimaryPointerDown: this.$isPrimaryPointerDown.get(),
|
||||
$cursorPos: this.$cursorPos.get(),
|
||||
$colorUnderCursor: this.$colorUnderCursor.get(),
|
||||
brushToolPreview: this.brushToolPreview.repr(),
|
||||
eraserToolPreview: this.eraserToolPreview.repr(),
|
||||
colorPickerToolPreview: this.colorPickerToolPreview.repr(),
|
||||
$lastPointerType: this.$lastPointerType.get(),
|
||||
tools: {
|
||||
brush: this.tools.brush.repr(),
|
||||
eraser: this.tools.eraser.repr(),
|
||||
colorPicker: this.tools.colorPicker.repr(),
|
||||
rect: this.tools.rect.repr(),
|
||||
bbox: this.tools.bbox.repr(),
|
||||
view: this.tools.view.repr(),
|
||||
move: this.tools.move.repr(),
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { Logger } from 'roarr';
|
||||
|
||||
export class CanvasViewToolModule extends CanvasModuleBase {
|
||||
readonly type = 'view_tool';
|
||||
readonly id: string;
|
||||
readonly path: string[];
|
||||
readonly parent: CanvasToolModule;
|
||||
readonly manager: CanvasManager;
|
||||
readonly log: Logger;
|
||||
|
||||
constructor(parent: CanvasToolModule) {
|
||||
super();
|
||||
this.id = getPrefixedId(this.type);
|
||||
this.parent = parent;
|
||||
this.manager = this.parent.manager;
|
||||
this.path = this.manager.buildPath(this);
|
||||
this.log = this.manager.buildLogger(this);
|
||||
|
||||
this.log.debug('Creating module');
|
||||
}
|
||||
|
||||
syncCursorStyle = () => {
|
||||
this.manager.stage.setCursor(this.manager.stage.getIsDragging() ? 'grabbing' : 'grab');
|
||||
};
|
||||
}
|
||||
@@ -1,6 +1,12 @@
|
||||
import type { Selector, Store } from '@reduxjs/toolkit';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import type { CanvasEntityIdentifier, CanvasObjectState, Coordinate, Rect } from 'features/controlLayers/store/types';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CanvasObjectState,
|
||||
Coordinate,
|
||||
CoordinateWithPressure,
|
||||
Rect,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Vector2d } from 'konva/lib/types';
|
||||
@@ -74,6 +80,18 @@ export const offsetCoord = (coord: Coordinate, offset: Coordinate): Coordinate =
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds two coordinates together.
|
||||
* @param a The first coordinate
|
||||
* @param b The second coordinate
|
||||
*/
|
||||
export const addCoords = (a: Coordinate, b: Coordinate): Coordinate => {
|
||||
return {
|
||||
x: a.x + b.x,
|
||||
y: a.y + b.y,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Snaps a position to the edge of the stage if within a threshold of the edge
|
||||
* @param pos The position to snap
|
||||
@@ -134,7 +152,7 @@ export const snapToRect = (pos: Vector2d, rect: Rect, threshold = 10): Vector2d
|
||||
* Checks if the left mouse button is currently pressed
|
||||
* @param e The konva event
|
||||
*/
|
||||
export const getIsMouseDown = (e: KonvaEventObject<MouseEvent>): boolean => e.evt.buttons === 1;
|
||||
export const getIsPrimaryPointerDown = (e: KonvaEventObject<PointerEvent>): boolean => e.evt.buttons === 1;
|
||||
|
||||
/**
|
||||
* Checks if the stage is currently focused
|
||||
@@ -545,11 +563,6 @@ export const exhaustiveCheck = (value: never): never => {
|
||||
assert(false, `Unhandled value: ${value}`);
|
||||
};
|
||||
|
||||
type CoordinateWithPressure = {
|
||||
x: number;
|
||||
y: number;
|
||||
pressure: number;
|
||||
};
|
||||
export const getLastPointOfLastLineWithPressure = (
|
||||
objects: CanvasObjectState[],
|
||||
type: 'brush_line_with_pressure' | 'eraser_line_with_pressure'
|
||||
@@ -615,6 +628,7 @@ export const getKonvaNodeDebugAttrs = (node: Konva.Node) => {
|
||||
isCached: node.isCached(),
|
||||
visible: node.visible(),
|
||||
listening: node.listening(),
|
||||
zIndex: node.zIndex(),
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -48,9 +48,9 @@ type CanvasSettingsState = {
|
||||
*/
|
||||
outputOnlyMaskedRegions: boolean;
|
||||
/**
|
||||
* Whether to automatically process the filter when the filter configuration changes.
|
||||
* Whether to automatically process the operations like filtering and auto-masking.
|
||||
*/
|
||||
autoProcessFilter: boolean;
|
||||
autoProcess: boolean;
|
||||
/**
|
||||
* The snap-to-grid setting for the canvas.
|
||||
*/
|
||||
@@ -72,13 +72,9 @@ type CanvasSettingsState = {
|
||||
*/
|
||||
isolatedStagingPreview: boolean;
|
||||
/**
|
||||
* Whether to show only the selected layer while filtering.
|
||||
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
|
||||
*/
|
||||
isolatedFilteringPreview: boolean;
|
||||
/**
|
||||
* Whether to show only the selected layer while transforming.
|
||||
*/
|
||||
isolatedTransformingPreview: boolean;
|
||||
isolatedLayerPreview: boolean;
|
||||
/**
|
||||
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
|
||||
*/
|
||||
@@ -95,14 +91,13 @@ const initialState: CanvasSettingsState = {
|
||||
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
|
||||
sendToCanvas: false,
|
||||
outputOnlyMaskedRegions: false,
|
||||
autoProcessFilter: true,
|
||||
autoProcess: true,
|
||||
snapToGrid: true,
|
||||
showProgressOnCanvas: true,
|
||||
bboxOverlay: false,
|
||||
preserveMask: false,
|
||||
isolatedStagingPreview: true,
|
||||
isolatedFilteringPreview: true,
|
||||
isolatedTransformingPreview: true,
|
||||
isolatedLayerPreview: true,
|
||||
pressureSensitivity: true,
|
||||
};
|
||||
|
||||
@@ -137,8 +132,8 @@ export const canvasSettingsSlice = createSlice({
|
||||
settingsOutputOnlyMaskedRegionsToggled: (state) => {
|
||||
state.outputOnlyMaskedRegions = !state.outputOnlyMaskedRegions;
|
||||
},
|
||||
settingsAutoProcessFilterToggled: (state) => {
|
||||
state.autoProcessFilter = !state.autoProcessFilter;
|
||||
settingsAutoProcessToggled: (state) => {
|
||||
state.autoProcess = !state.autoProcess;
|
||||
},
|
||||
settingsSnapToGridToggled: (state) => {
|
||||
state.snapToGrid = !state.snapToGrid;
|
||||
@@ -155,11 +150,8 @@ export const canvasSettingsSlice = createSlice({
|
||||
settingsIsolatedStagingPreviewToggled: (state) => {
|
||||
state.isolatedStagingPreview = !state.isolatedStagingPreview;
|
||||
},
|
||||
settingsIsolatedFilteringPreviewToggled: (state) => {
|
||||
state.isolatedFilteringPreview = !state.isolatedFilteringPreview;
|
||||
},
|
||||
settingsIsolatedTransformingPreviewToggled: (state) => {
|
||||
state.isolatedTransformingPreview = !state.isolatedTransformingPreview;
|
||||
settingsIsolatedLayerPreviewToggled: (state) => {
|
||||
state.isolatedLayerPreview = !state.isolatedLayerPreview;
|
||||
},
|
||||
settingsPressureSensitivityToggled: (state) => {
|
||||
state.pressureSensitivity = !state.pressureSensitivity;
|
||||
@@ -185,14 +177,13 @@ export const {
|
||||
settingsInvertScrollForToolWidthChanged,
|
||||
settingsSendToCanvasChanged,
|
||||
settingsOutputOnlyMaskedRegionsToggled,
|
||||
settingsAutoProcessFilterToggled,
|
||||
settingsAutoProcessToggled,
|
||||
settingsSnapToGridToggled,
|
||||
settingsShowProgressOnCanvasToggled,
|
||||
settingsBboxOverlayToggled,
|
||||
settingsPreserveMaskToggled,
|
||||
settingsIsolatedStagingPreviewToggled,
|
||||
settingsIsolatedFilteringPreviewToggled,
|
||||
settingsIsolatedTransformingPreviewToggled,
|
||||
settingsIsolatedLayerPreviewToggled,
|
||||
settingsPressureSensitivityToggled,
|
||||
} = canvasSettingsSlice.actions;
|
||||
|
||||
@@ -219,17 +210,12 @@ export const selectOutputOnlyMaskedRegions = createCanvasSettingsSelector(
|
||||
export const selectDynamicGrid = createCanvasSettingsSelector((settings) => settings.dynamicGrid);
|
||||
export const selectBboxOverlay = createCanvasSettingsSelector((settings) => settings.bboxOverlay);
|
||||
export const selectShowHUD = createCanvasSettingsSelector((settings) => settings.showHUD);
|
||||
export const selectAutoProcessFilter = createCanvasSettingsSelector((settings) => settings.autoProcessFilter);
|
||||
export const selectAutoProcess = createCanvasSettingsSelector((settings) => settings.autoProcess);
|
||||
export const selectSnapToGrid = createCanvasSettingsSelector((settings) => settings.snapToGrid);
|
||||
export const selectSendToCanvas = createCanvasSettingsSelector((canvasSettings) => canvasSettings.sendToCanvas);
|
||||
export const selectShowProgressOnCanvas = createCanvasSettingsSelector(
|
||||
(canvasSettings) => canvasSettings.showProgressOnCanvas
|
||||
);
|
||||
export const selectIsolatedStagingPreview = createCanvasSettingsSelector((settings) => settings.isolatedStagingPreview);
|
||||
export const selectIsolatedFilteringPreview = createCanvasSettingsSelector(
|
||||
(settings) => settings.isolatedFilteringPreview
|
||||
);
|
||||
export const selectIsolatedTransformingPreview = createCanvasSettingsSelector(
|
||||
(settings) => settings.isolatedTransformingPreview
|
||||
);
|
||||
export const selectIsolatedLayerPreview = createCanvasSettingsSelector((settings) => settings.isolatedLayerPreview);
|
||||
export const selectPressureSensitivity = createCanvasSettingsSelector((settings) => settings.pressureSensitivity);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user