mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 17:28:04 -05:00
Compare commits
128 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75acece1f1 | ||
|
|
a9db2ffefd | ||
|
|
cdd148b4d1 | ||
|
|
730fabe2de | ||
|
|
6c59790a7f | ||
|
|
c37251d6f7 | ||
|
|
2854210162 | ||
|
|
5545b980af | ||
|
|
0c9434c464 | ||
|
|
8771de917d | ||
|
|
122946ef4c | ||
|
|
2d974f670c | ||
|
|
75f0da9c35 | ||
|
|
5df3c00e28 | ||
|
|
b049880502 | ||
|
|
e5293fdd1a | ||
|
|
8883775762 | ||
|
|
cfadb313d2 | ||
|
|
b5cadd9a1a | ||
|
|
5361b6e014 | ||
|
|
ff346172af | ||
|
|
92f660018b | ||
|
|
1afc2cba4e | ||
|
|
ee8359242c | ||
|
|
f0c80a8d7a | ||
|
|
8da9e7c1f6 | ||
|
|
6d7a486e5b | ||
|
|
57122c6aa3 | ||
|
|
54abd8d4d1 | ||
|
|
06283cffed | ||
|
|
27fa0e1140 | ||
|
|
533d48abdb | ||
|
|
6845cae4c9 | ||
|
|
31c9acb1fa | ||
|
|
fb5e462300 | ||
|
|
2f3abc29b1 | ||
|
|
c5c071f285 | ||
|
|
93a3ed56e7 | ||
|
|
406fc58889 | ||
|
|
cf67d084fd | ||
|
|
d4a95af14f | ||
|
|
8c8e7102c2 | ||
|
|
b6b9ea9d70 | ||
|
|
63126950bc | ||
|
|
29d63d5dea | ||
|
|
2f6b035138 | ||
|
|
4f9ae44472 | ||
|
|
c682330852 | ||
|
|
c064257759 | ||
|
|
8a4c629576 | ||
|
|
a01d44f813 | ||
|
|
63fb3a15e9 | ||
|
|
4d0837541b | ||
|
|
999809b4c7 | ||
|
|
c452edfb9f | ||
|
|
ad2cdbd8a2 | ||
|
|
f15c24bfa7 | ||
|
|
d1f653f28c | ||
|
|
244465d3a6 | ||
|
|
c6236ab70c | ||
|
|
644d5cb411 | ||
|
|
bb0a630416 | ||
|
|
2148ae9287 | ||
|
|
42d242609c | ||
|
|
fd0a52392b | ||
|
|
e64415d59a | ||
|
|
1871e0bdbf | ||
|
|
3ae9a965c2 | ||
|
|
85932e35a7 | ||
|
|
41b07a56cc | ||
|
|
54064c0cb8 | ||
|
|
68284b37fa | ||
|
|
ae5bc6f5d6 | ||
|
|
6dc16c9f54 | ||
|
|
faa9ac4e15 | ||
|
|
d0460849b0 | ||
|
|
bed3c2dd77 | ||
|
|
916ddd17d7 | ||
|
|
accfa7407f | ||
|
|
908db31e48 | ||
|
|
b70f632b26 | ||
|
|
d07a6385ab | ||
|
|
68df612fa1 | ||
|
|
3b96c79461 | ||
|
|
89bda5b983 | ||
|
|
22bff1fb22 | ||
|
|
55ba6488d1 | ||
|
|
2d78859171 | ||
|
|
3a661bac34 | ||
|
|
bb8a02de18 | ||
|
|
78155344f6 | ||
|
|
391a24b0f6 | ||
|
|
e75903389f | ||
|
|
27567052f2 | ||
|
|
6f447f7169 | ||
|
|
8b370cc182 | ||
|
|
af583d2971 | ||
|
|
0ebe8fb1bd | ||
|
|
befb629f46 | ||
|
|
874d67cb37 | ||
|
|
19f7a1295a | ||
|
|
78bd605617 | ||
|
|
b87f4e59a5 | ||
|
|
1eca4f12c8 | ||
|
|
f1de11d6bf | ||
|
|
9361ed9d70 | ||
|
|
ebabf4f7a8 | ||
|
|
606f3321f5 | ||
|
|
3970aa30fb | ||
|
|
678436e07c | ||
|
|
c620581699 | ||
|
|
c331d42ce4 | ||
|
|
1ac9b502f1 | ||
|
|
3fa478a12f | ||
|
|
2d86298b7f | ||
|
|
009cdb714c | ||
|
|
9d3f5427b4 | ||
|
|
e4b17f019a | ||
|
|
586c00bc02 | ||
|
|
0f11fda65a | ||
|
|
3e75331ef7 | ||
|
|
be133408ac | ||
|
|
7e1e0d6928 | ||
|
|
cd3d8df5a8 | ||
|
|
24d3c22017 | ||
|
|
b0d37f4e51 | ||
|
|
3559124674 | ||
|
|
6c33e02141 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -19,3 +19,4 @@
|
||||
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
|
||||
- [ ] _Tests added / updated (if applicable)_
|
||||
- [ ] _Documentation added / updated (if applicable)_
|
||||
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
|
||||
|
||||
@@ -40,6 +40,8 @@ class AppVersion(BaseModel):
|
||||
|
||||
version: str = Field(description="App version")
|
||||
|
||||
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
|
||||
|
||||
|
||||
class AppDependencyVersions(BaseModel):
|
||||
"""App depencency Versions Response"""
|
||||
|
||||
@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
@@ -52,6 +53,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -131,8 +134,10 @@ class FieldDescriptions:
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
clip_g_model = "CLIP-G Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
mmditx = "MMDiTX"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
@@ -140,6 +145,7 @@ class FieldDescriptions:
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -246,6 +252,12 @@ class FluxConditioningField(BaseModel):
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.0",
|
||||
version="3.2.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -81,6 +81,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
@@ -207,9 +208,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
if self.add_noise:
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
x = init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
|
||||
89
invokeai/app/invocations/flux_model_loader.py
Normal file
89
invokeai/app/invocations/flux_model_loader.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("sd3_conditioning_output")
|
||||
class SD3ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single SD3 conditioning tensor"""
|
||||
|
||||
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
|
||||
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
260
invokeai/app/invocations/sd3_denoise.py
Normal file
260
invokeai/app/invocations/sd3_denoise.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
SD3ConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_denoise",
|
||||
title="SD3 Denoise",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
joint_attention_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
sd3_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
|
||||
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
|
||||
|
||||
t5_embeds = sd3_conditioning.t5_embeds
|
||||
if t5_embeds is None:
|
||||
t5_embeds = torch.zeros(
|
||||
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat(
|
||||
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
num_samples: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
||||
"""Prepare the CFG scale list.
|
||||
|
||||
Args:
|
||||
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
||||
on the scheduler used (e.g. higher order schedulers).
|
||||
|
||||
Returns:
|
||||
list[float]: _description_
|
||||
"""
|
||||
if isinstance(self.cfg_scale, float):
|
||||
cfg_scale = [self.cfg_scale] * num_timesteps
|
||||
elif isinstance(self.cfg_scale, list):
|
||||
assert len(self.cfg_scale) == num_timesteps
|
||||
cfg_scale = self.cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
do_classifier_free_guidance = True
|
||||
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.positive_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.negative_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare the scheduler.
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
assert isinstance(timesteps, torch.Tensor)
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(len(timesteps))
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
assert isinstance(num_channels_latents, int)
|
||||
noise = self._get_noise(
|
||||
num_samples=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
latents: torch.Tensor = noise
|
||||
|
||||
total_steps = len(timesteps)
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps[0]),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
with transformer_info.model_on_device() as (cached_weights, transformer):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
|
||||
# 6. Denoising loop
|
||||
for step_idx, t in tqdm(list(enumerate(timesteps))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=None,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG.
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
|
||||
|
||||
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
|
||||
# needed.
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
|
||||
|
||||
return step_callback
|
||||
73
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
73
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_l2i",
|
||||
title="SD3 Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "sd3"],
|
||||
category="latents",
|
||||
version="1.3.0",
|
||||
)
|
||||
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
latents = latents.to(vae.device)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode(), tiling_context:
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
img = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
108
invokeai/app/invocations/sd3_model_loader.py
Normal file
108
invokeai/app/invocations/sd3_model_loader.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SD3 base model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
|
||||
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_model_loader",
|
||||
title="SD3 Main Model",
|
||||
tags=["model", "sd3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.SD3MainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_l_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPLEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP L Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_g_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_g_model,
|
||||
ui_type=UIType.CLIPGEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP G Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = (
|
||||
self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
if self.vae_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
)
|
||||
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_l = (
|
||||
self.clip_l_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
if self.clip_l_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
)
|
||||
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
clip_encoder_g = (
|
||||
self.clip_g_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
if self.clip_g_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
)
|
||||
tokenizer_t5 = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
)
|
||||
t5_encoder = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
|
||||
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
199
invokeai/app/invocations/sd3_text_encoder.py
Normal file
199
invokeai/app/invocations/sd3_text_encoder.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_text_encoder",
|
||||
title="SD3 Text Encoding",
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a SD3 image."""
|
||||
|
||||
clip_l: CLIPField = InputField(
|
||||
title="CLIP L",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
clip_g: CLIPField = InputField(
|
||||
title="CLIP G",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
title="T5Encoder",
|
||||
default=None,
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
|
||||
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
|
||||
|
||||
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
|
||||
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
|
||||
|
||||
t5_embeddings: torch.Tensor | None = None
|
||||
if self.t5_encoder is not None:
|
||||
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SD3ConditioningInfo(
|
||||
clip_l_embeds=clip_l_embeddings,
|
||||
clip_l_pooled_embeds=clip_l_pooled_embeddings,
|
||||
clip_g_embeds=clip_g_embeddings,
|
||||
clip_g_pooled_embeds=clip_g_pooled_embeddings,
|
||||
t5_embeds=t5_embeddings,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return SD3ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
text_inputs = t5_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
|
||||
|
||||
text_inputs = clip_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(
|
||||
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
@@ -77,19 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
default="all",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_point_lists_or_bounding_box(self):
|
||||
if self.point_lists is None and self.bounding_boxes is None:
|
||||
raise ValueError("Either point_lists or bounding_box must be provided.")
|
||||
elif self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
|
||||
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
|
||||
not self.point_lists or len(self.point_lists) == 0
|
||||
):
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
@@ -85,7 +86,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
{
|
||||
"name": "SD3.5 Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "Sample text to image workflow for Stable Diffusion 3.5",
|
||||
"version": "1.0.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text2image, SD3.5, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"fieldName": "prompt"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "sd3_model_loader",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
|
||||
"hash": "placeholder",
|
||||
"name": "stable-diffusion-3.5-medium",
|
||||
"base": "sd-3",
|
||||
"type": "main"
|
||||
}
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_l_model": {
|
||||
"name": "clip_l_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g_model": {
|
||||
"name": "clip_g_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": -55.58689609637031,
|
||||
"y": -111.53602444662268
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 470.45870147220353,
|
||||
"y": 350.3141781644303
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "sd3_l2i",
|
||||
"version": "1.3.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1192.3097009334897,
|
||||
"y": -366.0994675072209
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 408.16054647924784,
|
||||
"y": 65.06415352118786
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 378.9283412440941,
|
||||
"y": -302.65777497352553
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "sd3_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_conditioning": {
|
||||
"name": "positive_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"name": "negative_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"cfg_scale": {
|
||||
"name": "cfg_scale",
|
||||
"label": "",
|
||||
"value": 3.5
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"steps": {
|
||||
"name": "steps",
|
||||
"label": "",
|
||||
"value": 30
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 813.7814762740603,
|
||||
"y": -142.20529727605867
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
|
||||
"type": "default",
|
||||
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
|
||||
"type": "default",
|
||||
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
|
||||
"type": "default",
|
||||
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
|
||||
"type": "default",
|
||||
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
SD3_5_LATENT_RGB_FACTORS = [
|
||||
[-0.05240681, 0.03251581, 0.0749016],
|
||||
[-0.0580572, 0.00759826, 0.05729818],
|
||||
[0.16144888, 0.01270368, -0.03768577],
|
||||
[0.14418615, 0.08460266, 0.15941818],
|
||||
[0.04894035, 0.0056485, -0.06686988],
|
||||
[0.05187166, 0.19222395, 0.06261094],
|
||||
[0.1539433, 0.04818359, 0.07103094],
|
||||
[-0.08601796, 0.09013458, 0.10893912],
|
||||
[-0.12398469, -0.06766567, 0.0033688],
|
||||
[-0.0439737, 0.07825329, 0.02258823],
|
||||
[0.03101129, 0.06382551, 0.07753657],
|
||||
[-0.01315361, 0.08554491, -0.08772475],
|
||||
[0.06464487, 0.05914605, 0.13262741],
|
||||
[-0.07863674, -0.02261737, -0.12761454],
|
||||
[-0.09923835, -0.08010759, -0.06264447],
|
||||
[-0.03392309, -0.0804029, -0.06078822],
|
||||
]
|
||||
|
||||
FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.0412, 0.0149, 0.0521],
|
||||
[0.0056, 0.0291, 0.0768],
|
||||
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
elif base_model == BaseModelType.StableDiffusion3:
|
||||
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
|
||||
else:
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusion3 = "sd-3"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
@@ -92,6 +95,13 @@ class SubModelType(str, Enum):
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ClipVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
L = "large"
|
||||
G = "gigantic"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
@@ -147,6 +157,17 @@ class ModelSourceType(str, Enum):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
variant: AnyVariant = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class MainModelDefaultSettings(BaseModel):
|
||||
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
||||
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
||||
@@ -193,6 +214,9 @@ class ModelConfigBase(BaseModel):
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
@@ -335,7 +359,7 @@ class MainConfigBase(ModelConfigBase):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
variant: AnyVariant = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
@@ -419,12 +443,33 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.G
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
|
||||
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-L Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@@ -501,6 +546,8 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
@@ -33,8 +34,15 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubmodelDefinition,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
get_clip_variant_type,
|
||||
lora_token_vector_length,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
@@ -112,6 +120,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusion3Pipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
@@ -122,8 +131,12 @@ class ModelProbe(object):
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
"SD3Transformer2DModel": ModelType.Main,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||
@@ -170,7 +183,10 @@ class ModelProbe(object):
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
|
||||
fields["variant"] = (
|
||||
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
|
||||
)
|
||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
@@ -217,6 +233,10 @@ class ModelProbe(object):
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
get_submodels = getattr(probe, "get_submodels", None)
|
||||
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
|
||||
fields["submodels"] = get_submodels()
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@@ -747,18 +767,33 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
|
||||
config_path = self.model_path / "unet" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
# Handle pipelines with a transformer (i.e. SD3).
|
||||
config_path = self.model_path / "transformer" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
transformer_conf = json.load(file)
|
||||
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
|
||||
return BaseModelType.StableDiffusion3
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
@@ -770,6 +805,23 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
||||
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
|
||||
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
|
||||
submodels: Dict[SubModelType, SubmodelDefinition] = {}
|
||||
for key, value in config.items():
|
||||
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
|
||||
continue
|
||||
model_loader = str(value[1])
|
||||
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
|
||||
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=(self.model_path / key).resolve().as_posix(),
|
||||
model_type=model_type,
|
||||
variant=variant_func and variant_func((self.model_path / key).as_posix()),
|
||||
)
|
||||
|
||||
return submodels
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
|
||||
@@ -140,6 +140,22 @@ flux_dev = StarterModel(
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-medium",
|
||||
description="Medium SD3.5 Model: ~15GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
sd35_large = StarterModel(
|
||||
name="SD3.5 Large",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-large",
|
||||
description="Large SD3.5 Model: ~19G",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
cyberrealistic_sd1 = StarterModel(
|
||||
name="CyberRealistic v4.1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -570,6 +586,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
flux_dev,
|
||||
sd35_medium,
|
||||
sd35_large,
|
||||
cyberrealistic_sd1,
|
||||
rev_animated_sd1,
|
||||
dreamshaper_8_sd1,
|
||||
|
||||
@@ -8,6 +8,7 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_manager.config import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
@@ -165,3 +166,25 @@ def convert_bundle_to_flux_transformer_checkpoint(
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
|
||||
try:
|
||||
path = Path(location)
|
||||
config_path = path / "config.json"
|
||||
if not config_path.exists():
|
||||
config_path = path / "text_encoder" / "config.json"
|
||||
if not config_path.exists():
|
||||
return ClipVariantType.L
|
||||
with open(config_path) as file:
|
||||
clip_conf = json.load(file)
|
||||
hidden_size = clip_conf.get("hidden_size", -1)
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return ClipVariantType.L
|
||||
except Exception:
|
||||
return ClipVariantType.L
|
||||
|
||||
@@ -85,6 +85,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
safetensors_detected = False
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
@@ -119,19 +120,27 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
if safetensors_detected and path.suffix == ".bin":
|
||||
continue
|
||||
|
||||
parent = path.parent
|
||||
score = 0
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
safetensors_detected = True
|
||||
if parent in subfolder_weights:
|
||||
subfolder_weights[parent] = [sfc for sfc in subfolder_weights[parent] if sfc.path.suffix != ".bin"]
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
if (
|
||||
variant is not ModelRepoVariant.Default
|
||||
and candidate_variant_label
|
||||
and candidate_variant_label.startswith(f".{variant.value}")
|
||||
) or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
@@ -146,7 +155,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
@@ -162,7 +171,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
|
||||
match = re.match(pattern, highest_score_candidate.path.as_posix())
|
||||
if match:
|
||||
for candidate in candidate_list:
|
||||
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
|
||||
match.group(2)
|
||||
):
|
||||
result.add(candidate.path)
|
||||
else:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
|
||||
@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SD3ConditioningInfo:
|
||||
clip_l_pooled_embeds: torch.Tensor
|
||||
clip_l_embeds: torch.Tensor
|
||||
clip_g_pooled_embeds: torch.Tensor
|
||||
clip_g_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor | None
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
|
||||
if self.t5_embeds is not None:
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
| List[FLUXConditioningInfo]
|
||||
| List[SD3ConditioningInfo]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,6 +9,7 @@ const config: KnipConfig = {
|
||||
'src/services/api/schema.ts',
|
||||
'src/features/nodes/types/v1/**',
|
||||
'src/features/nodes/types/v2/**',
|
||||
'src/features/parameters/types/parameterSchemas.ts',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// TODO(psyche): restore HRF functionality?
|
||||
|
||||
@@ -52,11 +52,11 @@
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
|
||||
"@dagrejs/dagre": "^1.1.4",
|
||||
"@dagrejs/graphlib": "^2.2.4",
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@fontsource-variable/inter": "^5.1.0",
|
||||
"@invoke-ai/ui-library": "^0.0.43",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
|
||||
91
invokeai/frontend/web/pnpm-lock.yaml
generated
91
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -5,21 +5,21 @@ settings:
|
||||
excludeLinksFromLockfile: false
|
||||
|
||||
dependencies:
|
||||
'@atlaskit/pragmatic-drag-and-drop':
|
||||
specifier: ^1.4.0
|
||||
version: 1.4.0
|
||||
'@atlaskit/pragmatic-drag-and-drop-auto-scroll':
|
||||
specifier: ^1.4.0
|
||||
version: 1.4.0
|
||||
'@atlaskit/pragmatic-drag-and-drop-hitbox':
|
||||
specifier: ^1.0.3
|
||||
version: 1.0.3
|
||||
'@dagrejs/dagre':
|
||||
specifier: ^1.1.4
|
||||
version: 1.1.4
|
||||
'@dagrejs/graphlib':
|
||||
specifier: ^2.2.4
|
||||
version: 2.2.4
|
||||
'@dnd-kit/core':
|
||||
specifier: ^6.1.0
|
||||
version: 6.1.0(react-dom@18.3.1)(react@18.3.1)
|
||||
'@dnd-kit/sortable':
|
||||
specifier: ^8.0.0
|
||||
version: 8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1)
|
||||
'@dnd-kit/utilities':
|
||||
specifier: ^3.2.2
|
||||
version: 3.2.2(react@18.3.1)
|
||||
'@fontsource-variable/inter':
|
||||
specifier: ^5.1.0
|
||||
version: 5.1.0
|
||||
@@ -319,6 +319,28 @@ packages:
|
||||
'@jridgewell/trace-mapping': 0.3.25
|
||||
dev: true
|
||||
|
||||
/@atlaskit/pragmatic-drag-and-drop-auto-scroll@1.4.0:
|
||||
resolution: {integrity: sha512-5GoikoTSW13UX76F9TDeWB8x3jbbGlp/Y+3aRkHe1MOBMkrWkwNpJ42MIVhhX/6NSeaZiPumP0KbGJVs2tOWSQ==}
|
||||
dependencies:
|
||||
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
|
||||
'@babel/runtime': 7.25.7
|
||||
dev: false
|
||||
|
||||
/@atlaskit/pragmatic-drag-and-drop-hitbox@1.0.3:
|
||||
resolution: {integrity: sha512-/Sbu/HqN2VGLYBhnsG7SbRNg98XKkbF6L7XDdBi+izRybfaK1FeMfodPpm/xnBHPJzwYMdkE0qtLyv6afhgMUA==}
|
||||
dependencies:
|
||||
'@atlaskit/pragmatic-drag-and-drop': 1.4.0
|
||||
'@babel/runtime': 7.25.7
|
||||
dev: false
|
||||
|
||||
/@atlaskit/pragmatic-drag-and-drop@1.4.0:
|
||||
resolution: {integrity: sha512-qRY3PTJIcxfl/QB8Gwswz+BRvlmgAC5pB+J2hL6dkIxgqAgVwOhAamMUKsrOcFU/axG2Q7RbNs1xfoLKDuhoPg==}
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.7
|
||||
bind-event-listener: 3.0.0
|
||||
raf-schd: 4.0.3
|
||||
dev: false
|
||||
|
||||
/@babel/code-frame@7.25.7:
|
||||
resolution: {integrity: sha512-0xZJFNE5XMpENsgfHYTw8FbX4kv53mFLn2i3XPoq69LyhYSCBJtitaHx9QnsVTrsogI4Z3+HtEfZ2/GFPOtf5g==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
@@ -980,49 +1002,6 @@ packages:
|
||||
engines: {node: '>17.0.0'}
|
||||
dev: false
|
||||
|
||||
/@dnd-kit/accessibility@3.1.0(react@18.3.1):
|
||||
resolution: {integrity: sha512-ea7IkhKvlJUv9iSHJOnxinBcoOI3ppGnnL+VDJ75O45Nss6HtZd8IdN8touXPDtASfeI2T2LImb8VOZcL47wjQ==}
|
||||
peerDependencies:
|
||||
react: '>=16.8.0'
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
tslib: 2.7.0
|
||||
dev: false
|
||||
|
||||
/@dnd-kit/core@6.1.0(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-J3cQBClB4TVxwGo3KEjssGEXNJqGVWx17aRTZ1ob0FliR5IjYgTxl5YJbKTzA6IzrtelotH19v6y7uoIRUZPSg==}
|
||||
peerDependencies:
|
||||
react: '>=16.8.0'
|
||||
react-dom: '>=16.8.0'
|
||||
dependencies:
|
||||
'@dnd-kit/accessibility': 3.1.0(react@18.3.1)
|
||||
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
tslib: 2.7.0
|
||||
dev: false
|
||||
|
||||
/@dnd-kit/sortable@8.0.0(@dnd-kit/core@6.1.0)(react@18.3.1):
|
||||
resolution: {integrity: sha512-U3jk5ebVXe1Lr7c2wU7SBZjcWdQP+j7peHJfCspnA81enlu88Mgd7CC8Q+pub9ubP7eKVETzJW+IBAhsqbSu/g==}
|
||||
peerDependencies:
|
||||
'@dnd-kit/core': ^6.1.0
|
||||
react: '>=16.8.0'
|
||||
dependencies:
|
||||
'@dnd-kit/core': 6.1.0(react-dom@18.3.1)(react@18.3.1)
|
||||
'@dnd-kit/utilities': 3.2.2(react@18.3.1)
|
||||
react: 18.3.1
|
||||
tslib: 2.7.0
|
||||
dev: false
|
||||
|
||||
/@dnd-kit/utilities@3.2.2(react@18.3.1):
|
||||
resolution: {integrity: sha512-+MKAJEOfaBe5SmV6t34p80MMKhjvUz0vRrvVJbPT0WElzaOJ/1xs+D+KDv+tD/NE5ujfrChEcshd4fLn0wpiqg==}
|
||||
peerDependencies:
|
||||
react: '>=16.8.0'
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
tslib: 2.7.0
|
||||
dev: false
|
||||
|
||||
/@emotion/babel-plugin@11.12.0:
|
||||
resolution: {integrity: sha512-y2WQb+oP8Jqvvclh8Q55gLUyb7UFvgv7eJfsj7td5TToBrIUtPay2kMrZi4xjq9qw2vD0ZR5fSho0yqoFgX7Rw==}
|
||||
dependencies:
|
||||
@@ -4313,6 +4292,10 @@ packages:
|
||||
open: 8.4.2
|
||||
dev: true
|
||||
|
||||
/bind-event-listener@3.0.0:
|
||||
resolution: {integrity: sha512-PJvH288AWQhKs2v9zyfYdPzlPqf5bXbGMmhmUIY9x4dAUGIWgomO771oBQNwJnMQSnUIXhKu6sgzpBRXTlvb8Q==}
|
||||
dev: false
|
||||
|
||||
/bl@4.1.0:
|
||||
resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==}
|
||||
dependencies:
|
||||
@@ -7557,6 +7540,10 @@ packages:
|
||||
resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==}
|
||||
dev: true
|
||||
|
||||
/raf-schd@4.0.3:
|
||||
resolution: {integrity: sha512-tQkJl2GRWh83ui2DiPTJz9wEiMN20syf+5oKfB03yYP7ioZcJwsIK8FjrtLwH1m7C7e+Tt2yYBlrOpdT+dyeIQ==}
|
||||
dev: false
|
||||
|
||||
/raf-throttle@2.0.6:
|
||||
resolution: {integrity: sha512-C7W6hy78A+vMmk5a/B6C5szjBHrUzWJkVyakjKCK59Uy2CcA7KhO1JUvvH32IXYFIcyJ3FMKP3ZzCc2/71I6Vg==}
|
||||
dev: false
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 895 KiB |
@@ -997,6 +997,7 @@
|
||||
"controlNetControlMode": "Control Mode",
|
||||
"copyImage": "Copy Image",
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"noRasterLayers": "No Raster Layers",
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"guidance": "Guidance",
|
||||
@@ -1412,8 +1413,9 @@
|
||||
"paramDenoisingStrength": {
|
||||
"heading": "Denoising Strength",
|
||||
"paragraphs": [
|
||||
"How much noise is added to the input image.",
|
||||
"0 will result in an identical image, while 1 will result in a completely new image."
|
||||
"Controls how much the generated image varies from the raster layer(s).",
|
||||
"Lower strength stays closer to the combined visible raster layers. Higher strength relies more on the global prompt.",
|
||||
"When there are no raster layers with visible content, this setting is ignored."
|
||||
]
|
||||
},
|
||||
"paramHeight": {
|
||||
@@ -1662,6 +1664,7 @@
|
||||
"mergeDown": "Merge Down",
|
||||
"mergeVisibleOk": "Merged layers",
|
||||
"mergeVisibleError": "Error merging layers",
|
||||
"mergingLayers": "Merging layers",
|
||||
"clearHistory": "Clear History",
|
||||
"bboxOverlay": "Show Bbox Overlay",
|
||||
"resetCanvas": "Reset Canvas",
|
||||
@@ -1774,9 +1777,10 @@
|
||||
"newCanvasSession": "New Canvas Session",
|
||||
"newCanvasSessionDesc": "This will clear the canvas and all settings except for your model selection. Generations will be staged on the canvas.",
|
||||
"replaceCurrent": "Replace Current",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
|
||||
"controlMode": {
|
||||
"controlMode": "Control Mode",
|
||||
"balanced": "Balanced",
|
||||
"balanced": "Balanced (recommended)",
|
||||
"prompt": "Prompt",
|
||||
"control": "Control",
|
||||
"megaControl": "Mega Control"
|
||||
@@ -1815,6 +1819,9 @@
|
||||
"process": "Process",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel",
|
||||
"advanced": "Advanced",
|
||||
"processingLayerWith": "Processing layer with the {{type}} filter.",
|
||||
"forMoreControl": "For more control, click Advanced below.",
|
||||
"spandrel_filter": {
|
||||
"label": "Image-to-Image Model",
|
||||
"description": "Run an image-to-image model on the selected layer.",
|
||||
@@ -2095,9 +2102,10 @@
|
||||
},
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"line1": "<ItalicComponent>Select Object</ItalicComponent> tool for precise object selection and editing",
|
||||
"line2": "Expanded Flux support, now with Global Reference Images",
|
||||
"line3": "Improved tooltips and context menus",
|
||||
"items": [
|
||||
"<StrongComponent>SD 3.5</StrongComponent>: Support for Text-to-Image in Workflows with SD 3.5 Medium and Large.",
|
||||
"<StrongComponent>Canvas</StrongComponent>: Streamlined Control Layer processing and improved default Control settings."
|
||||
],
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
"watchUiUpdatesOverview": "Watch UI Updates Overview"
|
||||
|
||||
@@ -8,10 +8,8 @@ import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
|
||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
|
||||
import { useFocusRegionWatcher } from 'common/hooks/focus';
|
||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
|
||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import {
|
||||
@@ -19,6 +17,7 @@ import {
|
||||
NewGallerySessionDialog,
|
||||
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
|
||||
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
@@ -62,8 +61,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
useGetOpenAPISchemaQuery();
|
||||
useSyncLoggingConfig();
|
||||
|
||||
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
clearStorage();
|
||||
location.reload();
|
||||
@@ -92,19 +89,8 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box
|
||||
id="invoke-app-wrapper"
|
||||
w="100dvw"
|
||||
h="100dvh"
|
||||
position="relative"
|
||||
overflow="hidden"
|
||||
{...dropzone.getRootProps()}
|
||||
>
|
||||
<input {...dropzone.getInputProps()} />
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{dropzone.isDragActive && isHandlingUpload && (
|
||||
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
|
||||
)}
|
||||
</Box>
|
||||
<DeleteImageModal />
|
||||
<ChangeBoardModal />
|
||||
@@ -121,6 +107,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
<NewGallerySessionDialog />
|
||||
<NewCanvasSessionDialog />
|
||||
<ImageContextMenu />
|
||||
<FullscreenDropzone />
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
@@ -8,13 +7,11 @@ import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const GlobalImageHotkeys = memo(() => {
|
||||
useAssertSingleton('GlobalImageHotkeys');
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(lastSelectedImage?.image_name ?? skipToken);
|
||||
const imageDTO = useAppSelector(selectLastSelectedImage);
|
||||
|
||||
if (!imageDTO) {
|
||||
return null;
|
||||
|
||||
@@ -19,7 +19,6 @@ import { $workflowCategories } from 'app/store/nanostores/workflowCategories';
|
||||
import { createStore } from 'app/store/store';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import AppDndContext from 'features/dnd/components/AppDndContext';
|
||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
|
||||
@@ -237,9 +236,7 @@ const InvokeAIUI = ({
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<AppDndContext>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</AppDndContext>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
|
||||
@@ -17,6 +17,7 @@ const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||
export const zLogNamespace = z.enum([
|
||||
'canvas',
|
||||
'config',
|
||||
'dnd',
|
||||
'events',
|
||||
'gallery',
|
||||
'generation',
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
export const STORAGE_PREFIX = '@@invokeai-';
|
||||
export const EMPTY_ARRAY = [];
|
||||
/** @knipignore */
|
||||
export const EMPTY_OBJECT = {};
|
||||
|
||||
@@ -16,7 +16,6 @@ import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMi
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addImageDeletionListeners } from 'app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners';
|
||||
import { addImageDroppedListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
|
||||
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
|
||||
import { addImagesStarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesStarred';
|
||||
import { addImagesUnstarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesUnstarred';
|
||||
@@ -93,9 +92,6 @@ addGetOpenAPISchemaListener(startAppListening);
|
||||
addWorkflowLoadRequestedListener(startAppListening);
|
||||
addUpdateAllNodesRequestedListener(startAppListening);
|
||||
|
||||
// DND
|
||||
addImageDroppedListener(startAppListening);
|
||||
|
||||
// Models
|
||||
addModelSelectedListener(startAppListening);
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('queue');
|
||||
|
||||
@@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
|
||||
log.debug({ enqueueResult } as JsonObject, t('queue.graphQueued'));
|
||||
} catch (error) {
|
||||
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
|
||||
log.error({ enqueueBatchArg } as JsonObject, t('queue.graphFailedToQueue'));
|
||||
|
||||
if (error instanceof Object && 'status' in error && error.status === 403) {
|
||||
return;
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { truncate, upperFirst } from 'lodash-es';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('queue');
|
||||
|
||||
@@ -17,7 +17,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
||||
effect: (action) => {
|
||||
const enqueueResult = action.payload;
|
||||
const arg = action.meta.arg.originalArgs;
|
||||
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
|
||||
log.debug({ enqueueResult } as JsonObject, 'Batch enqueued');
|
||||
|
||||
toast({
|
||||
id: 'QUEUE_BATCH_SUCCEEDED',
|
||||
@@ -45,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
||||
status: 'error',
|
||||
description: t('common.unknownError'),
|
||||
});
|
||||
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||
log.error({ batchConfig } as JsonObject, t('queue.batchFailedToQueue'));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
||||
description: t('common.unknownError'),
|
||||
});
|
||||
}
|
||||
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||
log.error({ batchConfig, error: serializeError(response) } as JsonObject, t('queue.batchFailedToQueue'));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import type { Result } from 'common/util/result';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
|
||||
@@ -10,10 +10,12 @@ import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGr
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
@@ -57,7 +59,17 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
}
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status: 'error',
|
||||
title: 'Failed to build graph',
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -88,7 +100,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
return;
|
||||
}
|
||||
|
||||
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
|
||||
log.debug({ batchConfig: prepareBatchResult.value } as JsonObject, 'Enqueued batch');
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { size } from 'lodash-es';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
@@ -16,12 +16,12 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
|
||||
effect: (action, { getState }) => {
|
||||
const schemaJSON = action.payload;
|
||||
|
||||
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
|
||||
log.debug({ schemaJSON: parseify(schemaJSON) } as JsonObject, 'Received OpenAPI schema');
|
||||
const { nodesAllowlist, nodesDenylist } = getState().config;
|
||||
|
||||
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
|
||||
|
||||
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
|
||||
log.debug({ nodeTemplates } as JsonObject, `Built ${size(nodeTemplates)} node templates`);
|
||||
|
||||
$templates.set(nodeTemplates);
|
||||
},
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
entityRasterized,
|
||||
entitySelected,
|
||||
inpaintMaskAdded,
|
||||
rasterLayerAdded,
|
||||
referenceImageAdded,
|
||||
referenceImageIPAdapterImageChanged,
|
||||
rgAdded,
|
||||
rgIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasInpaintMaskState,
|
||||
CanvasRasterLayerState,
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const dndDropped = createAction<{
|
||||
overData: TypesafeDroppableData;
|
||||
activeData: TypesafeDraggableData;
|
||||
}>('dnd/dndDropped');
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: dndDropped,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { activeData, overData } = action.payload;
|
||||
if (!isValidDrop(overData, activeData)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||
log.debug({ activeData, overData }, 'Image dropped');
|
||||
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
|
||||
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
|
||||
} else if (activeData.payloadType === 'NODE_FIELD') {
|
||||
log.debug({ activeData, overData }, 'Node field dropped');
|
||||
} else {
|
||||
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on IP Adapter Layer
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_IPA_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { id } = overData.context;
|
||||
dispatch(
|
||||
referenceImageIPAdapterImageChanged({
|
||||
entityIdentifier: { id, type: 'reference_image' },
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on RG Layer IP Adapter
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { id, referenceImageId } = overData.context;
|
||||
dispatch(
|
||||
rgIPAdapterImageChanged({
|
||||
entityIdentifier: { id, type: 'regional_guidance' },
|
||||
referenceImageId,
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on Raster layer
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
};
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
/**
|
||||
* Image dropped on Inpaint Mask
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_INPAINT_MASK_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||
const overrides: Partial<CanvasInpaintMaskState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
};
|
||||
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
/**
|
||||
* Image dropped on Regional Guidance
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_REGIONAL_GUIDANCE_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||
const overrides: Partial<CanvasRegionalGuidanceState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
};
|
||||
dispatch(rgAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on Raster layer
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
const defaultControlAdapter = selectDefaultControlAdapter(state);
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
controlAdapter: defaultControlAdapter,
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
overData.actionType === 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const state = getState();
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
|
||||
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
|
||||
const overrides: Partial<CanvasRegionalGuidanceState> = {
|
||||
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }],
|
||||
};
|
||||
dispatch(rgAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
overData.actionType === 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const state = getState();
|
||||
const ipAdapter = deepClone(selectDefaultIPAdapter(state));
|
||||
ipAdapter.image = imageDTOToImageWithDims(activeData.payload.imageDTO);
|
||||
const overrides: Partial<CanvasReferenceImageState> = {
|
||||
ipAdapter,
|
||||
};
|
||||
dispatch(referenceImageAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on Raster layer
|
||||
*/
|
||||
if (overData.actionType === 'REPLACE_LAYER_WITH_IMAGE' && activeData.payloadType === 'IMAGE_DTO') {
|
||||
const state = getState();
|
||||
const { entityIdentifier } = overData.context;
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
|
||||
dispatch(entitySelected({ entityIdentifier }));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on node image field
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_NODES_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { fieldName, nodeId } = overData.context;
|
||||
dispatch(
|
||||
fieldImageValueChanged({
|
||||
nodeId,
|
||||
fieldName,
|
||||
value: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image selected for compare
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SELECT_FOR_COMPARE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on user board
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_TO_BOARD' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
const { boardId } = overData.context;
|
||||
dispatch(
|
||||
imagesApi.endpoints.addImageToBoard.initiate({
|
||||
imageDTO,
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on 'none' board
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'REMOVE_FROM_BOARD' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
dispatch(
|
||||
imagesApi.endpoints.removeImageFromBoard.initiate({
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on upscale initial image
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiple images dropped on user board
|
||||
*/
|
||||
if (overData.actionType === 'ADD_TO_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
|
||||
const imageDTOs = getState().gallery.selection;
|
||||
const { boardId } = overData.context;
|
||||
dispatch(
|
||||
imagesApi.endpoints.addImagesToBoard.initiate({
|
||||
imageDTOs,
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiple images dropped on 'none' board
|
||||
*/
|
||||
if (overData.actionType === 'REMOVE_FROM_BOARD' && activeData.payloadType === 'GALLERY_SELECTION') {
|
||||
const imageDTOs = getState().gallery.selection;
|
||||
dispatch(
|
||||
imagesApi.endpoints.removeImagesFromBoard.initiate({
|
||||
imageDTOs,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,18 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import {
|
||||
entityRasterized,
|
||||
entitySelected,
|
||||
referenceImageIPAdapterImageChanged,
|
||||
rgIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { omit } from 'lodash-es';
|
||||
@@ -51,93 +41,45 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
|
||||
log.debug({ imageDTO }, 'Image uploaded');
|
||||
|
||||
const { postUploadAction } = action.meta.arg.originalArgs;
|
||||
const boardId = imageDTO.board_id ?? 'none';
|
||||
|
||||
if (!postUploadAction) {
|
||||
return;
|
||||
}
|
||||
if (action.meta.arg.originalArgs.withToast) {
|
||||
const DEFAULT_UPLOADED_TOAST = {
|
||||
id: 'IMAGE_UPLOADED',
|
||||
title: t('toast.imageUploaded'),
|
||||
status: 'success',
|
||||
} as const;
|
||||
|
||||
const DEFAULT_UPLOADED_TOAST = {
|
||||
id: 'IMAGE_UPLOADED',
|
||||
title: t('toast.imageUploaded'),
|
||||
status: 'success',
|
||||
} as const;
|
||||
|
||||
// default action - just upload and alert user
|
||||
if (postUploadAction.type === 'TOAST') {
|
||||
const boardId = imageDTO.board_id ?? 'none';
|
||||
// default action - just upload and alert user
|
||||
if (lastUploadedToastTimeout !== null) {
|
||||
window.clearTimeout(lastUploadedToastTimeout);
|
||||
}
|
||||
const toastApi = toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
title: postUploadAction.title || DEFAULT_UPLOADED_TOAST.title,
|
||||
title: DEFAULT_UPLOADED_TOAST.title,
|
||||
description: getUploadedToastDescription(boardId, state),
|
||||
duration: null, // we will close the toast manually
|
||||
});
|
||||
lastUploadedToastTimeout = window.setTimeout(() => {
|
||||
toastApi.close();
|
||||
}, 3000);
|
||||
/**
|
||||
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
|
||||
* the user's gallery board and view selection:
|
||||
* - User uploads multiple images
|
||||
* - A couple uploads finish, but others are pending still
|
||||
* - User changes the board selection
|
||||
* - Pending uploads finish and change the board back to the original board
|
||||
* - User is confused as to why the board changed
|
||||
*
|
||||
* Default to true to not require _all_ image upload handlers to set this value
|
||||
*/
|
||||
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
|
||||
if (isFirstUploadOfBatch) {
|
||||
dispatch(boardIdSelected({ boardId }));
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction.type === 'SET_UPSCALE_INITIAL_IMAGE') {
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: 'set as upscale initial image',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction.type === 'SET_IPA_IMAGE') {
|
||||
const { id } = postUploadAction;
|
||||
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: { id, type: 'reference_image' }, imageDTO }));
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction.type === 'SET_RG_IP_ADAPTER_IMAGE') {
|
||||
const { id, referenceImageId } = postUploadAction;
|
||||
dispatch(
|
||||
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, referenceImageId, imageDTO })
|
||||
);
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction.type === 'SET_NODES_IMAGE') {
|
||||
const { nodeId, fieldName } = postUploadAction;
|
||||
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction.type === 'REPLACE_LAYER_WITH_IMAGE') {
|
||||
const { entityIdentifier } = postUploadAction;
|
||||
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
dispatch(entityRasterized({ entityIdentifier, imageObject, position: { x, y }, replaceObjects: true }));
|
||||
dispatch(entitySelected({ entityIdentifier }));
|
||||
return;
|
||||
/**
|
||||
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
|
||||
* the user's gallery board and view selection:
|
||||
* - User uploads multiple images
|
||||
* - A couple uploads finish, but others are pending still
|
||||
* - User changes the board selection
|
||||
* - Pending uploads finish and change the board back to the original board
|
||||
* - User is confused as to why the board changed
|
||||
*
|
||||
* Default to true to not require _all_ image upload handlers to set this value
|
||||
*/
|
||||
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
|
||||
if (isFirstUploadOfBatch) {
|
||||
dispatch(boardIdSelected({ boardId }));
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import {
|
||||
controlLayerModelChanged,
|
||||
referenceImageIPAdapterModelChanged,
|
||||
@@ -41,6 +40,7 @@ import {
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT5EncoderModelConfig,
|
||||
} from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
@@ -85,7 +85,7 @@ type ModelHandler = (
|
||||
models: AnyModelConfig[],
|
||||
state: RootState,
|
||||
dispatch: AppDispatch,
|
||||
log: Logger<SerializableObject>
|
||||
log: Logger<JsonObject>
|
||||
) => undefined;
|
||||
|
||||
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
@@ -164,7 +164,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
// We have a VAE selected, need to check if it is available
|
||||
|
||||
// Grab just the VAE models
|
||||
const vaeModels = models.filter(isNonFluxVAEModelConfig);
|
||||
const vaeModels = models.filter((m) => isNonFluxVAEModelConfig(m));
|
||||
|
||||
// If the current VAE model is available, we don't need to do anything
|
||||
if (vaeModels.some((m) => m.key === selectedVAEModel.key)) {
|
||||
@@ -297,7 +297,7 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedT5EncoderModel = state.params.t5EncoderModel;
|
||||
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
|
||||
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
|
||||
@@ -325,7 +325,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
|
||||
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
|
||||
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {
|
||||
@@ -353,7 +353,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedFLUXVAEModel = state.params.fluxVAE;
|
||||
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
|
||||
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedFLUXVAEModel && fluxVAEModels.some((m) => m.key === selectedFLUXVAEModel.key)) {
|
||||
|
||||
@@ -4,8 +4,10 @@ import { atom } from 'nanostores';
|
||||
/**
|
||||
* A fallback non-writable atom that always returns `false`, used when a nanostores atom is only conditionally available
|
||||
* in a hook or component.
|
||||
*
|
||||
* @knipignore
|
||||
*/
|
||||
// export const $false: ReadableAtom<boolean> = atom(false);
|
||||
export const $false: ReadableAtom<boolean> = atom(false);
|
||||
/**
|
||||
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
|
||||
* in a hook or component.
|
||||
|
||||
@@ -3,7 +3,6 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
@@ -37,6 +36,7 @@ import undoable from 'redux-undo';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { api } from 'services/api';
|
||||
import { authToastMiddleware } from 'services/api/authToastMiddleware';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
import { STORAGE_PREFIX } from './constants';
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
@@ -139,7 +139,7 @@ const unserialize: UnserializeFunction = (data, key) => {
|
||||
{
|
||||
persistedData: parsed,
|
||||
rehydratedData: transformed,
|
||||
diff: diff(parsed, transformed) as SerializableObject, // this is always serializable
|
||||
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
|
||||
},
|
||||
`Rehydrated slice "${key}"`
|
||||
);
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
import type { ChakraProps, FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||
import { IAILoadingImageFallback, IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
|
||||
import IAIDraggable from './IAIDraggable';
|
||||
import IAIDroppable from './IAIDroppable';
|
||||
|
||||
const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
|
||||
|
||||
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
|
||||
|
||||
const baseStyles: SystemStyleObject = {
|
||||
touchAction: 'none',
|
||||
userSelect: 'none',
|
||||
webkitUserSelect: 'none',
|
||||
};
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
...baseStyles,
|
||||
'.gallery-image-container::before': {
|
||||
content: '""',
|
||||
display: 'inline-block',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
pointerEvents: 'none',
|
||||
borderRadius: 'base',
|
||||
},
|
||||
'&[data-selected="selected"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&[data-selected="selectedForCompare"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
'&:hover>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected="selected"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected="selectedForCompare"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
};
|
||||
|
||||
type IAIDndImageProps = FlexProps & {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
||||
withMetadataOverlay?: boolean;
|
||||
isDragDisabled?: boolean;
|
||||
isDropDisabled?: boolean;
|
||||
isUploadDisabled?: boolean;
|
||||
minSize?: number;
|
||||
postUploadAction?: PostUploadAction;
|
||||
imageSx?: ChakraProps['sx'];
|
||||
fitContainer?: boolean;
|
||||
droppableData?: TypesafeDroppableData;
|
||||
draggableData?: TypesafeDraggableData;
|
||||
dropLabel?: string;
|
||||
isSelected?: boolean;
|
||||
isSelectedForCompare?: boolean;
|
||||
thumbnail?: boolean;
|
||||
noContentFallback?: ReactElement;
|
||||
useThumbailFallback?: boolean;
|
||||
withHoverOverlay?: boolean;
|
||||
children?: JSX.Element;
|
||||
uploadElement?: ReactNode;
|
||||
dataTestId?: string;
|
||||
};
|
||||
|
||||
const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
const {
|
||||
imageDTO,
|
||||
onError,
|
||||
onClick,
|
||||
withMetadataOverlay = false,
|
||||
isDropDisabled = false,
|
||||
isDragDisabled = false,
|
||||
isUploadDisabled = false,
|
||||
minSize = 24,
|
||||
postUploadAction,
|
||||
imageSx,
|
||||
fitContainer = false,
|
||||
droppableData,
|
||||
draggableData,
|
||||
dropLabel,
|
||||
isSelected = false,
|
||||
isSelectedForCompare = false,
|
||||
thumbnail = false,
|
||||
noContentFallback = defaultNoContentFallback,
|
||||
uploadElement = defaultUploadElement,
|
||||
useThumbailFallback,
|
||||
withHoverOverlay = false,
|
||||
children,
|
||||
dataTestId,
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const openInNewTab = useCallback(
|
||||
(e: MouseEvent) => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (e.button !== 1) {
|
||||
return;
|
||||
}
|
||||
window.open(imageDTO.image_url, '_blank');
|
||||
},
|
||||
[imageDTO]
|
||||
);
|
||||
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useImageContextMenu(imageDTO, ref);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
width="full"
|
||||
height="full"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
position="relative"
|
||||
minW={minSize ? minSize : undefined}
|
||||
minH={minSize ? minSize : undefined}
|
||||
userSelect="none"
|
||||
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
|
||||
sx={withHoverOverlay ? sx : baseStyles}
|
||||
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
|
||||
{...rest}
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
className="gallery-image-container"
|
||||
w="full"
|
||||
h="full"
|
||||
position={fitContainer ? 'absolute' : 'relative'}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Image
|
||||
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
||||
fallbackStrategy="beforeLoadOrError"
|
||||
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
|
||||
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
|
||||
onError={onError}
|
||||
draggable={false}
|
||||
w={imageDTO.width}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
borderRadius="base"
|
||||
sx={imageSx}
|
||||
data-testid={dataTestId}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
<UploadButton
|
||||
isUploadDisabled={isUploadDisabled}
|
||||
postUploadAction={postUploadAction}
|
||||
uploadElement={uploadElement}
|
||||
minSize={minSize}
|
||||
/>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
onClick={onClick}
|
||||
onAuxClick={openInNewTab}
|
||||
/>
|
||||
)}
|
||||
{children}
|
||||
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIDndImage);
|
||||
|
||||
const UploadButton = memo(
|
||||
({
|
||||
isUploadDisabled,
|
||||
postUploadAction,
|
||||
uploadElement,
|
||||
minSize,
|
||||
}: {
|
||||
isUploadDisabled: boolean;
|
||||
postUploadAction?: PostUploadAction;
|
||||
uploadElement: ReactNode;
|
||||
minSize: number;
|
||||
}) => {
|
||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||
postUploadAction,
|
||||
isDisabled: isUploadDisabled,
|
||||
});
|
||||
|
||||
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
|
||||
const styles: SystemStyleObject = {
|
||||
minH: minSize,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
color: 'base.500',
|
||||
};
|
||||
if (!isUploadDisabled) {
|
||||
Object.assign(styles, {
|
||||
cursor: 'pointer',
|
||||
bg: 'base.700',
|
||||
_hover: {
|
||||
bg: 'base.650',
|
||||
color: 'base.300',
|
||||
},
|
||||
});
|
||||
}
|
||||
return styles;
|
||||
}, [isUploadDisabled, minSize]);
|
||||
|
||||
return (
|
||||
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
|
||||
<input {...getUploadInputProps()} />
|
||||
{uploadElement}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
UploadButton.displayName = 'UploadButton';
|
||||
@@ -1,38 +0,0 @@
|
||||
import type { BoxProps } from '@invoke-ai/ui-library';
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||
import type { TypesafeDraggableData } from 'features/dnd/types';
|
||||
import { memo, useRef } from 'react';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
type IAIDraggableProps = BoxProps & {
|
||||
disabled?: boolean;
|
||||
data?: TypesafeDraggableData;
|
||||
};
|
||||
|
||||
const IAIDraggable = (props: IAIDraggableProps) => {
|
||||
const { data, disabled, ...rest } = props;
|
||||
const dndId = useRef(uuidv4());
|
||||
|
||||
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
|
||||
id: dndId.current,
|
||||
disabled,
|
||||
data,
|
||||
});
|
||||
|
||||
return (
|
||||
<Box
|
||||
ref={setNodeRef}
|
||||
position="absolute"
|
||||
w="full"
|
||||
h="full"
|
||||
top={0}
|
||||
insetInlineStart={0}
|
||||
{...attributes}
|
||||
{...listeners}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIDraggable);
|
||||
@@ -1,64 +0,0 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
isOver: boolean;
|
||||
label?: string;
|
||||
withBackdrop?: boolean;
|
||||
};
|
||||
|
||||
const IAIDropOverlay = (props: Props) => {
|
||||
const { isOver, label, withBackdrop = true } = props;
|
||||
return (
|
||||
<Flex position="absolute" top={0} right={0} bottom={0} left={0}>
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
w="full"
|
||||
h="full"
|
||||
bg={withBackdrop ? 'base.900' : 'transparent'}
|
||||
opacity={0.7}
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
/>
|
||||
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0.5}
|
||||
right={0.5}
|
||||
bottom={0.5}
|
||||
left={0.5}
|
||||
opacity={1}
|
||||
borderWidth={1.5}
|
||||
borderColor={isOver ? 'invokeYellow.300' : 'base.500'}
|
||||
borderRadius="base"
|
||||
borderStyle="dashed"
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
{label && (
|
||||
<Text
|
||||
fontSize="lg"
|
||||
fontWeight="semibold"
|
||||
color={isOver ? 'invokeYellow.300' : 'base.500'}
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
textAlign="center"
|
||||
>
|
||||
{label}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIDropOverlay);
|
||||
@@ -1,46 +0,0 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||
import type { TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { memo, useRef } from 'react';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import IAIDropOverlay from './IAIDropOverlay';
|
||||
|
||||
type IAIDroppableProps = {
|
||||
dropLabel?: string;
|
||||
disabled?: boolean;
|
||||
data?: TypesafeDroppableData;
|
||||
};
|
||||
|
||||
const IAIDroppable = (props: IAIDroppableProps) => {
|
||||
const { dropLabel, data, disabled } = props;
|
||||
const dndId = useRef(uuidv4());
|
||||
|
||||
const { isOver, setNodeRef, active } = useDroppableTypesafe({
|
||||
id: dndId.current,
|
||||
disabled,
|
||||
data,
|
||||
});
|
||||
|
||||
return (
|
||||
<Box
|
||||
ref={setNodeRef}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
w="full"
|
||||
h="full"
|
||||
pointerEvents={active ? 'auto' : 'none'}
|
||||
>
|
||||
<AnimatePresence>
|
||||
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
||||
</AnimatePresence>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIDroppable);
|
||||
@@ -1,24 +0,0 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Skeleton } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
|
||||
const skeletonStyles: SystemStyleObject = {
|
||||
position: 'relative',
|
||||
height: 'full',
|
||||
width: 'full',
|
||||
'::before': {
|
||||
content: "''",
|
||||
display: 'block',
|
||||
pt: '100%',
|
||||
},
|
||||
};
|
||||
|
||||
const IAIFillSkeleton = () => {
|
||||
return (
|
||||
<Skeleton sx={skeletonStyles}>
|
||||
<Box position="absolute" top={0} insetInlineStart={0} height="full" width="full" />
|
||||
</Skeleton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIFillSkeleton);
|
||||
@@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
type Props = { image: ImageDTO | undefined };
|
||||
|
||||
export const IAILoadingImageFallback = memo((props: Props) => {
|
||||
const IAILoadingImageFallback = memo((props: Props) => {
|
||||
if (props.image) {
|
||||
return (
|
||||
<Skeleton
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import { Badge, Flex } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
type ImageMetadataOverlayProps = {
|
||||
imageDTO: ImageDTO;
|
||||
};
|
||||
|
||||
const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => {
|
||||
return (
|
||||
<Flex
|
||||
pointerEvents="none"
|
||||
flexDirection="column"
|
||||
position="absolute"
|
||||
top={0}
|
||||
insetInlineStart={0}
|
||||
p={2}
|
||||
alignItems="flex-start"
|
||||
gap={2}
|
||||
>
|
||||
<Badge variant="solid" colorScheme="base">
|
||||
{imageDTO.width} × {imageDTO.height}
|
||||
</Badge>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImageMetadataOverlay);
|
||||
@@ -1,89 +0,0 @@
|
||||
import { Box, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { memo } from 'react';
|
||||
import type { DropzoneState } from 'react-dropzone';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
type ImageUploadOverlayProps = {
|
||||
dropzone: DropzoneState;
|
||||
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
|
||||
};
|
||||
|
||||
const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
||||
const { dropzone, setIsHandlingUpload } = props;
|
||||
|
||||
useHotkeys(
|
||||
'esc',
|
||||
() => {
|
||||
setIsHandlingUpload(false);
|
||||
},
|
||||
[setIsHandlingUpload]
|
||||
);
|
||||
|
||||
return (
|
||||
<Box position="absolute" top={0} right={0} bottom={0} left={0} zIndex={999} backdropFilter="blur(20px)">
|
||||
<Flex position="absolute" top={0} right={0} bottom={0} left={0} bg="base.900" opacity={0.7} />
|
||||
<Flex
|
||||
position="absolute"
|
||||
flexDir="column"
|
||||
gap={4}
|
||||
top={2}
|
||||
right={2}
|
||||
bottom={2}
|
||||
left={2}
|
||||
opacity={1}
|
||||
borderWidth={2}
|
||||
borderColor={dropzone.isDragAccept ? 'invokeYellow.300' : 'error.500'}
|
||||
borderRadius="base"
|
||||
borderStyle="dashed"
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
color={dropzone.isDragReject ? 'error.300' : undefined}
|
||||
>
|
||||
{dropzone.isDragAccept && <DragAcceptMessage />}
|
||||
{!dropzone.isDragAccept && <DragRejectMessage />}
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
export default memo(ImageUploadOverlay);
|
||||
|
||||
const DragAcceptMessage = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector(selectSelectedBoardId);
|
||||
const boardName = useBoardName(selectedBoardId);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Heading size="lg">{t('gallery.dropToUpload')}</Heading>
|
||||
<Heading size="md">{t('toast.imagesWillBeAddedTo', { boardName })}</Heading>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const DragRejectMessage = () => {
|
||||
const { t } = useTranslation();
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
|
||||
if (maxImageUploadCount === undefined) {
|
||||
return (
|
||||
<>
|
||||
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
|
||||
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc')}</Heading>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Heading size="lg">{t('toast.invalidUpload')}</Heading>
|
||||
<Heading size="md">{t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount })}</Heading>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { PopoverProps } from '@invoke-ai/ui-library';
|
||||
import commercialLicenseBg from 'public/assets/images/commercial-license-bg.png';
|
||||
import denoisingStrength from 'public/assets/images/denoising-strength.png';
|
||||
|
||||
export type Feature =
|
||||
| 'clipSkip'
|
||||
@@ -125,7 +126,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
|
||||
},
|
||||
infillMethod: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158841-infill-and-scaling',
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
|
||||
},
|
||||
scaleBeforeProcessing: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000158841',
|
||||
@@ -138,6 +139,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
|
||||
},
|
||||
paramDenoisingStrength: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000094998-image-to-image',
|
||||
image: denoisingStrength,
|
||||
},
|
||||
paramHrf: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000096700-how-can-i-get-larger-images-what-does-upscaling-do-',
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import { autoScrollForElements } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/element';
|
||||
import { autoScrollForExternal } from '@atlaskit/pragmatic-drag-and-drop-auto-scroll/external';
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import type { OverlayScrollbarsComponentRef } from 'overlayscrollbars-react';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties, PropsWithChildren } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = PropsWithChildren & {
|
||||
maxHeight?: ChakraProps['maxHeight'];
|
||||
@@ -11,17 +15,38 @@ type Props = PropsWithChildren & {
|
||||
overflowY?: 'hidden' | 'scroll';
|
||||
};
|
||||
|
||||
const styles: CSSProperties = { height: '100%', width: '100%' };
|
||||
const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 };
|
||||
|
||||
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
|
||||
const overlayscrollbarsOptions = useMemo(
|
||||
() => getOverlayScrollbarsParams(overflowX, overflowY).options,
|
||||
[overflowX, overflowY]
|
||||
);
|
||||
const [os, osRef] = useState<OverlayScrollbarsComponentRef | null>(null);
|
||||
useEffect(() => {
|
||||
const osInstance = os?.osInstance();
|
||||
|
||||
if (!osInstance) {
|
||||
return;
|
||||
}
|
||||
|
||||
const element = osInstance.elements().viewport;
|
||||
|
||||
// `pragmatic-drag-and-drop-auto-scroll` requires the element to have `overflow-y: scroll` or `overflow-y: auto`
|
||||
// else it logs an ugly warning. In our case, using a custom scrollbar library, it will be 'hidden' by default.
|
||||
// To prevent the erroneous warning, we temporarily set the overflow-y to 'scroll' and then revert it back.
|
||||
const overflowY = element.style.overflowY; // starts 'hidden'
|
||||
element.style.setProperty('overflow-y', 'scroll', 'important');
|
||||
const cleanup = combine(autoScrollForElements({ element }), autoScrollForExternal({ element }));
|
||||
element.style.setProperty('overflow-y', overflowY);
|
||||
|
||||
return cleanup;
|
||||
}, [os]);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" maxHeight={maxHeight} position="relative">
|
||||
<Box position="absolute" top={0} left={0} right={0} bottom={0}>
|
||||
<OverlayScrollbarsComponent defer style={styles} options={overlayscrollbarsOptions}>
|
||||
<OverlayScrollbarsComponent ref={osRef} style={styles} options={overlayscrollbarsOptions}>
|
||||
{children}
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
|
||||
57
invokeai/frontend/web/src/common/components/WavyLine.tsx
Normal file
57
invokeai/frontend/web/src/common/components/WavyLine.tsx
Normal file
@@ -0,0 +1,57 @@
|
||||
type Props = {
|
||||
/**
|
||||
* The amplitude of the wave. 0 is a straight line, higher values create more pronounced waves.
|
||||
*/
|
||||
amplitude: number;
|
||||
/**
|
||||
* The number of segments in the line. More segments create a smoother wave.
|
||||
*/
|
||||
segments?: number;
|
||||
/**
|
||||
* The color of the wave.
|
||||
*/
|
||||
stroke: string;
|
||||
/**
|
||||
* The width of the wave.
|
||||
*/
|
||||
strokeWidth: number;
|
||||
/**
|
||||
* The width of the SVG.
|
||||
*/
|
||||
width: number;
|
||||
/**
|
||||
* The height of the SVG.
|
||||
*/
|
||||
height: number;
|
||||
};
|
||||
|
||||
const WavyLine = ({ amplitude, stroke, strokeWidth, width, height, segments = 5 }: Props) => {
|
||||
// Calculate the path dynamically based on waviness
|
||||
const generatePath = () => {
|
||||
if (amplitude === 0) {
|
||||
// If waviness is 0, return a straight line
|
||||
return `M0,${height / 2} L${width},${height / 2}`;
|
||||
}
|
||||
|
||||
const clampedAmplitude = Math.min(height / 2, amplitude); // Cap amplitude to half the height
|
||||
const segmentWidth = width / segments;
|
||||
let path = `M0,${height / 2}`; // Start in the middle of the left edge
|
||||
|
||||
// Loop through each segment and alternate the y position to create waves
|
||||
for (let i = 1; i <= segments; i++) {
|
||||
const x = i * segmentWidth;
|
||||
const y = height / 2 + (i % 2 === 0 ? clampedAmplitude : -clampedAmplitude);
|
||||
path += ` Q${x - segmentWidth / 2},${y} ${x},${height / 2}`;
|
||||
}
|
||||
|
||||
return path;
|
||||
};
|
||||
|
||||
return (
|
||||
<svg width={width} height={height} viewBox={`0 0 ${width} ${height}`} xmlns="http://www.w3.org/2000/svg">
|
||||
<path d={generatePath()} fill="none" stroke={stroke} strokeWidth={strokeWidth} />
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export default WavyLine;
|
||||
@@ -1,124 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import type { Accept, FileRejection } from 'react-dropzone';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
import type { PostUploadAction } from 'services/api/types';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
const accept: Accept = {
|
||||
'image/png': ['.png'],
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
};
|
||||
|
||||
export const useFullscreenDropzone = () => {
|
||||
useAssertSingleton('useFullscreenDropzone');
|
||||
const { t } = useTranslation();
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
|
||||
const [uploadImage] = useUploadImageMutation();
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
|
||||
const getPostUploadAction = useCallback((): PostUploadAction => {
|
||||
if (activeTabName === 'upscaling') {
|
||||
return { type: 'SET_UPSCALE_INITIAL_IMAGE' };
|
||||
} else {
|
||||
return { type: 'TOAST' };
|
||||
}
|
||||
}, [activeTabName]);
|
||||
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||
if (fileRejections.length > 0) {
|
||||
const errors = fileRejections.map((rejection) => ({
|
||||
errors: rejection.errors.map(({ message }) => message),
|
||||
file: rejection.file.path,
|
||||
}));
|
||||
log.error({ errors }, 'Invalid upload');
|
||||
const description =
|
||||
maxImageUploadCount === undefined
|
||||
? t('toast.uploadFailedInvalidUploadDesc')
|
||||
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
|
||||
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
title: t('toast.uploadFailed'),
|
||||
description,
|
||||
status: 'error',
|
||||
});
|
||||
|
||||
setIsHandlingUpload(false);
|
||||
return;
|
||||
}
|
||||
|
||||
for (const [i, file] of acceptedFiles.entries()) {
|
||||
uploadImage({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
postUploadAction: getPostUploadAction(),
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
// The `imageUploaded` listener does some extra logic, like switching to the asset view on upload on the
|
||||
// first upload of a "batch".
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
});
|
||||
}
|
||||
|
||||
setIsHandlingUpload(false);
|
||||
},
|
||||
[t, maxImageUploadCount, uploadImage, getPostUploadAction, autoAddBoardId]
|
||||
);
|
||||
|
||||
const onDragOver = useCallback(() => {
|
||||
setIsHandlingUpload(true);
|
||||
}, []);
|
||||
|
||||
const onDragLeave = useCallback(() => {
|
||||
setIsHandlingUpload(false);
|
||||
}, []);
|
||||
|
||||
const dropzone = useDropzone({
|
||||
accept,
|
||||
noClick: true,
|
||||
onDrop,
|
||||
onDragOver,
|
||||
onDragLeave,
|
||||
noKeyboard: true,
|
||||
multiple: maxImageUploadCount === undefined || maxImageUploadCount > 1,
|
||||
maxFiles: maxImageUploadCount,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
// This is a hack to allow pasting images into the uploader
|
||||
const handlePaste = (e: ClipboardEvent) => {
|
||||
if (!dropzone.inputRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (e.clipboardData?.files) {
|
||||
// Set the files on the dropzone.inputRef
|
||||
dropzone.inputRef.current.files = e.clipboardData.files;
|
||||
// Dispatch the change event, dropzone catches this and we get to use its own validation
|
||||
dropzone.inputRef.current?.dispatchEvent(new Event('change', { bubbles: true }));
|
||||
}
|
||||
};
|
||||
|
||||
// Add the paste event listener
|
||||
document.addEventListener('paste', handlePaste);
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('paste', handlePaste);
|
||||
};
|
||||
}, [dropzone.inputRef]);
|
||||
|
||||
return { dropzone, isHandlingUpload, setIsHandlingUpload };
|
||||
};
|
||||
@@ -1,3 +1,5 @@
|
||||
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
@@ -7,14 +9,23 @@ import { useCallback } from 'react';
|
||||
import type { FileRejection } from 'react-dropzone';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
import type { PostUploadAction } from 'services/api/types';
|
||||
import { PiUploadBold } from 'react-icons/pi';
|
||||
import { uploadImages, useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import type { SetOptional } from 'type-fest';
|
||||
|
||||
type UseImageUploadButtonArgs = {
|
||||
postUploadAction?: PostUploadAction;
|
||||
isDisabled?: boolean;
|
||||
allowMultiple?: boolean;
|
||||
};
|
||||
type UseImageUploadButtonArgs =
|
||||
| {
|
||||
isDisabled?: boolean;
|
||||
allowMultiple: false;
|
||||
onUpload?: (imageDTO: ImageDTO) => void;
|
||||
}
|
||||
| {
|
||||
isDisabled?: boolean;
|
||||
allowMultiple: true;
|
||||
onUpload?: (imageDTOs: ImageDTO[]) => void;
|
||||
};
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
@@ -37,30 +48,46 @@ const log = logger('gallery');
|
||||
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
|
||||
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
|
||||
*/
|
||||
export const useImageUploadButton = ({
|
||||
postUploadAction,
|
||||
isDisabled,
|
||||
allowMultiple = false,
|
||||
}: UseImageUploadButtonArgs) => {
|
||||
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const [uploadImage] = useUploadImageMutation();
|
||||
const [uploadImage, request] = useUploadImageMutation();
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onDropAccepted = useCallback(
|
||||
(files: File[]) => {
|
||||
for (const [i, file] of files.entries()) {
|
||||
uploadImage({
|
||||
async (files: File[]) => {
|
||||
if (!allowMultiple) {
|
||||
if (files.length > 1) {
|
||||
log.warn('Multiple files dropped but only one allowed');
|
||||
return;
|
||||
}
|
||||
const file = files[0];
|
||||
assert(file !== undefined); // should never happen
|
||||
const imageDTO = await uploadImage({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
postUploadAction: postUploadAction ?? { type: 'TOAST' },
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
});
|
||||
}).unwrap();
|
||||
if (onUpload) {
|
||||
onUpload(imageDTO);
|
||||
}
|
||||
} else {
|
||||
//
|
||||
const imageDTOs = await uploadImages(
|
||||
files.map((file) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
}))
|
||||
);
|
||||
if (onUpload) {
|
||||
onUpload(imageDTOs);
|
||||
}
|
||||
}
|
||||
},
|
||||
[autoAddBoardId, postUploadAction, uploadImage]
|
||||
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
|
||||
);
|
||||
|
||||
const onDropRejected = useCallback(
|
||||
@@ -103,5 +130,42 @@ export const useImageUploadButton = ({
|
||||
maxFiles: maxImageUploadCount,
|
||||
});
|
||||
|
||||
return { getUploadButtonProps, getUploadInputProps, openUploader };
|
||||
return { getUploadButtonProps, getUploadInputProps, openUploader, request };
|
||||
};
|
||||
|
||||
const sx = {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
borderWidth: 0,
|
||||
borderRadius: 'base',
|
||||
'&[data-error=true]': {
|
||||
borderWidth: 1,
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const UploadImageButton = ({
|
||||
isDisabled = false,
|
||||
onUpload,
|
||||
isError = false,
|
||||
...rest
|
||||
}: {
|
||||
onUpload?: (imageDTO: ImageDTO) => void;
|
||||
isError?: boolean;
|
||||
} & SetOptional<IconButtonProps, 'aria-label'>) => {
|
||||
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload });
|
||||
return (
|
||||
<>
|
||||
<IconButton
|
||||
aria-label="Upload image"
|
||||
variant="ghost"
|
||||
sx={sx}
|
||||
data-error={isError}
|
||||
icon={<PiUploadBold />}
|
||||
isLoading={uploadApi.request.isLoading}
|
||||
{...rest}
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
import { getPrefixedId, nanoid } from 'features/controlLayers/konva/util';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNanoid = (prefix?: string) => {
|
||||
const id = useMemo(() => {
|
||||
if (prefix) {
|
||||
return getPrefixedId(prefix);
|
||||
} else {
|
||||
return nanoid();
|
||||
}
|
||||
}, [prefix]);
|
||||
|
||||
return id;
|
||||
};
|
||||
@@ -1,12 +0,0 @@
|
||||
type SerializableValue =
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| null
|
||||
| undefined
|
||||
| SerializableValue[]
|
||||
| readonly SerializableValue[]
|
||||
| SerializableObject;
|
||||
export type SerializableObject = {
|
||||
[k: string | number]: SerializableValue;
|
||||
};
|
||||
@@ -0,0 +1,6 @@
|
||||
import type { AssertionError } from 'tsafe';
|
||||
|
||||
export function extractMessageFromAssertionError(error: AssertionError): string | null {
|
||||
const match = error.message.match(/Wrong assertion encountered: "(.*)"/);
|
||||
return match ? (match[1] ?? null) : null;
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
import type { CSSProperties } from 'react';
|
||||
|
||||
/**
|
||||
* Chakra's Tooltip's method of finding the nearest scroll parent has a problem - it assumes the first parent with
|
||||
* `overflow: hidden` is the scroll parent. In this case, the Collapse component has that style, but isn't scrollable
|
||||
* itself. The result is that the tooltip does not close on scroll, because the scrolling happens higher up in the DOM.
|
||||
*
|
||||
* As a hacky workaround, we can set the overflow to `visible`, which allows the scroll parent search to continue up to
|
||||
* the actual scroll parent (in this case, the OverlayScrollbarsComponent in BoardsListWrapper).
|
||||
*
|
||||
* See: https://github.com/chakra-ui/chakra-ui/issues/7871#issuecomment-2453780958
|
||||
*/
|
||||
export const fixTooltipCloseOnScrollStyles: CSSProperties = {
|
||||
overflow: 'visible',
|
||||
};
|
||||
@@ -1,38 +1,26 @@
|
||||
import { Grid, GridItem } from '@invoke-ai/ui-library';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type {
|
||||
AddControlLayerFromImageDropData,
|
||||
AddGlobalReferenceImageFromImageDropData,
|
||||
AddRasterLayerFromImageDropData,
|
||||
AddRegionalReferenceImageFromImageDropData,
|
||||
} from 'features/dnd/types';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { newCanvasEntityFromImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const addRasterLayerFromImageDropData: AddRasterLayerFromImageDropData = {
|
||||
id: 'add-raster-layer-from-image-drop-data',
|
||||
actionType: 'ADD_RASTER_LAYER_FROM_IMAGE',
|
||||
};
|
||||
|
||||
const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
|
||||
id: 'add-control-layer-from-image-drop-data',
|
||||
actionType: 'ADD_CONTROL_LAYER_FROM_IMAGE',
|
||||
};
|
||||
|
||||
const addRegionalReferenceImageFromImageDropData: AddRegionalReferenceImageFromImageDropData = {
|
||||
id: 'add-control-layer-from-image-drop-data',
|
||||
actionType: 'ADD_REGIONAL_REFERENCE_IMAGE_FROM_IMAGE',
|
||||
};
|
||||
|
||||
const addGlobalReferenceImageFromImageDropData: AddGlobalReferenceImageFromImageDropData = {
|
||||
id: 'add-control-layer-from-image-drop-data',
|
||||
actionType: 'ADD_GLOBAL_REFERENCE_IMAGE_FROM_IMAGE',
|
||||
};
|
||||
const addRasterLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({ type: 'raster_layer' });
|
||||
const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
|
||||
type: 'control_layer',
|
||||
});
|
||||
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
|
||||
type: 'regional_guidance_with_reference_image',
|
||||
});
|
||||
const addGlobalReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
|
||||
type: 'reference_image',
|
||||
});
|
||||
|
||||
export const CanvasDropArea = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const imageViewer = useImageViewer();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
if (imageViewer.isOpen) {
|
||||
return null;
|
||||
@@ -51,28 +39,36 @@ export const CanvasDropArea = memo(() => {
|
||||
pointerEvents="none"
|
||||
>
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable
|
||||
dropLabel={t('controlLayers.canvasContextMenu.newRasterLayer')}
|
||||
data={addRasterLayerFromImageDropData}
|
||||
<DndDropTarget
|
||||
dndTarget={newCanvasEntityFromImageDndTarget}
|
||||
dndTargetData={addRasterLayerFromImageDndTargetData}
|
||||
label={t('controlLayers.canvasContextMenu.newRasterLayer')}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable
|
||||
dropLabel={t('controlLayers.canvasContextMenu.newControlLayer')}
|
||||
data={addControlLayerFromImageDropData}
|
||||
<DndDropTarget
|
||||
dndTarget={newCanvasEntityFromImageDndTarget}
|
||||
dndTargetData={addControlLayerFromImageDndTargetData}
|
||||
label={t('controlLayers.canvasContextMenu.newControlLayer')}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
</GridItem>
|
||||
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable
|
||||
dropLabel={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
|
||||
data={addRegionalReferenceImageFromImageDropData}
|
||||
<DndDropTarget
|
||||
dndTarget={newCanvasEntityFromImageDndTarget}
|
||||
dndTargetData={addRegionalGuidanceReferenceImageFromImageDndTargetData}
|
||||
label={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem position="relative">
|
||||
<IAIDroppable
|
||||
dropLabel={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
|
||||
data={addGlobalReferenceImageFromImageDropData}
|
||||
<DndDropTarget
|
||||
dndTarget={newCanvasEntityFromImageDndTarget}
|
||||
dndTargetData={addGlobalReferenceImageFromImageDndTargetData}
|
||||
label={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCanvasEntityListDnd } from 'features/controlLayers/components/CanvasEntityList/useCanvasEntityListDnd';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityIsSelected } from 'features/controlLayers/hooks/useEntityIsSelected';
|
||||
import { entitySelected } from 'features/controlLayers/store/canvasSlice';
|
||||
import { DndListDropIndicator } from 'features/dnd/DndListDropIndicator';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
|
||||
const sx = {
|
||||
position: 'relative',
|
||||
flexDir: 'column',
|
||||
w: 'full',
|
||||
bg: 'base.850',
|
||||
borderRadius: 'base',
|
||||
'&[data-selected=true]': {
|
||||
bg: 'base.800',
|
||||
},
|
||||
'&[data-is-dragging=true]': {
|
||||
opacity: 0.3,
|
||||
},
|
||||
transitionProperty: 'common',
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const CanvasEntityContainer = memo((props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const isSelected = useEntityIsSelected(entityIdentifier);
|
||||
const onClick = useCallback(() => {
|
||||
if (isSelected) {
|
||||
return;
|
||||
}
|
||||
dispatch(entitySelected({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier, isSelected]);
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
|
||||
const [dndListState, isDragging] = useCanvasEntityListDnd(ref, entityIdentifier);
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<Flex
|
||||
// This is used to trigger the post-move flash animation
|
||||
data-entity-id={entityIdentifier.id}
|
||||
data-selected={isSelected}
|
||||
data-is-dragging={isDragging}
|
||||
ref={ref}
|
||||
onClick={onClick}
|
||||
sx={sx}
|
||||
>
|
||||
{props.children}
|
||||
</Flex>
|
||||
<DndListDropIndicator dndState={dndListState} />
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityContainer.displayName = 'CanvasEntityContainer';
|
||||
@@ -0,0 +1,181 @@
|
||||
import { monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import { extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
|
||||
import { reorderWithEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/util/reorder-with-edge';
|
||||
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
|
||||
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
|
||||
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
|
||||
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
|
||||
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
|
||||
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
|
||||
import { entitiesReordered } from 'features/controlLayers/store/canvasSlice';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { isRenderableEntityType } from 'features/controlLayers/store/types';
|
||||
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
|
||||
import { triggerPostMoveFlash } from 'features/dnd/util';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
isSelected: boolean;
|
||||
type: CanvasEntityIdentifier['type'];
|
||||
entityIdentifiers: CanvasEntityIdentifier[];
|
||||
}>;
|
||||
|
||||
export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityIdentifiers }: Props) => {
|
||||
const title = useEntityTypeTitle(type);
|
||||
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
|
||||
const collapse = useBoolean(true);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
return monitorForElements({
|
||||
canMonitor({ source }) {
|
||||
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
|
||||
return false;
|
||||
}
|
||||
if (source.data.payload.entityIdentifier.type !== type) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
onDrop({ location, source }) {
|
||||
const target = location.current.dropTargets[0];
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceData = source.data;
|
||||
const targetData = target.data;
|
||||
|
||||
if (!singleCanvasEntityDndSource.typeGuard(sourceData) || !singleCanvasEntityDndSource.typeGuard(targetData)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const indexOfSource = entityIdentifiers.findIndex(
|
||||
(entityIdentifier) => entityIdentifier.id === sourceData.payload.entityIdentifier.id
|
||||
);
|
||||
const indexOfTarget = entityIdentifiers.findIndex(
|
||||
(entityIdentifier) => entityIdentifier.id === targetData.payload.entityIdentifier.id
|
||||
);
|
||||
|
||||
if (indexOfTarget < 0 || indexOfSource < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't move if the source and target are the same index, meaning same position in the list
|
||||
if (indexOfSource === indexOfTarget) {
|
||||
return;
|
||||
}
|
||||
|
||||
const closestEdgeOfTarget = extractClosestEdge(targetData);
|
||||
|
||||
// It's possible that the indices are different, but refer to the same position. For example, if the source is
|
||||
// at 2 and the target is at 3, but the target edge is 'top', then the entity is already in the correct position.
|
||||
// We should bail if this is the case.
|
||||
let edgeIndexDelta = 0;
|
||||
|
||||
if (closestEdgeOfTarget === 'bottom') {
|
||||
edgeIndexDelta = 1;
|
||||
} else if (closestEdgeOfTarget === 'top') {
|
||||
edgeIndexDelta = -1;
|
||||
}
|
||||
|
||||
// If the source is already in the correct position, we don't need to move it.
|
||||
if (indexOfSource === indexOfTarget + edgeIndexDelta) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Using `flushSync` so we can query the DOM straight after this line
|
||||
flushSync(() => {
|
||||
dispatch(
|
||||
entitiesReordered({
|
||||
type,
|
||||
entityIdentifiers: reorderWithEdge({
|
||||
list: entityIdentifiers,
|
||||
startIndex: indexOfSource,
|
||||
indexOfTarget,
|
||||
closestEdgeOfTarget,
|
||||
axis: 'vertical',
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Flash the element that was moved
|
||||
const element = document.querySelector(`[data-entity-id="${sourceData.payload.entityIdentifier.id}"]`);
|
||||
if (element instanceof HTMLElement) {
|
||||
triggerPostMoveFlash(element, colorTokenToCssVar('base.700'));
|
||||
}
|
||||
},
|
||||
});
|
||||
}, [dispatch, entityIdentifiers, type]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex w="full">
|
||||
<Flex
|
||||
flexGrow={1}
|
||||
as={Button}
|
||||
onClick={collapse.toggle}
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
gap={3}
|
||||
variant="unstyled"
|
||||
p={0}
|
||||
h={8}
|
||||
>
|
||||
<Icon
|
||||
boxSize={4}
|
||||
as={PiCaretDownBold}
|
||||
transform={collapse.isTrue ? undefined : 'rotate(-90deg)'}
|
||||
fill={isSelected ? 'base.200' : 'base.500'}
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
/>
|
||||
{informationalPopoverFeature ? (
|
||||
<InformationalPopover feature={informationalPopoverFeature}>
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
color={isSelected ? 'base.200' : 'base.500'}
|
||||
userSelect="none"
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
</InformationalPopover>
|
||||
) : (
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
color={isSelected ? 'base.200' : 'base.500'}
|
||||
userSelect="none"
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<Spacer />
|
||||
</Flex>
|
||||
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
|
||||
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
|
||||
<CanvasEntityAddOfTypeButton type={type} />
|
||||
</Flex>
|
||||
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>
|
||||
<Flex flexDir="column" gap={2} pt={2}>
|
||||
{children}
|
||||
</Flex>
|
||||
</Collapse>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityGroupList.displayName = 'CanvasEntityGroupList';
|
||||
@@ -0,0 +1,83 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import { draggable, dropTargetForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-and-drop-hitbox/closest-edge';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
|
||||
import { type DndListTargetState, idle } from 'features/dnd/types';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export const useCanvasEntityListDnd = (ref: RefObject<HTMLElement>, entityIdentifier: CanvasEntityIdentifier) => {
|
||||
const [dndListState, setDndListState] = useState<DndListTargetState>(idle);
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const element = ref.current;
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
return combine(
|
||||
draggable({
|
||||
element,
|
||||
getInitialData() {
|
||||
return singleCanvasEntityDndSource.getData({ entityIdentifier });
|
||||
},
|
||||
onDragStart() {
|
||||
setDndListState({ type: 'is-dragging' });
|
||||
setIsDragging(true);
|
||||
},
|
||||
onDrop() {
|
||||
setDndListState(idle);
|
||||
setIsDragging(false);
|
||||
},
|
||||
}),
|
||||
dropTargetForElements({
|
||||
element,
|
||||
canDrop({ source }) {
|
||||
if (!singleCanvasEntityDndSource.typeGuard(source.data)) {
|
||||
return false;
|
||||
}
|
||||
if (source.data.payload.entityIdentifier.type !== entityIdentifier.type) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
getData({ input }) {
|
||||
const data = singleCanvasEntityDndSource.getData({ entityIdentifier });
|
||||
return attachClosestEdge(data, {
|
||||
element,
|
||||
input,
|
||||
allowedEdges: ['top', 'bottom'],
|
||||
});
|
||||
},
|
||||
getIsSticky() {
|
||||
return true;
|
||||
},
|
||||
onDragEnter({ self }) {
|
||||
const closestEdge = extractClosestEdge(self.data);
|
||||
setDndListState({ type: 'is-dragging-over', closestEdge });
|
||||
},
|
||||
onDrag({ self }) {
|
||||
const closestEdge = extractClosestEdge(self.data);
|
||||
|
||||
// Only need to update react state if nothing has changed.
|
||||
// Prevents re-rendering.
|
||||
setDndListState((current) => {
|
||||
if (current.type === 'is-dragging-over' && current.closestEdge === closestEdge) {
|
||||
return current;
|
||||
}
|
||||
return { type: 'is-dragging-over', closestEdge };
|
||||
});
|
||||
},
|
||||
onDragLeave() {
|
||||
setDndListState(idle);
|
||||
},
|
||||
onDrop() {
|
||||
setDndListState(idle);
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [entityIdentifier, ref]);
|
||||
|
||||
return [dndListState, isDragging] as const;
|
||||
};
|
||||
@@ -7,6 +7,8 @@ import { EntityListSelectedEntityActionBar } from 'features/controlLayers/compon
|
||||
import { selectHasEntities } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useRef } from 'react';
|
||||
|
||||
import { ParamDenoisingStrength } from './ParamDenoisingStrength';
|
||||
|
||||
export const CanvasLayersPanelContent = memo(() => {
|
||||
const hasEntities = useAppSelector(selectHasEntities);
|
||||
const layersPanelFocusRef = useRef<HTMLDivElement>(null);
|
||||
@@ -16,6 +18,8 @@ export const CanvasLayersPanelContent = memo(() => {
|
||||
<Flex ref={layersPanelFocusRef} flexDir="column" gap={2} w="full" h="full">
|
||||
<EntityListSelectedEntityActionBar />
|
||||
<Divider py={0} />
|
||||
<ParamDenoisingStrength />
|
||||
<Divider py={0} />
|
||||
{!hasEntities && <CanvasAddEntityButtons />}
|
||||
{hasEntities && <CanvasEntityList />}
|
||||
</Flex>
|
||||
|
||||
@@ -109,7 +109,9 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
<SelectObject />
|
||||
</CanvasManagerProviderGate>
|
||||
</Flex>
|
||||
<CanvasDropArea />
|
||||
<CanvasManagerProviderGate>
|
||||
<CanvasDropArea />
|
||||
</CanvasManagerProviderGate>
|
||||
<GatedImageViewer />
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import { useDndContext } from '@dnd-kit/core';
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import { dropTargetForElements, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import { dropTargetForExternal, monitorForExternal } from '@atlaskit/pragmatic-drag-and-drop/external/adapter';
|
||||
import { Box, Button, Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { CanvasLayersPanelContent } from 'features/controlLayers/components/CanvasLayersPanelContent';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectEntityCountActive } from 'features/controlLayers/store/selectors';
|
||||
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
|
||||
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
|
||||
import type { DndTargetState } from 'features/dnd/types';
|
||||
import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { selectActiveTabCanvasRightPanel } from 'features/ui/store/uiSelectors';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasRightPanel = memo(() => {
|
||||
@@ -79,37 +83,13 @@ CanvasRightPanel.displayName = 'CanvasRightPanel';
|
||||
|
||||
const PanelTabs = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const activeTab = useAppSelector(selectActiveTabCanvasRightPanel);
|
||||
const store = useAppStore();
|
||||
const activeEntityCount = useAppSelector(selectEntityCountActive);
|
||||
const tabTimeout = useRef<number | null>(null);
|
||||
const dndCtx = useDndContext();
|
||||
const dispatch = useAppDispatch();
|
||||
const [mouseOverTab, setMouseOverTab] = useState<'layers' | 'gallery' | null>(null);
|
||||
|
||||
const onOnMouseOverLayersTab = useCallback(() => {
|
||||
setMouseOverTab('layers');
|
||||
tabTimeout.current = window.setTimeout(() => {
|
||||
if (dndCtx.active) {
|
||||
dispatch(activeTabCanvasRightPanelChanged('layers'));
|
||||
}
|
||||
}, 300);
|
||||
}, [dndCtx.active, dispatch]);
|
||||
|
||||
const onOnMouseOverGalleryTab = useCallback(() => {
|
||||
setMouseOverTab('gallery');
|
||||
tabTimeout.current = window.setTimeout(() => {
|
||||
if (dndCtx.active) {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}
|
||||
}, 300);
|
||||
}, [dndCtx.active, dispatch]);
|
||||
|
||||
const onMouseOut = useCallback(() => {
|
||||
setMouseOverTab(null);
|
||||
if (tabTimeout.current) {
|
||||
clearTimeout(tabTimeout.current);
|
||||
}
|
||||
}, []);
|
||||
const [layersTabDndState, setLayersTabDndState] = useState<DndTargetState>('idle');
|
||||
const [galleryTabDndState, setGalleryTabDndState] = useState<DndTargetState>('idle');
|
||||
const layersTabRef = useRef<HTMLDivElement>(null);
|
||||
const galleryTabRef = useRef<HTMLDivElement>(null);
|
||||
const timeoutRef = useRef<number | null>(null);
|
||||
|
||||
const layersTabLabel = useMemo(() => {
|
||||
if (activeEntityCount === 0) {
|
||||
@@ -118,23 +98,172 @@ const PanelTabs = memo(() => {
|
||||
return `${t('controlLayers.layer_other')} (${activeEntityCount})`;
|
||||
}, [activeEntityCount, t]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!layersTabRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const getIsOnLayersTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'layers';
|
||||
|
||||
const onDragEnter = () => {
|
||||
// If we are already on the layers tab, do nothing
|
||||
if (getIsOnLayersTab()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Else set the state to active and switch to the layers tab after a timeout
|
||||
setLayersTabDndState('over');
|
||||
timeoutRef.current = window.setTimeout(() => {
|
||||
timeoutRef.current = null;
|
||||
store.dispatch(activeTabCanvasRightPanelChanged('layers'));
|
||||
// When we switch tabs, the other tab should be pending
|
||||
setLayersTabDndState('idle');
|
||||
setGalleryTabDndState('potential');
|
||||
}, 300);
|
||||
};
|
||||
const onDragLeave = () => {
|
||||
// Set the state to idle or pending depending on the current tab
|
||||
if (getIsOnLayersTab()) {
|
||||
setLayersTabDndState('idle');
|
||||
} else {
|
||||
setLayersTabDndState('potential');
|
||||
}
|
||||
// Abort the tab switch if it hasn't happened yet
|
||||
if (timeoutRef.current !== null) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
}
|
||||
};
|
||||
const onDragStart = () => {
|
||||
// Set the state to pending when a drag starts
|
||||
setLayersTabDndState('potential');
|
||||
};
|
||||
return combine(
|
||||
dropTargetForElements({
|
||||
element: layersTabRef.current,
|
||||
onDragEnter,
|
||||
onDragLeave,
|
||||
}),
|
||||
monitorForElements({
|
||||
canMonitor: ({ source }) => {
|
||||
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
|
||||
return false;
|
||||
}
|
||||
// Only monitor if we are not already on the gallery tab
|
||||
return !getIsOnLayersTab();
|
||||
},
|
||||
onDragStart,
|
||||
}),
|
||||
dropTargetForExternal({
|
||||
element: layersTabRef.current,
|
||||
onDragEnter,
|
||||
onDragLeave,
|
||||
}),
|
||||
monitorForExternal({
|
||||
canMonitor: () => !getIsOnLayersTab(),
|
||||
onDragStart,
|
||||
})
|
||||
);
|
||||
}, [store]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!galleryTabRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const getIsOnGalleryTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'gallery';
|
||||
|
||||
const onDragEnter = () => {
|
||||
// If we are already on the gallery tab, do nothing
|
||||
if (getIsOnGalleryTab()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Else set the state to active and switch to the gallery tab after a timeout
|
||||
setGalleryTabDndState('over');
|
||||
timeoutRef.current = window.setTimeout(() => {
|
||||
timeoutRef.current = null;
|
||||
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
// When we switch tabs, the other tab should be pending
|
||||
setGalleryTabDndState('idle');
|
||||
setLayersTabDndState('potential');
|
||||
}, 300);
|
||||
};
|
||||
|
||||
const onDragLeave = () => {
|
||||
// Set the state to idle or pending depending on the current tab
|
||||
if (getIsOnGalleryTab()) {
|
||||
setGalleryTabDndState('idle');
|
||||
} else {
|
||||
setGalleryTabDndState('potential');
|
||||
}
|
||||
// Abort the tab switch if it hasn't happened yet
|
||||
if (timeoutRef.current !== null) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
}
|
||||
};
|
||||
|
||||
const onDragStart = () => {
|
||||
// Set the state to pending when a drag starts
|
||||
setGalleryTabDndState('potential');
|
||||
};
|
||||
|
||||
return combine(
|
||||
dropTargetForElements({
|
||||
element: galleryTabRef.current,
|
||||
onDragEnter,
|
||||
onDragLeave,
|
||||
}),
|
||||
monitorForElements({
|
||||
canMonitor: ({ source }) => {
|
||||
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
|
||||
return false;
|
||||
}
|
||||
// Only monitor if we are not already on the gallery tab
|
||||
return !getIsOnGalleryTab();
|
||||
},
|
||||
onDragStart,
|
||||
}),
|
||||
dropTargetForExternal({
|
||||
element: galleryTabRef.current,
|
||||
onDragEnter,
|
||||
onDragLeave,
|
||||
}),
|
||||
monitorForExternal({
|
||||
canMonitor: () => !getIsOnGalleryTab(),
|
||||
onDragStart,
|
||||
})
|
||||
);
|
||||
}, [store]);
|
||||
|
||||
useEffect(() => {
|
||||
const onDrop = () => {
|
||||
// Reset the dnd state when a drop happens
|
||||
setGalleryTabDndState('idle');
|
||||
setLayersTabDndState('idle');
|
||||
};
|
||||
const cleanup = combine(monitorForElements({ onDrop }), monitorForExternal({ onDrop }));
|
||||
|
||||
return () => {
|
||||
cleanup();
|
||||
if (timeoutRef.current !== null) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Tab position="relative" onMouseOver={onOnMouseOverLayersTab} onMouseOut={onMouseOut} w={32}>
|
||||
<Tab ref={layersTabRef} position="relative" w={32}>
|
||||
<Box as="span" w="full">
|
||||
{layersTabLabel}
|
||||
</Box>
|
||||
{dndCtx.active && activeTab !== 'layers' && (
|
||||
<IAIDropOverlay isOver={mouseOverTab === 'layers'} withBackdrop={false} />
|
||||
)}
|
||||
<DndDropOverlay dndState={layersTabDndState} withBackdrop={false} />
|
||||
</Tab>
|
||||
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut} w={32}>
|
||||
<Tab ref={galleryTabRef} position="relative" w={32}>
|
||||
<Box as="span" w="full">
|
||||
{t('gallery.gallery')}
|
||||
</Box>
|
||||
{dndCtx.active && activeTab !== 'gallery' && (
|
||||
<IAIDropOverlay isOver={mouseOverTab === 'gallery'} withBackdrop={false} />
|
||||
)}
|
||||
<DndDropOverlay dndState={galleryTabDndState} withBackdrop={false} />
|
||||
</Tab>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { ControlLayerBadges } from 'features/controlLayers/components/ControlLayer/ControlLayerBadges';
|
||||
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
|
||||
import { ControlLayerSettings } from 'features/controlLayers/components/ControlLayer/ControlLayerSettings';
|
||||
import { ControlLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import type { ReplaceLayerImageDropData } from 'features/dnd/types';
|
||||
import type { ReplaceCanvasEntityObjectsWithImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { replaceCanvasEntityObjectsWithImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -21,14 +23,16 @@ type Props = {
|
||||
|
||||
export const ControlLayer = memo(({ id }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const entityIdentifier = useMemo<CanvasEntityIdentifier<'control_layer'>>(
|
||||
() => ({ id, type: 'control_layer' }),
|
||||
[id]
|
||||
);
|
||||
const dropData = useMemo<ReplaceLayerImageDropData>(
|
||||
() => ({ id, actionType: 'REPLACE_LAYER_WITH_IMAGE', context: { entityIdentifier } }),
|
||||
[id, entityIdentifier]
|
||||
const dndTargetData = useMemo<ReplaceCanvasEntityObjectsWithImageDndTargetData>(
|
||||
() => replaceCanvasEntityObjectsWithImageDndTarget.getData({ entityIdentifier }, entityIdentifier.id),
|
||||
[entityIdentifier]
|
||||
);
|
||||
|
||||
return (
|
||||
<EntityIdentifierContext.Provider value={entityIdentifier}>
|
||||
<ControlLayerAdapterGate>
|
||||
@@ -41,9 +45,14 @@ export const ControlLayer = memo(({ id }: Props) => {
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<ControlLayerControlAdapter />
|
||||
<ControlLayerSettings />
|
||||
</CanvasEntitySettingsWrapper>
|
||||
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
|
||||
<DndDropTarget
|
||||
dndTarget={replaceCanvasEntityObjectsWithImageDndTarget}
|
||||
dndTargetData={dndTargetData}
|
||||
label={t('controlLayers.replaceLayer')}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
</CanvasEntityContainer>
|
||||
</ControlLayerAdapterGate>
|
||||
</EntityIdentifierContext.Provider>
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import { Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
|
||||
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
|
||||
import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
@@ -16,13 +18,15 @@ import {
|
||||
controlLayerModelChanged,
|
||||
controlLayerWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { getFilterForModel } from 'features/controlLayers/store/filters';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
|
||||
const selectControlAdapter = useMemo(
|
||||
@@ -39,11 +43,12 @@ const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<
|
||||
|
||||
export const ControlLayerControlAdapter = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { dispatch, getState } = useAppStore();
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
|
||||
const filter = useEntityFilter(entityIdentifier);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const adapter = useEntityAdapterContext('control_layer');
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
@@ -69,17 +74,58 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
|
||||
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
|
||||
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
|
||||
// filter config.
|
||||
const isFiltering = adapter.filterer.$isFiltering.get();
|
||||
const isSimple = adapter.filterer.$simple.get();
|
||||
// If we are filtering and _not_ in simple mode, that means the user has clicked Advanced. They want to be in control
|
||||
// of the settings. Bail early without doing anything else.
|
||||
if (isFiltering && !isSimple) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Else, we are in simple mode and will take care of some things for the user.
|
||||
|
||||
// First, check if the newly-selected model has a default filter. It may not - for example, Tile controlnet models
|
||||
// don't have a default filter.
|
||||
const defaultFilterForNewModel = getFilterForModel(modelConfig);
|
||||
|
||||
if (!defaultFilterForNewModel) {
|
||||
// The user has chosen a model that doesn't have a default filter - cancel any in-progress filtering and bail.
|
||||
if (isFiltering) {
|
||||
adapter.filterer.cancel();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, we know the user has selected a model that has a default filter. We need to either start filtering
|
||||
// with that default filter, or update the existing filter config to match the new model's default filter.
|
||||
const filterConfig = defaultFilterForNewModel.buildDefaults();
|
||||
if (isFiltering) {
|
||||
adapter.filterer.$filterConfig.set(filterConfig);
|
||||
} else {
|
||||
adapter.filterer.start(filterConfig);
|
||||
}
|
||||
// The user may have disabled auto-processing, so we should process the filter manually. This is essentially a
|
||||
// no-op if auto-processing is already enabled, because the process method is debounced.
|
||||
adapter.filterer.process();
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
[adapter.filterer, dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const postUploadAction = useMemo<PostUploadAction>(
|
||||
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
|
||||
[entityIdentifier]
|
||||
const uploadOptions = useMemo(
|
||||
() =>
|
||||
({
|
||||
onUpload: (imageDTO: ImageDTO) => {
|
||||
replaceCanvasEntityObjectsWithImage({ entityIdentifier, imageDTO, dispatch, getState });
|
||||
},
|
||||
allowMultiple: false,
|
||||
}) as const,
|
||||
[dispatch, entityIdentifier, getState]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ postUploadAction });
|
||||
const uploadApi = useImageUploadButton(uploadOptions);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} position="relative" w="full">
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
|
||||
import { ControlLayer } from 'features/controlLayers/components/ControlLayer/ControlLayer';
|
||||
import { mapId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.controlLayers.entities.map(mapId).reverse();
|
||||
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.controlLayers.entities.map(getEntityIdentifier).toReversed();
|
||||
});
|
||||
|
||||
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
|
||||
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
|
||||
|
||||
export const ControlLayerEntityList = memo(() => {
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const layerIds = useAppSelector(selectEntityIds);
|
||||
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
|
||||
|
||||
if (layerIds.length === 0) {
|
||||
if (entityIdentifiers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (layerIds.length > 0) {
|
||||
if (entityIdentifiers.length > 0) {
|
||||
return (
|
||||
<CanvasEntityGroupList type="control_layer" isSelected={isSelected}>
|
||||
{layerIds.map((id) => (
|
||||
<ControlLayer key={id} id={id} />
|
||||
<CanvasEntityGroupList type="control_layer" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
|
||||
{entityIdentifiers.map((entityIdentifier) => (
|
||||
<ControlLayer key={entityIdentifier.id} id={entityIdentifier.id} />
|
||||
))}
|
||||
</CanvasEntityGroupList>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
|
||||
import { ControlLayerSettingsEmptyState } from 'features/controlLayers/components/ControlLayer/ControlLayerSettingsEmptyState';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const ControlLayerSettings = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const isEmpty = useEntityIsEmpty(entityIdentifier);
|
||||
|
||||
if (isEmpty) {
|
||||
return <ControlLayerSettingsEmptyState />;
|
||||
}
|
||||
|
||||
return <ControlLayerControlAdapter />;
|
||||
});
|
||||
|
||||
ControlLayerSettings.displayName = 'ControlLayerSettings';
|
||||
@@ -0,0 +1,53 @@
|
||||
import { Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { Trans } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const ControlLayerSettingsEmptyState = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const { dispatch, getState } = useAppStore();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const onUpload = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
replaceCanvasEntityObjectsWithImage({ imageDTO, entityIdentifier, dispatch, getState });
|
||||
},
|
||||
[dispatch, entityIdentifier, getState]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
|
||||
<Text textAlign="center" color="base.300">
|
||||
<Trans
|
||||
i18nKey="controlLayers.controlLayerEmptyState"
|
||||
components={{
|
||||
UploadButton: (
|
||||
<Button
|
||||
isDisabled={isBusy}
|
||||
size="sm"
|
||||
variant="link"
|
||||
color="base.300"
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</Text>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ControlLayerSettingsEmptyState.displayName = 'ControlLayerSettingsEmptyState';
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
MenuList,
|
||||
Spacer,
|
||||
Spinner,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -28,13 +29,10 @@ import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
const FilterContent = memo(
|
||||
const FilterContentAdvanced = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
const config = useStore(adapter.filterer.$filterConfig);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const hasImageState = useStore(adapter.filterer.$hasImageState);
|
||||
const autoProcess = useAppSelector(selectAutoProcess);
|
||||
@@ -73,36 +71,8 @@ const FilterContent = memo(
|
||||
adapter.filterer.saveAs('control_layer');
|
||||
}, [adapter.filterer]);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applyFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.apply,
|
||||
options: { enabled: !isProcessing && isCanvasFocused },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'cancelFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.cancel,
|
||||
options: { enabled: !isProcessing && isCanvasFocused },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
p={4}
|
||||
flexDir="column"
|
||||
gap={4}
|
||||
w={420}
|
||||
h="auto"
|
||||
shadow="dark-lg"
|
||||
transitionProperty="height"
|
||||
transitionDuration="normal"
|
||||
>
|
||||
<>
|
||||
<Flex w="full" gap={4}>
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.filter.filter')}
|
||||
@@ -169,12 +139,67 @@ const FilterContent = memo(
|
||||
{t('controlLayers.filter.cancel')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FilterContent.displayName = 'FilterContent';
|
||||
FilterContentAdvanced.displayName = 'FilterContentAdvanced';
|
||||
|
||||
const FilterContentSimple = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const { t } = useTranslation();
|
||||
const config = useStore(adapter.filterer.$filterConfig);
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const hasImageState = useStore(adapter.filterer.$hasImageState);
|
||||
|
||||
const isValid = useMemo(() => {
|
||||
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
|
||||
}, [config]);
|
||||
|
||||
const onClickAdvanced = useCallback(() => {
|
||||
adapter.filterer.$simple.set(false);
|
||||
}, [adapter.filterer.$simple]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex w="full" gap={4}>
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.filter.filter')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
</Flex>
|
||||
<Flex flexDir="column" w="full" gap={2} pb={2}>
|
||||
<Text color="base.500" textAlign="center">
|
||||
{t('controlLayers.filter.processingLayerWith', { type: t(`controlLayers.filter.${config.type}.label`) })}
|
||||
</Text>
|
||||
<Text color="base.500" textAlign="center">
|
||||
{t('controlLayers.filter.forMoreControl')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Button variant="ghost" onClick={onClickAdvanced}>
|
||||
{t('controlLayers.filter.advanced')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button
|
||||
onClick={adapter.filterer.apply}
|
||||
loadingText={t('controlLayers.filter.apply')}
|
||||
variant="ghost"
|
||||
isDisabled={isProcessing || !isValid || !hasImageState}
|
||||
>
|
||||
{t('controlLayers.filter.apply')}
|
||||
</Button>
|
||||
<Button variant="ghost" onClick={adapter.filterer.cancel} loadingText={t('controlLayers.filter.cancel')}>
|
||||
{t('controlLayers.filter.cancel')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FilterContentSimple.displayName = 'FilterContentSimple';
|
||||
|
||||
export const Filter = () => {
|
||||
const canvasManager = useCanvasManager();
|
||||
@@ -182,8 +207,54 @@ export const Filter = () => {
|
||||
if (!adapter) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <FilterContent adapter={adapter} />;
|
||||
};
|
||||
|
||||
Filter.displayName = 'Filter';
|
||||
|
||||
const FilterContent = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
const simplified = useStore(adapter.filterer.$simple);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useFocusRegion('canvas', ref, { focusOnMount: true });
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applyFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.apply,
|
||||
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'cancelFilter',
|
||||
category: 'canvas',
|
||||
callback: adapter.filterer.cancel,
|
||||
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
|
||||
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
p={4}
|
||||
flexDir="column"
|
||||
gap={4}
|
||||
w={420}
|
||||
h="auto"
|
||||
shadow="dark-lg"
|
||||
transitionProperty="height"
|
||||
transitionDuration="normal"
|
||||
>
|
||||
{simplified && <FilterContentSimple adapter={adapter} />}
|
||||
{!simplified && <FilterContentAdvanced adapter={adapter} />}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FilterContent.displayName = 'FilterContent';
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
|
||||
@@ -1,82 +1,80 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
import { useNanoid } from 'common/hooks/useNanoid';
|
||||
import { UploadImageButton } from 'common/hooks/useImageUploadButton';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import type { ImageDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { DndImageIcon } from 'features/dnd/DndImageIcon';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
type Props = {
|
||||
type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget> = {
|
||||
image: ImageWithDims | null;
|
||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||
droppableData: TypesafeDroppableData;
|
||||
postUploadAction: PostUploadAction;
|
||||
dndTarget: T;
|
||||
dndTargetData: ReturnType<T['getData']>;
|
||||
};
|
||||
|
||||
export const IPAdapterImagePreview = memo(({ image, onChangeImage, droppableData, postUploadAction }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useStore($isConnected);
|
||||
const dndId = useNanoid('ip_adapter_image_preview');
|
||||
export const IPAdapterImagePreview = memo(
|
||||
<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget>({
|
||||
image,
|
||||
onChangeImage,
|
||||
dndTarget,
|
||||
dndTargetData,
|
||||
}: Props<T>) => {
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useStore($isConnected);
|
||||
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
|
||||
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
|
||||
image?.image_name ?? skipToken
|
||||
);
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
useEffect(() => {
|
||||
if (isConnected && isError) {
|
||||
handleResetControlImage();
|
||||
}
|
||||
}, [handleResetControlImage, isError, isConnected]);
|
||||
|
||||
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
|
||||
if (controlImage) {
|
||||
return {
|
||||
id: dndId,
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO: controlImage },
|
||||
};
|
||||
}
|
||||
}, [controlImage, dndId]);
|
||||
const onUpload = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
onChangeImage(imageDTO);
|
||||
},
|
||||
[onChangeImage]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected && isErrorControlImage) {
|
||||
handleResetControlImage();
|
||||
}
|
||||
}, [handleResetControlImage, isConnected, isErrorControlImage]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
w="full"
|
||||
h="full"
|
||||
alignItems="center"
|
||||
borderColor="error.500"
|
||||
borderStyle="solid"
|
||||
borderWidth={controlImage ? 0 : 1}
|
||||
borderRadius="base"
|
||||
>
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
{controlImage && (
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('common.reset')}
|
||||
return (
|
||||
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
|
||||
{!imageDTO && (
|
||||
<UploadImageButton
|
||||
w="full"
|
||||
h="full"
|
||||
isError={!imageDTO && !image?.image_name}
|
||||
onUpload={onUpload}
|
||||
fontSize={36}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
)}
|
||||
{imageDTO && (
|
||||
<>
|
||||
<DndImage imageDTO={imageDTO} />
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<DndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
tooltip={t('common.reset')}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
|
||||
import { IPAdapter } from 'features/controlLayers/components/IPAdapter/IPAdapter';
|
||||
import { mapId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.referenceImages.entities.map(mapId).reverse();
|
||||
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.referenceImages.entities.map(getEntityIdentifier).toReversed();
|
||||
});
|
||||
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
|
||||
return selectedEntityIdentifier?.type === 'reference_image';
|
||||
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
|
||||
|
||||
export const IPAdapterList = memo(() => {
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const ipaIds = useAppSelector(selectEntityIds);
|
||||
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
|
||||
|
||||
if (ipaIds.length === 0) {
|
||||
if (entityIdentifiers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (ipaIds.length > 0) {
|
||||
if (entityIdentifiers.length > 0) {
|
||||
return (
|
||||
<CanvasEntityGroupList type="reference_image" isSelected={isSelected}>
|
||||
{ipaIds.map((id) => (
|
||||
<IPAdapter key={id} id={id} />
|
||||
<CanvasEntityGroupList type="reference_image" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
|
||||
{entityIdentifiers.map((entityIdentifiers) => (
|
||||
<IPAdapter key={entityIdentifiers.id} id={entityIdentifiers.id} />
|
||||
))}
|
||||
</CanvasEntityGroupList>
|
||||
);
|
||||
|
||||
@@ -19,11 +19,12 @@ import {
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
|
||||
import type { IPAImageDropData } from 'features/dnd/types';
|
||||
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, IPAdapterModelConfig, IPALayerImagePostUploadAction } from 'services/api/types';
|
||||
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
|
||||
import { IPAdapterModel } from './IPAdapterModel';
|
||||
@@ -80,13 +81,9 @@ export const IPAdapterSettings = memo(() => {
|
||||
[dispatch, entityIdentifier]
|
||||
);
|
||||
|
||||
const droppableData = useMemo<IPAImageDropData>(
|
||||
() => ({ actionType: 'SET_IPA_IMAGE', context: { id: entityIdentifier.id }, id: entityIdentifier.id }),
|
||||
[entityIdentifier.id]
|
||||
);
|
||||
const postUploadAction = useMemo<IPALayerImagePostUploadAction>(
|
||||
() => ({ type: 'SET_IPA_IMAGE', id: entityIdentifier.id }),
|
||||
[entityIdentifier.id]
|
||||
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
|
||||
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }, ipAdapter.image?.image_name),
|
||||
[entityIdentifier, ipAdapter.image?.image_name]
|
||||
);
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
@@ -122,10 +119,10 @@ export const IPAdapterSettings = memo(() => {
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
|
||||
<IPAdapterImagePreview
|
||||
image={ipAdapter.image ?? null}
|
||||
image={ipAdapter.image}
|
||||
onChangeImage={onChangeImage}
|
||||
droppableData={droppableData}
|
||||
postUploadAction={postUploadAction}
|
||||
dndTarget={setGlobalReferenceImageDndTarget}
|
||||
dndTargetData={dndTargetData}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
|
||||
import { InpaintMask } from 'features/controlLayers/components/InpaintMask/InpaintMask';
|
||||
import { mapId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.inpaintMasks.entities.map(mapId).reverse();
|
||||
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.inpaintMasks.entities.map(getEntityIdentifier).toReversed();
|
||||
});
|
||||
|
||||
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
|
||||
@@ -17,17 +17,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
|
||||
|
||||
export const InpaintMaskList = memo(() => {
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const entityIds = useAppSelector(selectEntityIds);
|
||||
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
|
||||
|
||||
if (entityIds.length === 0) {
|
||||
if (entityIdentifiers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (entityIds.length > 0) {
|
||||
if (entityIdentifiers.length > 0) {
|
||||
return (
|
||||
<CanvasEntityGroupList type="inpaint_mask" isSelected={isSelected}>
|
||||
{entityIds.map((id) => (
|
||||
<InpaintMask key={id} id={id} />
|
||||
<CanvasEntityGroupList type="inpaint_mask" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
|
||||
{entityIdentifiers.map((entityIdentifier) => (
|
||||
<InpaintMask key={entityIdentifier.id} id={entityIdentifier.id} />
|
||||
))}
|
||||
</CanvasEntityGroupList>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
import {
|
||||
Badge,
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
useToken,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import WavyLine from 'common/components/WavyLine';
|
||||
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
|
||||
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
|
||||
|
||||
export const ParamDenoisingStrength = memo(() => {
|
||||
const img2imgStrength = useAppSelector(selectImg2imgStrength);
|
||||
const dispatch = useAppDispatch();
|
||||
const isEnabled = useAppSelector(selectIsEnabled);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setImg2imgStrength(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const config = useAppSelector(selectImg2imgStrengthConfig);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
|
||||
<Flex gap={3} alignItems="center">
|
||||
<InformationalPopover feature="paramDenoisingStrength">
|
||||
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
|
||||
</InformationalPopover>
|
||||
{isEnabled && (
|
||||
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
|
||||
)}
|
||||
</Flex>
|
||||
{isEnabled ? (
|
||||
<>
|
||||
<CompositeSlider
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
min={config.sliderMin}
|
||||
max={config.sliderMax}
|
||||
defaultValue={config.initial}
|
||||
onChange={onChange}
|
||||
value={img2imgStrength}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
min={config.numberInputMin}
|
||||
max={config.numberInputMax}
|
||||
defaultValue={config.initial}
|
||||
onChange={onChange}
|
||||
value={img2imgStrength}
|
||||
variant="outline"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<Flex alignItems="center">
|
||||
<Badge opacity="0.6">
|
||||
{t('common.disabled')} - {t('parameters.noRasterLayers')}
|
||||
</Badge>
|
||||
</Flex>
|
||||
)}
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
ParamDenoisingStrength.displayName = 'ParamDenoisingStrength';
|
||||
@@ -1,14 +1,16 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { RasterLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import type { ReplaceLayerImageDropData } from 'features/dnd/types';
|
||||
import type { ReplaceCanvasEntityObjectsWithImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { replaceCanvasEntityObjectsWithImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -18,10 +20,11 @@ type Props = {
|
||||
|
||||
export const RasterLayer = memo(({ id }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const entityIdentifier = useMemo<CanvasEntityIdentifier<'raster_layer'>>(() => ({ id, type: 'raster_layer' }), [id]);
|
||||
const dropData = useMemo<ReplaceLayerImageDropData>(
|
||||
() => ({ id, actionType: 'REPLACE_LAYER_WITH_IMAGE', context: { entityIdentifier } }),
|
||||
[id, entityIdentifier]
|
||||
const dndTargetData = useMemo<ReplaceCanvasEntityObjectsWithImageDndTargetData>(
|
||||
() => replaceCanvasEntityObjectsWithImageDndTarget.getData({ entityIdentifier }, entityIdentifier.id),
|
||||
[entityIdentifier]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -34,7 +37,12 @@ export const RasterLayer = memo(({ id }: Props) => {
|
||||
<Spacer />
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
|
||||
<DndDropTarget
|
||||
dndTarget={replaceCanvasEntityObjectsWithImageDndTarget}
|
||||
dndTargetData={dndTargetData}
|
||||
label={t('controlLayers.replaceLayer')}
|
||||
isDisabled={isBusy}
|
||||
/>
|
||||
</CanvasEntityContainer>
|
||||
</RasterLayerAdapterGate>
|
||||
</EntityIdentifierContext.Provider>
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
|
||||
import { RasterLayer } from 'features/controlLayers/components/RasterLayer/RasterLayer';
|
||||
import { mapId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.rasterLayers.entities.map(mapId).reverse();
|
||||
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.rasterLayers.entities.map(getEntityIdentifier).toReversed();
|
||||
});
|
||||
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
|
||||
return selectedEntityIdentifier?.type === 'raster_layer';
|
||||
@@ -16,17 +16,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
|
||||
|
||||
export const RasterLayerEntityList = memo(() => {
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const layerIds = useAppSelector(selectEntityIds);
|
||||
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
|
||||
|
||||
if (layerIds.length === 0) {
|
||||
if (entityIdentifiers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (layerIds.length > 0) {
|
||||
if (entityIdentifiers.length > 0) {
|
||||
return (
|
||||
<CanvasEntityGroupList type="raster_layer" isSelected={isSelected}>
|
||||
{layerIds.map((id) => (
|
||||
<RasterLayer key={id} id={id} />
|
||||
<CanvasEntityGroupList type="raster_layer" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
|
||||
{entityIdentifiers.map((entityIdentifier) => (
|
||||
<RasterLayer key={entityIdentifier.id} id={entityIdentifier.id} />
|
||||
))}
|
||||
</CanvasEntityGroupList>
|
||||
);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useEntityIsLocked } from 'features/controlLayers/hooks/useEntityIsLocked';
|
||||
import {
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
rasterLayerConvertedToInpaintMask,
|
||||
rasterLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { initialControlNet } from 'features/controlLayers/store/util';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSwapBold } from 'react-icons/pi';
|
||||
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isLocked = useEntityIsLocked(entityIdentifier);
|
||||
|
||||
@@ -37,10 +37,10 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
replace: true,
|
||||
overrides: { controlAdapter: defaultControlAdapter },
|
||||
overrides: { controlAdapter: deepClone(initialControlNet) },
|
||||
})
|
||||
);
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />} isDisabled={isBusy || isLocked}>
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import {
|
||||
rasterLayerConvertedToControlLayer,
|
||||
rasterLayerConvertedToInpaintMask,
|
||||
rasterLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { initialControlNet } from 'features/controlLayers/store/util';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const copyToInpaintMask = useCallback(() => {
|
||||
@@ -35,10 +35,10 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
|
||||
dispatch(
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
overrides: { controlAdapter: defaultControlAdapter },
|
||||
overrides: { controlAdapter: deepClone(initialControlNet) },
|
||||
})
|
||||
);
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />} isDisabled={isBusy}>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/common/CanvasEntityGroupList';
|
||||
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
|
||||
import { RegionalGuidance } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidance';
|
||||
import { mapId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectEntityIds = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.regionalGuidance.entities.map(mapId).reverse();
|
||||
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
return canvas.regionalGuidance.entities.map(getEntityIdentifier).toReversed();
|
||||
});
|
||||
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
|
||||
return selectedEntityIdentifier?.type === 'regional_guidance';
|
||||
@@ -16,17 +16,17 @@ const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selecte
|
||||
|
||||
export const RegionalGuidanceEntityList = memo(() => {
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const rgIds = useAppSelector(selectEntityIds);
|
||||
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
|
||||
|
||||
if (rgIds.length === 0) {
|
||||
if (entityIdentifiers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (rgIds.length > 0) {
|
||||
if (entityIdentifiers.length > 0) {
|
||||
return (
|
||||
<CanvasEntityGroupList type="regional_guidance" isSelected={isSelected}>
|
||||
{rgIds.map((id) => (
|
||||
<RegionalGuidance key={id} id={id} />
|
||||
<CanvasEntityGroupList type="regional_guidance" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
|
||||
{entityIdentifiers.map((entityIdentifier) => (
|
||||
<RegionalGuidance key={entityIdentifier.id} id={entityIdentifier.id} />
|
||||
))}
|
||||
</CanvasEntityGroupList>
|
||||
);
|
||||
|
||||
@@ -20,11 +20,12 @@ import {
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice, selectRegionalGuidanceReferenceImage } from 'features/controlLayers/store/selectors';
|
||||
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
|
||||
import type { RGIPAdapterImageDropData } from 'features/dnd/types';
|
||||
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import type { ImageDTO, IPAdapterModelConfig, RGIPAdapterImagePostUploadAction } from 'services/api/types';
|
||||
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
@@ -91,18 +92,15 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
);
|
||||
|
||||
const droppableData = useMemo<RGIPAdapterImageDropData>(
|
||||
() => ({
|
||||
actionType: 'SET_RG_IP_ADAPTER_IMAGE',
|
||||
context: { id: entityIdentifier.id, referenceImageId: referenceImageId },
|
||||
id: entityIdentifier.id,
|
||||
}),
|
||||
[entityIdentifier.id, referenceImageId]
|
||||
);
|
||||
const postUploadAction = useMemo<RGIPAdapterImagePostUploadAction>(
|
||||
() => ({ type: 'SET_RG_IP_ADAPTER_IMAGE', id: entityIdentifier.id, referenceImageId: referenceImageId }),
|
||||
[entityIdentifier.id, referenceImageId]
|
||||
const dndTargetData = useMemo<SetRegionalGuidanceReferenceImageDndTargetData>(
|
||||
() =>
|
||||
setRegionalGuidanceReferenceImageDndTarget.getData(
|
||||
{ entityIdentifier, referenceImageId },
|
||||
ipAdapter.image?.image_name
|
||||
),
|
||||
[entityIdentifier, ipAdapter.image?.image_name, referenceImageId]
|
||||
);
|
||||
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
@@ -151,10 +149,10 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
|
||||
<IPAdapterImagePreview
|
||||
image={ipAdapter.image ?? null}
|
||||
image={ipAdapter.image}
|
||||
onChangeImage={onChangeImage}
|
||||
droppableData={droppableData}
|
||||
postUploadAction={postUploadAction}
|
||||
dndTarget={setRegionalGuidanceReferenceImageDndTarget}
|
||||
dndTargetData={dndTargetData}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityIsSelected } from 'features/controlLayers/hooks/useEntityIsSelected';
|
||||
import { useEntitySelectionColor } from 'features/controlLayers/hooks/useEntitySelectionColor';
|
||||
import { entitySelected } from 'features/controlLayers/store/canvasSlice';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
export const CanvasEntityContainer = memo((props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const isSelected = useEntityIsSelected(entityIdentifier);
|
||||
const selectionColor = useEntitySelectionColor(entityIdentifier);
|
||||
const onClick = useCallback(() => {
|
||||
if (isSelected) {
|
||||
return;
|
||||
}
|
||||
dispatch(entitySelected({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier, isSelected]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
flexDir="column"
|
||||
w="full"
|
||||
bg={isSelected ? 'base.800' : 'base.850'}
|
||||
onClick={onClick}
|
||||
borderInlineStartWidth={5}
|
||||
borderColor={isSelected ? selectionColor : 'base.800'}
|
||||
borderRadius="base"
|
||||
>
|
||||
{props.children}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityContainer.displayName = 'CanvasEntityContainer';
|
||||
@@ -1,90 +0,0 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
|
||||
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
|
||||
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
|
||||
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
|
||||
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
|
||||
import { type CanvasEntityIdentifier, isRenderableEntityType } from 'features/controlLayers/store/types';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
isSelected: boolean;
|
||||
type: CanvasEntityIdentifier['type'];
|
||||
}>;
|
||||
|
||||
const _hover: SystemStyleObject = {
|
||||
opacity: 1,
|
||||
};
|
||||
|
||||
export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props) => {
|
||||
const title = useEntityTypeTitle(type);
|
||||
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
|
||||
const collapse = useBoolean(true);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex w="full">
|
||||
<Flex
|
||||
flexGrow={1}
|
||||
as={Button}
|
||||
onClick={collapse.toggle}
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
gap={3}
|
||||
variant="unstyled"
|
||||
p={0}
|
||||
h={8}
|
||||
>
|
||||
<Icon
|
||||
boxSize={4}
|
||||
as={PiCaretDownBold}
|
||||
transform={collapse.isTrue ? undefined : 'rotate(-90deg)'}
|
||||
fill={isSelected ? 'base.200' : 'base.500'}
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
/>
|
||||
{informationalPopoverFeature ? (
|
||||
<InformationalPopover feature={informationalPopoverFeature}>
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
color={isSelected ? 'base.200' : 'base.500'}
|
||||
userSelect="none"
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
</InformationalPopover>
|
||||
) : (
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
color={isSelected ? 'base.200' : 'base.500'}
|
||||
userSelect="none"
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<Spacer />
|
||||
</Flex>
|
||||
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
|
||||
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
|
||||
<CanvasEntityAddOfTypeButton type={type} />
|
||||
</Flex>
|
||||
<Collapse in={collapse.isTrue}>
|
||||
<Flex flexDir="column" gap={2} pt={2}>
|
||||
{children}
|
||||
</Flex>
|
||||
</Collapse>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityGroupList.displayName = 'CanvasEntityGroupList';
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, chakra, Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, chakra, Flex, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { rgbColorToString } from 'common/util/colorCodeTransformers';
|
||||
@@ -86,13 +86,63 @@ export const CanvasEntityPreviewImage = memo(() => {
|
||||
|
||||
useEffect(updatePreview, [updatePreview, canvasCache, nodeRect, pixelRect]);
|
||||
|
||||
return (
|
||||
<Tooltip label={<TooltipContent canvasRef={canvasRef} />} p={2} closeOnScroll>
|
||||
<Flex
|
||||
position="relative"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
w={CONTAINER_WIDTH_PX}
|
||||
h={CONTAINER_WIDTH_PX}
|
||||
borderRadius="sm"
|
||||
borderWidth={1}
|
||||
bg="base.900"
|
||||
flexShrink={0}
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
|
||||
bgSize="5px"
|
||||
/>
|
||||
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
|
||||
|
||||
const TooltipContent = ({ canvasRef }: { canvasRef: React.RefObject<HTMLCanvasElement> }) => {
|
||||
const canvasRef2 = useRef<HTMLCanvasElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!canvasRef2.current || !canvasRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const ctx = canvasRef2.current.getContext('2d');
|
||||
|
||||
if (!ctx) {
|
||||
return;
|
||||
}
|
||||
|
||||
canvasRef2.current.width = canvasRef.current.width;
|
||||
canvasRef2.current.height = canvasRef.current.height;
|
||||
ctx.clearRect(0, 0, canvasRef2.current.width, canvasRef2.current.height);
|
||||
ctx.drawImage(canvasRef.current, 0, 0);
|
||||
}, [canvasRef]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
w={CONTAINER_WIDTH_PX}
|
||||
h={CONTAINER_WIDTH_PX}
|
||||
w={150}
|
||||
h={150}
|
||||
borderRadius="sm"
|
||||
borderWidth={1}
|
||||
bg="base.900"
|
||||
@@ -105,11 +155,9 @@ export const CanvasEntityPreviewImage = memo(() => {
|
||||
bottom={0}
|
||||
left={0}
|
||||
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
|
||||
bgSize="5px"
|
||||
bgSize="8px"
|
||||
/>
|
||||
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
|
||||
<ChakraCanvas position="relative" ref={canvasRef2} objectFit="contain" maxW="full" maxH="full" />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
|
||||
};
|
||||
|
||||
@@ -4,9 +4,10 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
|
||||
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import type { CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import type { CanvasEntityIdentifier, CanvasRenderableEntityType } from 'features/controlLayers/store/types';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useMemo, useSyncExternalStore } from 'react';
|
||||
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const EntityAdapterContext = createContext<
|
||||
@@ -95,6 +96,17 @@ export const RegionalGuidanceAdapterGate = memo(({ children }: PropsWithChildren
|
||||
return <EntityAdapterContext.Provider value={adapter}>{children}</EntityAdapterContext.Provider>;
|
||||
});
|
||||
|
||||
export const useEntityAdapterContext = <T extends CanvasRenderableEntityType | undefined = CanvasRenderableEntityType>(
|
||||
type?: T
|
||||
): CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T> => {
|
||||
const adapter = useContext(EntityAdapterContext);
|
||||
assert(adapter, 'useEntityIdentifier must be used within a EntityIdentifierProvider');
|
||||
if (type) {
|
||||
assert(adapter.entityIdentifier.type === type, 'useEntityIdentifier must be used with the correct type');
|
||||
}
|
||||
return adapter as CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T>;
|
||||
};
|
||||
|
||||
RegionalGuidanceAdapterGate.displayName = 'RegionalGuidanceAdapterGate';
|
||||
|
||||
export const useEntityAdapterSafe = (
|
||||
|
||||
@@ -2,11 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import {
|
||||
bboxChangedFromCanvas,
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
rasterLayerAdded,
|
||||
@@ -17,38 +14,22 @@ import {
|
||||
rgPositivePromptChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
selectBboxModelBase,
|
||||
selectBboxRect,
|
||||
selectCanvasSlice,
|
||||
selectEntityOrThrow,
|
||||
} from 'features/controlLayers/store/selectors';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type {
|
||||
CanvasControlLayerState,
|
||||
CanvasEntityIdentifier,
|
||||
CanvasInpaintMaskState,
|
||||
CanvasRasterLayerState,
|
||||
CanvasRegionalGuidanceState,
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import {
|
||||
imageDTOToImageObject,
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlLayers/store/util';
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { useCallback } from 'react';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/** @knipignore */
|
||||
export const selectDefaultControlAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
@@ -85,6 +66,9 @@ export const selectDefaultIPAdapter = createSelector(
|
||||
const ipAdapter = deepClone(initialIPAdapter);
|
||||
if (model) {
|
||||
ipAdapter.model = zModelIdentifierField.parse(model);
|
||||
if (model.base === 'flux') {
|
||||
ipAdapter.clipVisionModel = 'ViT-L';
|
||||
}
|
||||
}
|
||||
return ipAdapter;
|
||||
}
|
||||
@@ -92,11 +76,10 @@ export const selectDefaultIPAdapter = createSelector(
|
||||
|
||||
export const useAddControlLayer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const func = useCallback(() => {
|
||||
const overrides = { controlAdapter: defaultControlAdapter };
|
||||
const overrides = { controlAdapter: deepClone(initialControlNet) };
|
||||
dispatch(controlLayerAdded({ isSelected: true, overrides }));
|
||||
}, [defaultControlAdapter, dispatch]);
|
||||
}, [dispatch]);
|
||||
|
||||
return func;
|
||||
};
|
||||
@@ -110,150 +93,6 @@ export const useAddRasterLayer = () => {
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useNewRasterLayerFromImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const func = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
},
|
||||
[bboxRect.x, bboxRect.y, dispatch]
|
||||
);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useNewControlLayerFromImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const func = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
},
|
||||
[bboxRect.x, bboxRect.y, dispatch]
|
||||
);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useNewInpaintMaskFromImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const func = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasInpaintMaskState> = {
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
dispatch(inpaintMaskAdded({ overrides, isSelected: true }));
|
||||
},
|
||||
[bboxRect.x, bboxRect.y, dispatch]
|
||||
);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useNewRegionalGuidanceFromImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const func = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasRegionalGuidanceState> = {
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
dispatch(rgAdded({ overrides, isSelected: true }));
|
||||
},
|
||||
[bboxRect.x, bboxRect.y, dispatch]
|
||||
);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a function that adds a new canvas with the given image as the initial image, replicating the img2img flow:
|
||||
* - Reset the canvas
|
||||
* - Resize the bbox to the image's aspect ratio at the optimal size for the selected model
|
||||
* - Add the image as a raster layer
|
||||
* - Resizes the layer to fit the bbox using the 'fill' strategy
|
||||
*
|
||||
* This allows the user to immediately generate a new image from the given image without any additional steps.
|
||||
*/
|
||||
export const useNewCanvasFromImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const base = useAppSelector(selectBboxModelBase);
|
||||
const func = useCallback(
|
||||
(imageDTO: ImageDTO, type: CanvasRasterLayerState['type'] | CanvasControlLayerState['type']) => {
|
||||
// Calculate the new bbox dimensions to fit the image's aspect ratio at the optimal size
|
||||
const ratio = imageDTO.width / imageDTO.height;
|
||||
const optimalDimension = getOptimalDimension(base);
|
||||
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
|
||||
|
||||
// The overrides need to include the layer's ID so we can transform the layer it is initialized
|
||||
let overrides: Partial<CanvasRasterLayerState> | Partial<CanvasControlLayerState>;
|
||||
|
||||
if (type === 'raster_layer') {
|
||||
overrides = {
|
||||
id: getPrefixedId('raster_layer'),
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageDTOToImageObject(imageDTO)],
|
||||
} satisfies Partial<CanvasRasterLayerState>;
|
||||
} else if (type === 'control_layer') {
|
||||
overrides = {
|
||||
id: getPrefixedId('control_layer'),
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageDTOToImageObject(imageDTO)],
|
||||
} satisfies Partial<CanvasControlLayerState>;
|
||||
} else {
|
||||
// Catch unhandled types
|
||||
assert<Equals<typeof type, never>>(false);
|
||||
}
|
||||
|
||||
CanvasEntityAdapterBase.registerInitCallback(async (adapter) => {
|
||||
// Skip the callback if the adapter is not the one we are creating
|
||||
if (adapter.id !== overrides.id) {
|
||||
return false;
|
||||
}
|
||||
// Fit the layer to the bbox w/ fill strategy
|
||||
await adapter.transformer.startTransform({ silent: true });
|
||||
adapter.transformer.fitToBboxFill();
|
||||
await adapter.transformer.applyTransform();
|
||||
return true;
|
||||
});
|
||||
|
||||
dispatch(canvasReset());
|
||||
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
|
||||
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
|
||||
|
||||
// The type casts are safe because the type is checked above
|
||||
if (type === 'raster_layer') {
|
||||
dispatch(rasterLayerAdded({ overrides: overrides as Partial<CanvasRasterLayerState>, isSelected: true }));
|
||||
} else if (type === 'control_layer') {
|
||||
dispatch(controlLayerAdded({ overrides: overrides as Partial<CanvasControlLayerState>, isSelected: true }));
|
||||
} else {
|
||||
// Catch unhandled types
|
||||
assert<Equals<typeof type, never>>(false);
|
||||
}
|
||||
},
|
||||
[base, bboxRect.x, bboxRect.y, dispatch]
|
||||
);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useAddInpaintMask = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const func = useCallback(() => {
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
@@ -25,12 +24,13 @@ import type {
|
||||
Rect,
|
||||
RegionalGuidanceReferenceImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
@@ -64,7 +64,7 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
|
||||
return;
|
||||
}
|
||||
|
||||
let metadata: SerializableObject | undefined = undefined;
|
||||
let metadata: JsonObject | undefined = undefined;
|
||||
|
||||
if (withMetadata) {
|
||||
metadata = selectCanvasMetadata(store.getState());
|
||||
@@ -72,10 +72,16 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
|
||||
|
||||
const result = await withResultAsync(() => {
|
||||
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
return canvasManager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
|
||||
is_intermediate: !saveToGallery,
|
||||
metadata,
|
||||
});
|
||||
return canvasManager.compositor.getCompositeImageDTO(
|
||||
rasterAdapters,
|
||||
rect,
|
||||
{
|
||||
is_intermediate: !saveToGallery,
|
||||
metadata,
|
||||
},
|
||||
undefined,
|
||||
true // force upload the image to ensure it gets added to the gallery
|
||||
);
|
||||
});
|
||||
|
||||
if (result.isOk()) {
|
||||
@@ -223,13 +229,12 @@ export const useNewRasterLayerFromBbox = () => {
|
||||
export const useNewControlLayerFromBbox = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
|
||||
const arg = useMemo<UseSaveCanvasArg>(() => {
|
||||
const onSave = (imageDTO: ImageDTO, rect: Rect) => {
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageDTOToImageObject(imageDTO)],
|
||||
controlAdapter: deepClone(defaultControlAdapter),
|
||||
controlAdapter: deepClone(initialControlNet),
|
||||
position: { x: rect.x, y: rect.y },
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
@@ -242,7 +247,7 @@ export const useNewControlLayerFromBbox = () => {
|
||||
toastOk: t('controlLayers.newControlLayerOk'),
|
||||
toastError: t('controlLayers.newControlLayerError'),
|
||||
};
|
||||
}, [defaultControlAdapter, dispatch, t]);
|
||||
}, [dispatch, t]);
|
||||
const func = useSaveCanvas(arg);
|
||||
return func;
|
||||
};
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { rgbColorToString } from 'common/util/colorCodeTransformers';
|
||||
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useEntitySelectionColor = (entityIdentifier: CanvasEntityIdentifier) => {
|
||||
const selectSelectionColor = useMemo(
|
||||
() =>
|
||||
createSelector(selectCanvasSlice, (canvas) => {
|
||||
const entity = selectEntity(canvas, entityIdentifier);
|
||||
if (!entity) {
|
||||
return 'base.400';
|
||||
} else if (entity.type === 'inpaint_mask') {
|
||||
return rgbColorToString(entity.fill.color);
|
||||
} else if (entity.type === 'regional_guidance') {
|
||||
return rgbColorToString(entity.fill.color);
|
||||
} else {
|
||||
return 'base.400';
|
||||
}
|
||||
}),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const selectionColor = useAppSelector(selectSelectionColor);
|
||||
return selectionColor;
|
||||
};
|
||||
@@ -5,14 +5,10 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
|
||||
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
|
||||
import { canvasToBlob } from 'features/controlLayers/konva/util';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
import { uploadImage } from 'services/api/endpoints/images';
|
||||
|
||||
export const useSaveLayerToAssets = () => {
|
||||
const { t } = useTranslation();
|
||||
const [uploadImage] = useUploadImageMutation();
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
|
||||
const saveLayerToAssets = useCallback(
|
||||
@@ -27,30 +23,17 @@ export const useSaveLayerToAssets = () => {
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const canvas = adapter.getCanvas();
|
||||
const blob = await canvasToBlob(canvas);
|
||||
const file = new File([blob], `layer-${adapter.id}.png`, { type: 'image/png' });
|
||||
await uploadImage({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
postUploadAction: { type: 'TOAST' },
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
});
|
||||
|
||||
toast({
|
||||
status: 'info',
|
||||
title: t('toast.layerSavedToAssets'),
|
||||
});
|
||||
} catch (error) {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('toast.problemSavingLayer'),
|
||||
});
|
||||
}
|
||||
const canvas = adapter.getCanvas();
|
||||
const blob = await canvasToBlob(canvas);
|
||||
const file = new File([blob], `layer-${adapter.id}.png`, { type: 'image/png' });
|
||||
uploadImage({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
});
|
||||
},
|
||||
[t, autoAddBoardId, uploadImage]
|
||||
[autoAddBoardId]
|
||||
);
|
||||
|
||||
return saveLayerToAssets;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
|
||||
import type { CanvasEntityAdapter, CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
@@ -35,12 +34,13 @@ import { t } from 'i18next';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { UploadOptions } from 'services/api/endpoints/images';
|
||||
import type { UploadImageArg } from 'services/api/endpoints/images';
|
||||
import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import stableHash from 'stable-hash';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
type CompositingOptions = {
|
||||
/**
|
||||
@@ -173,14 +173,14 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
return adapters as CanvasEntityAdapterFromType<T>[];
|
||||
};
|
||||
|
||||
getCompositeHash = (adapters: CanvasEntityAdapter[], extra: SerializableObject): string => {
|
||||
const adapterHashes: SerializableObject[] = [];
|
||||
getCompositeHash = (adapters: CanvasEntityAdapter[], extra: JsonObject): string => {
|
||||
const adapterHashes: JsonObject[] = [];
|
||||
|
||||
for (const adapter of adapters) {
|
||||
adapterHashes.push(adapter.getHashableState());
|
||||
}
|
||||
|
||||
const data: SerializableObject = {
|
||||
const data: JsonObject = {
|
||||
extra,
|
||||
adapterHashes,
|
||||
};
|
||||
@@ -253,18 +253,20 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
* @param rect The region to include in the rasterized image
|
||||
* @param uploadOptions Options for uploading the image
|
||||
* @param compositingOptions Options for compositing the entities
|
||||
* @param forceUpload If true, the image is always re-uploaded, returning a new image DTO
|
||||
* @returns A promise that resolves to the image DTO
|
||||
*/
|
||||
getCompositeImageDTO = async (
|
||||
adapters: CanvasEntityAdapter[],
|
||||
rect: Rect,
|
||||
uploadOptions: Pick<UploadOptions, 'is_intermediate' | 'metadata'>,
|
||||
compositingOptions?: CompositingOptions
|
||||
uploadOptions: Pick<UploadImageArg, 'is_intermediate' | 'metadata'>,
|
||||
compositingOptions?: CompositingOptions,
|
||||
forceUpload?: boolean
|
||||
): Promise<ImageDTO> => {
|
||||
assert(rect.width > 0 && rect.height > 0, 'Unable to rasterize empty rect');
|
||||
|
||||
const hash = this.getCompositeHash(adapters, { rect });
|
||||
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
|
||||
const cachedImageName = forceUpload ? undefined : this.manager.cache.imageNameCache.get(hash);
|
||||
|
||||
let imageDTO: ImageDTO | null = null;
|
||||
|
||||
@@ -295,12 +297,12 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
this.$isUploading.set(true);
|
||||
const uploadResult = await withResultAsync(() =>
|
||||
uploadImage({
|
||||
blob,
|
||||
fileName: 'canvas-composite.png',
|
||||
file: new File([blob], 'canvas-composite.png', { type: 'image/png' }),
|
||||
image_category: 'general',
|
||||
is_intermediate: uploadOptions.is_intermediate,
|
||||
board_id: uploadOptions.is_intermediate ? undefined : selectAutoAddBoardId(this.manager.store.getState()),
|
||||
metadata: uploadOptions.metadata,
|
||||
withToast: false,
|
||||
})
|
||||
);
|
||||
this.$isUploading.set(false);
|
||||
@@ -327,6 +329,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
entityIdentifiers: T[],
|
||||
deleteMergedEntities: boolean
|
||||
): Promise<ImageDTO | null> => {
|
||||
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergingLayers'), withCount: false });
|
||||
if (entityIdentifiers.length <= 1) {
|
||||
this.log.warn({ entityIdentifiers }, 'Cannot merge less than 2 entities');
|
||||
return null;
|
||||
@@ -349,7 +352,12 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
|
||||
if (result.isErr()) {
|
||||
this.log.error({ error: serializeError(result.error) }, 'Failed to merge selected entities');
|
||||
toast({ title: t('controlLayers.mergeVisibleError'), status: 'error' });
|
||||
toast({
|
||||
id: 'MERGE_LAYERS_TOAST',
|
||||
title: t('controlLayers.mergeVisibleError'),
|
||||
status: 'error',
|
||||
withCount: false,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -381,7 +389,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
assert<Equals<typeof type, never>>(false, 'Unsupported type for merge');
|
||||
}
|
||||
|
||||
toast({ title: t('controlLayers.mergeVisibleOk') });
|
||||
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergeVisibleOk'), status: 'success', withCount: false });
|
||||
|
||||
return result.value;
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
|
||||
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
|
||||
@@ -37,6 +36,7 @@ import type { Logger } from 'roarr';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import stableHash from 'stable-hash';
|
||||
import { assert } from 'tsafe';
|
||||
import type { Jsonifiable, JsonObject } from 'type-fest';
|
||||
|
||||
// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter`
|
||||
// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. We'll need to do a
|
||||
@@ -111,7 +111,7 @@ export abstract class CanvasEntityAdapterBase<
|
||||
*
|
||||
* This is used for caching.
|
||||
*/
|
||||
abstract getHashableState: () => SerializableObject;
|
||||
abstract getHashableState: () => JsonObject;
|
||||
|
||||
/**
|
||||
* Callbacks that are executed when the module is initialized.
|
||||
@@ -566,7 +566,7 @@ export abstract class CanvasEntityAdapterBase<
|
||||
* Gets a hash of the entity's state, as provided by `getHashableState`. If `extra` is provided, it will be included in
|
||||
* the hash.
|
||||
*/
|
||||
hash = (extra?: SerializableObject): string => {
|
||||
hash = (extra?: Jsonifiable): string => {
|
||||
const arg = {
|
||||
state: this.getHashableState(),
|
||||
extra,
|
||||
@@ -614,8 +614,8 @@ export abstract class CanvasEntityAdapterBase<
|
||||
transformer: this.transformer.repr(),
|
||||
renderer: this.renderer.repr(),
|
||||
bufferRenderer: this.bufferRenderer.repr(),
|
||||
segmentAnything: this.segmentAnything?.repr(),
|
||||
filterer: this.filterer?.repr(),
|
||||
segmentAnything: this.segmentAnything?.repr() ?? null,
|
||||
filterer: this.filterer?.repr() ?? null,
|
||||
hasCache: this.$canvasCache.get() !== null,
|
||||
isLocked: this.$isLocked.get(),
|
||||
isDisabled: this.$isDisabled.get(),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
|
||||
import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
|
||||
import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
|
||||
@@ -9,6 +8,7 @@ import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/Canvas
|
||||
import type { CanvasControlLayerState, CanvasEntityIdentifier, Rect } from 'features/controlLayers/store/types';
|
||||
import type { GroupConfig } from 'konva/lib/Group';
|
||||
import { omit } from 'lodash-es';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
|
||||
CanvasControlLayerState,
|
||||
@@ -77,7 +77,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
|
||||
return canvas;
|
||||
};
|
||||
|
||||
getHashableState = (): SerializableObject => {
|
||||
getHashableState = (): JsonObject => {
|
||||
const keysToOmit: (keyof CanvasControlLayerState)[] = [
|
||||
'name',
|
||||
'controlAdapter',
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
|
||||
import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
|
||||
import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
|
||||
@@ -7,6 +6,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import type { CanvasEntityIdentifier, CanvasInpaintMaskState, Rect } from 'features/controlLayers/store/types';
|
||||
import type { GroupConfig } from 'konva/lib/Group';
|
||||
import { omit } from 'lodash-es';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<
|
||||
CanvasInpaintMaskState,
|
||||
@@ -69,7 +69,7 @@ export class CanvasEntityAdapterInpaintMask extends CanvasEntityAdapterBase<
|
||||
}
|
||||
};
|
||||
|
||||
getHashableState = (): SerializableObject => {
|
||||
getHashableState = (): JsonObject => {
|
||||
const keysToOmit: (keyof CanvasInpaintMaskState)[] = ['fill', 'name', 'opacity', 'isLocked'];
|
||||
return omit(this.state, keysToOmit);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
|
||||
import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
|
||||
import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
|
||||
@@ -9,6 +8,7 @@ import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/Canvas
|
||||
import type { CanvasEntityIdentifier, CanvasRasterLayerState, Rect } from 'features/controlLayers/store/types';
|
||||
import type { GroupConfig } from 'konva/lib/Group';
|
||||
import { omit } from 'lodash-es';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
|
||||
CanvasRasterLayerState,
|
||||
@@ -70,7 +70,7 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
|
||||
return canvas;
|
||||
};
|
||||
|
||||
getHashableState = (): SerializableObject => {
|
||||
getHashableState = (): JsonObject => {
|
||||
const keysToOmit: (keyof CanvasRasterLayerState)[] = ['name', 'isLocked'];
|
||||
return omit(this.state, keysToOmit);
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user