mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 10:58:06 -05:00
Compare commits
16 Commits
ebr/pin-py
...
brandon/sd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
596dc6b4e3 | ||
|
|
b4f93a0ff5 | ||
|
|
f4be52abb4 | ||
|
|
4e029331ba | ||
|
|
40b4de5f77 | ||
|
|
c930807881 | ||
|
|
bec8b27429 | ||
|
|
ef4f466ccf | ||
|
|
3c869ee5ab | ||
|
|
16dc30fd5b | ||
|
|
0c14192819 | ||
|
|
36dadba45b | ||
|
|
f2a9c01d0e | ||
|
|
1ca57ade4d | ||
|
|
85c0e0db1e | ||
|
|
59a2388585 |
@@ -808,11 +808,7 @@ def get_is_installed(
|
||||
for model in installed_models:
|
||||
if model.source == starter_model.source:
|
||||
return True
|
||||
if (
|
||||
(model.name == starter_model.name or model.name in starter_model.previous_names)
|
||||
and model.base == starter_model.base
|
||||
and model.type == starter_model.type
|
||||
):
|
||||
if model.name == starter_model.name and model.base == starter_model.base and model.type == starter_model.type:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
@@ -133,6 +134,7 @@ class FieldDescriptions:
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
mmditx = "MMDiTX"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
@@ -140,6 +142,7 @@ class FieldDescriptions:
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -246,6 +249,12 @@ class FluxConditioningField(BaseModel):
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
89
invokeai/app/invocations/flux_model_loader.py
Normal file
89
invokeai/app/invocations/flux_model_loader.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("sd3_conditioning_output")
|
||||
class SD3ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single SD3 conditioning tensor"""
|
||||
|
||||
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
|
||||
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
241
invokeai/app/invocations/sd3_denoise.py
Normal file
241
invokeai/app/invocations/sd3_denoise.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
SD3ConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_denoise",
|
||||
title="SD3 Denoise",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_text_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype, device: torch.device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
sd3_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
|
||||
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
|
||||
|
||||
t5_embeds = sd3_conditioning.t5_embeds
|
||||
if t5_embeds is None:
|
||||
# TODO(ryand): Construct a zero tensor of the correct shape to use as the T5 conditioning.
|
||||
raise NotImplementedError("SD3 inference without T5 conditioning is not yet supported.")
|
||||
|
||||
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat(
|
||||
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
num_samples: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
||||
"""Prepare the CFG scale list.
|
||||
|
||||
Args:
|
||||
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
||||
on the scheduler used (e.g. higher order schedulers).
|
||||
|
||||
Returns:
|
||||
list[float]: _description_
|
||||
"""
|
||||
if isinstance(self.cfg_scale, float):
|
||||
cfg_scale = [self.cfg_scale] * num_timesteps
|
||||
elif isinstance(self.cfg_scale, list):
|
||||
assert len(self.cfg_scale) == num_timesteps
|
||||
cfg_scale = self.cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
do_classifier_free_guidance = True
|
||||
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype, device
|
||||
)
|
||||
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype, device
|
||||
)
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare the scheduler.
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler.set_timesteps(num_inference_steps=self.num_steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
assert isinstance(timesteps, torch.Tensor)
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(len(timesteps))
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
assert isinstance(num_channels_latents, int)
|
||||
noise = self._get_noise(
|
||||
num_samples=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
latents: torch.Tensor = noise
|
||||
|
||||
total_steps = len(timesteps)
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps[0]),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
with transformer_info.model_on_device() as (cached_weights, transformer):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
|
||||
# 6. Denoising loop
|
||||
for step_idx, t in tqdm(list(enumerate(timesteps))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=None,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG.
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
|
||||
|
||||
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
|
||||
# needed.
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None: ...
|
||||
|
||||
return step_callback
|
||||
97
invokeai/app/invocations/sd3_model_loader.py
Normal file
97
invokeai/app/invocations/sd3_model_loader.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SD3 base model loader output."""
|
||||
|
||||
mmditx: TransformerField = OutputField(description=FieldDescriptions.mmditx, title="MMDiTX")
|
||||
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
|
||||
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_model_loader",
|
||||
title="SD3 Main Model",
|
||||
tags=["model", "sd3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.SD3MainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# TODO(brandon): Setup UI updates to support selecting a clip l model.
|
||||
# clip_l_model: ModelIdentifierField = InputField(
|
||||
# description=FieldDescriptions.clip_l_model,
|
||||
# ui_type=UIType.CLIPEmbedModel,
|
||||
# input=Input.Direct,
|
||||
# title="CLIP L Encoder",
|
||||
# )
|
||||
|
||||
# TODO(brandon): Setup UI updates to support selecting a clip g model.
|
||||
# clip_g_model: ModelIdentifierField = InputField(
|
||||
# description=FieldDescriptions.clip_g_model,
|
||||
# ui_type=UIType.CLIPGModel,
|
||||
# input=Input.Direct,
|
||||
# title="CLIP G Encoder",
|
||||
# )
|
||||
|
||||
# TODO(brandon): Setup UI updates to support selecting an SD3 vae model.
|
||||
# vae_model: ModelIdentifierField = InputField(
|
||||
# description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE", default=None
|
||||
# )
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
mmditx = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_l = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
clip_encoder_g = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
tokenizer_t5 = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
)
|
||||
t5_encoder = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
mmditx=TransformerField(transformer=mmditx, loras=[]),
|
||||
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
|
||||
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
196
invokeai/app/invocations/sd3_text_encoder.py
Normal file
196
invokeai/app/invocations/sd3_text_encoder.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_text_encoder",
|
||||
title="SD3 Text Encoding",
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a SD3 image."""
|
||||
|
||||
clip_l: CLIPField = InputField(
|
||||
title="CLIP L",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
clip_g: CLIPField = InputField(
|
||||
title="CLIP G",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
|
||||
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
|
||||
|
||||
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
|
||||
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
|
||||
|
||||
t5_max_seq_len = 256
|
||||
t5_embeddings: torch.Tensor | None = None
|
||||
if self.t5_encoder is not None:
|
||||
t5_embeddings = self._t5_encode(context, t5_max_seq_len)
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SD3ConditioningInfo(
|
||||
clip_l_embeds=clip_l_embeddings,
|
||||
clip_l_pooled_embeds=clip_l_pooled_embeddings,
|
||||
clip_g_embeds=clip_g_embeddings,
|
||||
clip_g_pooled_embeds=clip_g_pooled_embeddings,
|
||||
t5_embeds=t5_embeddings,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return SD3ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
text_inputs = t5_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
|
||||
|
||||
text_inputs = clip_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(
|
||||
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusion3 = "sd-3"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
@@ -147,6 +150,11 @@ class ModelSourceType(str, Enum):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class MainModelDefaultSettings(BaseModel):
|
||||
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
||||
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
||||
@@ -193,6 +201,9 @@ class ModelConfigBase(BaseModel):
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
|
||||
@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -33,7 +33,10 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubmodelDefinition,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
@@ -112,6 +115,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusion3Pipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
@@ -122,6 +126,8 @@ class ModelProbe(object):
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
"SD3Transformer2DModel": ModelType.Main,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -178,7 +184,7 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
fields["hash"] = "placeholder" # fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
|
||||
@@ -217,6 +223,10 @@ class ModelProbe(object):
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
get_submodels = getattr(probe, "get_submodels", None)
|
||||
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
|
||||
fields["submodels"] = get_submodels()
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@@ -462,9 +472,8 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"normal": "normalbae_image_processor",
|
||||
"sketch": "pidi_image_processor",
|
||||
"scribble": "lineart_image_processor",
|
||||
"lineart anime": "lineart_anime_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"lineart": "lineart_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"hed": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
@@ -747,18 +756,33 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
|
||||
config_path = self.model_path / "unet" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
# Handle pipelines with a transformer (i.e. SD3).
|
||||
config_path = self.model_path / "transformer" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
transformer_conf = json.load(file)
|
||||
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
|
||||
return BaseModelType.StableDiffusion3
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
@@ -770,6 +794,21 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
||||
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
|
||||
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
|
||||
submodels: Dict[SubModelType, SubmodelDefinition] = {}
|
||||
for key, value in config.items():
|
||||
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
|
||||
continue
|
||||
model_loader = str(value[1])
|
||||
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=(self.model_path / key).resolve().as_posix(),
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
return submodels
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
|
||||
@@ -13,9 +13,6 @@ class StarterModelWithoutDependencies(BaseModel):
|
||||
type: ModelType
|
||||
format: Optional[ModelFormat] = None
|
||||
is_installed: bool = False
|
||||
# allows us to track what models a user has installed across name changes within starter models
|
||||
# if you update a starter model name, please add the old one to this list for that starter model
|
||||
previous_names: list[str] = []
|
||||
|
||||
|
||||
class StarterModel(StarterModelWithoutDependencies):
|
||||
@@ -246,49 +243,44 @@ easy_neg_sd1 = StarterModel(
|
||||
# endregion
|
||||
# region IP Adapter
|
||||
ip_adapter_sd1 = StarterModel(
|
||||
name="Standard Reference (IP Adapter)",
|
||||
name="IP Adapter",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
description="IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
previous_names=["IP Adapter"],
|
||||
)
|
||||
ip_adapter_plus_sd1 = StarterModel(
|
||||
name="Precise Reference (IP Adapter Plus)",
|
||||
name="IP Adapter Plus",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
|
||||
description="References images with a higher degree of precision.",
|
||||
description="Refined IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
previous_names=["IP Adapter Plus"],
|
||||
)
|
||||
ip_adapter_plus_face_sd1 = StarterModel(
|
||||
name="Face Reference (IP Adapter Plus Face)",
|
||||
name="IP Adapter Plus Face",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
|
||||
description="References images with a higher degree of precision, adapted for faces",
|
||||
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
previous_names=["IP Adapter Plus Face"],
|
||||
)
|
||||
ip_adapter_sdxl = StarterModel(
|
||||
name="Standard Reference (IP Adapter ViT-H)",
|
||||
name="IP Adapter SDXL",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
|
||||
description="References images with a higher degree of precision.",
|
||||
description="IP-Adapter for SDXL models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||
previous_names=["IP Adapter SDXL"],
|
||||
)
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter)",
|
||||
name="XLabs FLUX IP-Adapter",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
description="FLUX IP-Adapter",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[clip_vit_l_image_encoder],
|
||||
previous_names=["XLabs FLUX IP-Adapter"],
|
||||
)
|
||||
# endregion
|
||||
# region ControlNet
|
||||
@@ -307,162 +299,157 @@ qr_code_cnet_sdxl = StarterModel(
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
canny_sd1 = StarterModel(
|
||||
name="Hard Edge Detection (canny)",
|
||||
name="canny",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="Uses detected edges in the image to control composition.",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["canny"],
|
||||
)
|
||||
inpaint_cnet_sd1 = StarterModel(
|
||||
name="Inpainting",
|
||||
name="inpaint",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["inpaint"],
|
||||
)
|
||||
mlsd_sd1 = StarterModel(
|
||||
name="Line Drawing (mlsd)",
|
||||
name="mlsd",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="Uses straight line detection for controlling the generation.",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["mlsd"],
|
||||
)
|
||||
depth_sd1 = StarterModel(
|
||||
name="Depth Map",
|
||||
name="depth",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["depth"],
|
||||
)
|
||||
normal_bae_sd1 = StarterModel(
|
||||
name="Lighting Detection (Normals)",
|
||||
name="normal_bae",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="Uses detected lighting information to guide the lighting of the composition.",
|
||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["normal_bae"],
|
||||
)
|
||||
seg_sd1 = StarterModel(
|
||||
name="Segmentation Map",
|
||||
name="seg",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="Uses segmentation maps to guide the structure of the composition.",
|
||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["seg"],
|
||||
)
|
||||
lineart_sd1 = StarterModel(
|
||||
name="Lineart",
|
||||
name="lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="Uses lineart detection to guide the lighting of the composition.",
|
||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["lineart"],
|
||||
)
|
||||
lineart_anime_sd1 = StarterModel(
|
||||
name="Lineart Anime",
|
||||
name="lineart_anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="Uses anime lineart detection to guide the lighting of the composition.",
|
||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["lineart_anime"],
|
||||
)
|
||||
openpose_sd1 = StarterModel(
|
||||
name="Pose Detection (openpose)",
|
||||
name="openpose",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="Uses pose information to control the pose of human characters in the generation.",
|
||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["openpose"],
|
||||
)
|
||||
scribble_sd1 = StarterModel(
|
||||
name="Contour Detection (scribble)",
|
||||
name="scribble",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="Uses edges, contours, or line art in the image to control composition.",
|
||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["scribble"],
|
||||
)
|
||||
softedge_sd1 = StarterModel(
|
||||
name="Soft Edge Detection (softedge)",
|
||||
name="softedge",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="Uses a soft edge detection map to control composition.",
|
||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["softedge"],
|
||||
)
|
||||
shuffle_sd1 = StarterModel(
|
||||
name="Remix (shuffle)",
|
||||
name="shuffle",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["shuffle"],
|
||||
)
|
||||
tile_sd1 = StarterModel(
|
||||
name="Tile",
|
||||
name="tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="Uses image data to replicate exact colors/structure in the resulting generation.",
|
||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
ip2p_sd1 = StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["tile"],
|
||||
)
|
||||
canny_sdxl = StarterModel(
|
||||
name="Hard Edge Detection (canny)",
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||
description="Uses detected edges in the image to control composition.",
|
||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["canny-sdxl"],
|
||||
)
|
||||
depth_sdxl = StarterModel(
|
||||
name="Depth Map",
|
||||
name="depth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["depth-sdxl"],
|
||||
)
|
||||
softedge_sdxl = StarterModel(
|
||||
name="Soft Edge Detection (softedge)",
|
||||
name="softedge-dexined-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||
description="Uses a soft edge detection map to control composition.",
|
||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
depth_zoe_16_sdxl = StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
depth_zoe_32_sdxl = StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["softedge-dexined-sdxl"],
|
||||
)
|
||||
openpose_sdxl = StarterModel(
|
||||
name="Pose Detection (openpose)",
|
||||
name="openpose-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||
description="Uses pose information to control the pose of human characters in the generation.",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["openpose-sdxl", "controlnet-openpose-sdxl"],
|
||||
)
|
||||
scribble_sdxl = StarterModel(
|
||||
name="Contour Detection (scribble)",
|
||||
name="scribble-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||
description="Uses edges, contours, or line art in the image to control composition.",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["scribble-sdxl", "controlnet-scribble-sdxl"],
|
||||
)
|
||||
tile_sdxl = StarterModel(
|
||||
name="Tile",
|
||||
name="tile-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||
description="Uses image data to replicate exact colors/structure in the resulting generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["tile-sdxl"],
|
||||
)
|
||||
union_cnet_sdxl = StarterModel(
|
||||
name="Multi-Guidance Detection (Union Pro)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="InvokeAI/Xinsir-SDXL_Controlnet_Union",
|
||||
description="A unified ControlNet for SDXL model that supports 10+ control types",
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
union_cnet_flux = StarterModel(
|
||||
@@ -475,52 +462,60 @@ union_cnet_flux = StarterModel(
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
t2i_canny_sd1 = StarterModel(
|
||||
name="Hard Edge Detection (canny)",
|
||||
name="canny-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_canny_sd15v2",
|
||||
description="Uses detected edges in the image to control composition",
|
||||
description="T2I Adapter weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["canny-sd15"],
|
||||
)
|
||||
t2i_sketch_sd1 = StarterModel(
|
||||
name="Sketch",
|
||||
name="sketch-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_sketch_sd15v2",
|
||||
description="Uses a sketch to control composition",
|
||||
description="T2I Adapter weights trained on sd-1.5 with sketch conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["sketch-sd15"],
|
||||
)
|
||||
t2i_depth_sd1 = StarterModel(
|
||||
name="Depth Map",
|
||||
name="depth-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_depth_sd15v2",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
description="T2I Adapter weights trained on sd-1.5 with depth conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
)
|
||||
t2i_zoe_depth_sd1 = StarterModel(
|
||||
name="zoedepth-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_zoedepth_sd15v1",
|
||||
description="T2I Adapter weights trained on sd-1.5 with zoe depth conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["depth-sd15"],
|
||||
)
|
||||
t2i_canny_sdxl = StarterModel(
|
||||
name="Hard Edge Detection (canny)",
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-canny-sdxl-1.0",
|
||||
description="Uses detected edges in the image to control composition",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with canny conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
)
|
||||
t2i_zoe_depth_sdxl = StarterModel(
|
||||
name="zoedepth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with zoe depth conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["canny-sdxl"],
|
||||
)
|
||||
t2i_lineart_sdxl = StarterModel(
|
||||
name="Lineart",
|
||||
name="lineart-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-lineart-sdxl-1.0",
|
||||
description="Uses lineart detection to guide the lighting of the composition.",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with lineart conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["lineart-sdxl"],
|
||||
)
|
||||
t2i_sketch_sdxl = StarterModel(
|
||||
name="Sketch",
|
||||
name="sketch-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
||||
description="Uses a sketch to control composition",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with sketch conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["sketch-sdxl"],
|
||||
)
|
||||
# endregion
|
||||
# region SpandrelImageToImage
|
||||
@@ -605,18 +600,22 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
softedge_sd1,
|
||||
shuffle_sd1,
|
||||
tile_sd1,
|
||||
ip2p_sd1,
|
||||
canny_sdxl,
|
||||
depth_sdxl,
|
||||
softedge_sdxl,
|
||||
depth_zoe_16_sdxl,
|
||||
depth_zoe_32_sdxl,
|
||||
openpose_sdxl,
|
||||
scribble_sdxl,
|
||||
tile_sdxl,
|
||||
union_cnet_sdxl,
|
||||
union_cnet_flux,
|
||||
t2i_canny_sd1,
|
||||
t2i_sketch_sd1,
|
||||
t2i_depth_sd1,
|
||||
t2i_zoe_depth_sd1,
|
||||
t2i_canny_sdxl,
|
||||
t2i_zoe_depth_sdxl,
|
||||
t2i_lineart_sdxl,
|
||||
t2i_sketch_sdxl,
|
||||
realesrgan_x4,
|
||||
@@ -647,6 +646,7 @@ sd1_bundle: list[StarterModel] = [
|
||||
softedge_sd1,
|
||||
shuffle_sd1,
|
||||
tile_sd1,
|
||||
ip2p_sd1,
|
||||
swinir,
|
||||
]
|
||||
|
||||
@@ -657,6 +657,8 @@ sdxl_bundle: list[StarterModel] = [
|
||||
canny_sdxl,
|
||||
depth_sdxl,
|
||||
softedge_sdxl,
|
||||
depth_zoe_16_sdxl,
|
||||
depth_zoe_32_sdxl,
|
||||
openpose_sdxl,
|
||||
scribble_sdxl,
|
||||
tile_sdxl,
|
||||
|
||||
@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SD3ConditioningInfo:
|
||||
clip_l_pooled_embeds: torch.Tensor
|
||||
clip_l_embeds: torch.Tensor
|
||||
clip_g_pooled_embeds: torch.Tensor
|
||||
clip_g_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor | None
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
|
||||
if self.t5_embeds is not None:
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
| List[FLUXConditioningInfo]
|
||||
| List[SD3ConditioningInfo]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlNetMixin
|
||||
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
from diffusers.models.embeddings import (
|
||||
@@ -32,7 +32,9 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
# NOTE(ryand): I'm not the origina author of this code, but for future reference, it appears that this class was copied
|
||||
# from diffusers in order to add support for the encoder_attention_mask argument.
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@fontsource-variable/inter": "^5.1.0",
|
||||
"@invoke-ai/ui-library": "^0.0.43",
|
||||
"@invoke-ai/ui-library": "^0.0.42",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.2.3",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
|
||||
14
invokeai/frontend/web/pnpm-lock.yaml
generated
14
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -24,8 +24,8 @@ dependencies:
|
||||
specifier: ^5.1.0
|
||||
version: 5.1.0
|
||||
'@invoke-ai/ui-library':
|
||||
specifier: ^0.0.43
|
||||
version: 0.0.43(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^0.0.42
|
||||
version: 0.0.42(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@nanostores/react':
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3(nanostores@0.11.3)(react@18.3.1)
|
||||
@@ -1696,20 +1696,20 @@ packages:
|
||||
prettier: 3.3.3
|
||||
dev: true
|
||||
|
||||
/@invoke-ai/ui-library@0.0.43(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-t3fPYyks07ue3dEBPJuTHbeDLnDckDCOrtvc07mMDbLOnlPEZ0StaeiNGH+oO8qLzAuMAlSTdswgHfzTc2MmPw==}
|
||||
/@invoke-ai/ui-library@0.0.42(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.1.0)(@types/react@18.3.11)(i18next@23.15.1)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-OuDXRipBO5mu+Nv4qN8cd8MiwiGBdq6h4PirVgPI9/ltbdcIzePgUJ0dJns26lflHSTRWW38I16wl4YTw3mNWA==}
|
||||
peerDependencies:
|
||||
'@fontsource-variable/inter': ^5.0.16
|
||||
react: ^18.2.0
|
||||
react-dom: ^18.2.0
|
||||
dependencies:
|
||||
'@chakra-ui/anatomy': 2.3.4
|
||||
'@chakra-ui/anatomy': 2.2.2
|
||||
'@chakra-ui/icons': 2.2.4(@chakra-ui/react@2.10.2)(react@18.3.1)
|
||||
'@chakra-ui/layout': 2.3.1(@chakra-ui/system@2.6.2)(react@18.3.1)
|
||||
'@chakra-ui/portal': 2.1.0(react-dom@18.3.1)(react@18.3.1)
|
||||
'@chakra-ui/react': 2.10.2(@emotion/react@11.13.3)(@emotion/styled@11.13.0)(@types/react@18.3.11)(framer-motion@11.10.0)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@chakra-ui/styled-system': 2.11.2(react@18.3.1)
|
||||
'@chakra-ui/theme-tools': 2.2.6(@chakra-ui/styled-system@2.11.2)(react@18.3.1)
|
||||
'@chakra-ui/styled-system': 2.9.2
|
||||
'@chakra-ui/theme-tools': 2.1.2(@chakra-ui/styled-system@2.9.2)
|
||||
'@emotion/react': 11.13.3(@types/react@18.3.11)(react@18.3.1)
|
||||
'@emotion/styled': 11.13.0(@emotion/react@11.13.3)(@types/react@18.3.11)(react@18.3.1)
|
||||
'@fontsource-variable/inter': 5.1.0
|
||||
|
||||
@@ -94,7 +94,6 @@
|
||||
"close": "Close",
|
||||
"copy": "Copy",
|
||||
"copyError": "$t(gallery.copy) Error",
|
||||
"clipboard": "Clipboard",
|
||||
"on": "On",
|
||||
"off": "Off",
|
||||
"or": "or",
|
||||
@@ -1252,33 +1251,6 @@
|
||||
"heading": "Mask Adjustments",
|
||||
"paragraphs": ["Adjust the mask."]
|
||||
},
|
||||
"inpainting": {
|
||||
"heading": "Inpainting",
|
||||
"paragraphs": ["Controls which area is modified, guided by Denoising Strength."]
|
||||
},
|
||||
"rasterLayer": {
|
||||
"heading": "Raster Layer",
|
||||
"paragraphs": ["Pixel-based content of your canvas, used during image generation."]
|
||||
},
|
||||
"regionalGuidance": {
|
||||
"heading": "Regional Guidance",
|
||||
"paragraphs": ["Brush to guide where elements from global prompts should appear."]
|
||||
},
|
||||
"regionalGuidanceAndReferenceImage": {
|
||||
"heading": "Regional Guidance and Regional Reference Image",
|
||||
"paragraphs": [
|
||||
"For Regional Guidance, brush to guide where elements from global prompts should appear.",
|
||||
"For Regional Reference Image, brush to apply a reference image to specific areas."
|
||||
]
|
||||
},
|
||||
"globalReferenceImage": {
|
||||
"heading": "Global Reference Image",
|
||||
"paragraphs": ["Applies a reference image to influence the entire generation."]
|
||||
},
|
||||
"regionalReferenceImage": {
|
||||
"heading": "Regional Reference Image",
|
||||
"paragraphs": ["Brush to apply a reference image to specific areas."]
|
||||
},
|
||||
"controlNet": {
|
||||
"heading": "ControlNet",
|
||||
"paragraphs": [
|
||||
@@ -1716,18 +1688,8 @@
|
||||
"layer_other": "Layers",
|
||||
"layer_withCount_one": "Layer ({{count}})",
|
||||
"layer_withCount_other": "Layers ({{count}})",
|
||||
"convertRasterLayerTo": "Convert $t(controlLayers.rasterLayer) To",
|
||||
"convertControlLayerTo": "Convert $t(controlLayers.controlLayer) To",
|
||||
"convertInpaintMaskTo": "Convert $t(controlLayers.inpaintMask) To",
|
||||
"convertRegionalGuidanceTo": "Convert $t(controlLayers.regionalGuidance) To",
|
||||
"copyRasterLayerTo": "Copy $t(controlLayers.rasterLayer) To",
|
||||
"copyControlLayerTo": "Copy $t(controlLayers.controlLayer) To",
|
||||
"copyInpaintMaskTo": "Copy $t(controlLayers.inpaintMask) To",
|
||||
"copyRegionalGuidanceTo": "Copy $t(controlLayers.regionalGuidance) To",
|
||||
"newRasterLayer": "New $t(controlLayers.rasterLayer)",
|
||||
"newControlLayer": "New $t(controlLayers.controlLayer)",
|
||||
"newInpaintMask": "New $t(controlLayers.inpaintMask)",
|
||||
"newRegionalGuidance": "New $t(controlLayers.regionalGuidance)",
|
||||
"convertToControlLayer": "Convert to Control Layer",
|
||||
"convertToRasterLayer": "Convert to Raster Layer",
|
||||
"transparency": "Transparency",
|
||||
"enableTransparencyEffect": "Enable Transparency Effect",
|
||||
"disableTransparencyEffect": "Disable Transparency Effect",
|
||||
@@ -1883,11 +1845,11 @@
|
||||
"segment": {
|
||||
"autoMask": "Auto Mask",
|
||||
"pointType": "Point Type",
|
||||
"include": "Include",
|
||||
"exclude": "Exclude",
|
||||
"foreground": "Foreground",
|
||||
"background": "Background",
|
||||
"neutral": "Neutral",
|
||||
"reset": "Reset",
|
||||
"saveAs": "Save As",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel",
|
||||
"process": "Process"
|
||||
},
|
||||
|
||||
@@ -26,9 +26,5 @@ export const IconMenuItem = ({ tooltip, icon, ...props }: Props) => {
|
||||
};
|
||||
|
||||
export const IconMenuItemGroup = ({ children }: { children: ReactNode }) => {
|
||||
return (
|
||||
<Flex gap={2} justifyContent="space-between">
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
return <Flex gap={2}>{children}</Flex>;
|
||||
};
|
||||
|
||||
@@ -23,10 +23,8 @@ export type Feature =
|
||||
| 'dynamicPrompts'
|
||||
| 'dynamicPromptsMaxPrompts'
|
||||
| 'dynamicPromptsSeedBehaviour'
|
||||
| 'globalReferenceImage'
|
||||
| 'imageFit'
|
||||
| 'infillMethod'
|
||||
| 'inpainting'
|
||||
| 'ipAdapterMethod'
|
||||
| 'lora'
|
||||
| 'loraWeight'
|
||||
@@ -48,7 +46,6 @@ export type Feature =
|
||||
| 'paramVAEPrecision'
|
||||
| 'paramWidth'
|
||||
| 'patchmatchDownScaleSize'
|
||||
| 'rasterLayer'
|
||||
| 'refinerModel'
|
||||
| 'refinerNegativeAestheticScore'
|
||||
| 'refinerPositiveAestheticScore'
|
||||
@@ -56,9 +53,6 @@ export type Feature =
|
||||
| 'refinerStart'
|
||||
| 'refinerSteps'
|
||||
| 'refinerCfgScale'
|
||||
| 'regionalGuidance'
|
||||
| 'regionalGuidanceAndReferenceImage'
|
||||
| 'regionalReferenceImage'
|
||||
| 'scaleBeforeProcessing'
|
||||
| 'seamlessTilingXAxis'
|
||||
| 'seamlessTilingYAxis'
|
||||
@@ -82,24 +76,6 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
|
||||
clipSkip: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000178161-advanced-settings',
|
||||
},
|
||||
inpainting: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000096702-inpainting-outpainting-and-bounding-box',
|
||||
},
|
||||
rasterLayer: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000094998-raster-layers-and-initial-images',
|
||||
},
|
||||
regionalGuidance: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000165024-regional-guidance-layers',
|
||||
},
|
||||
regionalGuidanceAndReferenceImage: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000165024-regional-guidance-layers',
|
||||
},
|
||||
globalReferenceImage: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000159340-global-and-regional-reference-images-ip-adapters-',
|
||||
},
|
||||
regionalReferenceImage: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000159340-global-and-regional-reference-images-ip-adapters-',
|
||||
},
|
||||
controlNet: {
|
||||
href: 'https://support.invoke.ai/support/solutions/articles/151000105880',
|
||||
},
|
||||
|
||||
@@ -127,6 +127,8 @@ export const buildUseDisclosure = (defaultIsOpen: boolean): [() => UseDisclosure
|
||||
*
|
||||
* Hook to manage a boolean state. Use this for a local boolean state.
|
||||
* @param defaultIsOpen Initial state of the disclosure
|
||||
*
|
||||
* @knipignore
|
||||
*/
|
||||
export const useDisclosure = (defaultIsOpen: boolean): UseDisclosure => {
|
||||
const [isOpen, set] = useState(defaultIsOpen);
|
||||
|
||||
@@ -16,7 +16,6 @@ type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
isLoading?: boolean;
|
||||
groupByType?: boolean;
|
||||
showDescriptions?: boolean;
|
||||
};
|
||||
|
||||
type UseGroupedModelComboboxReturn = {
|
||||
@@ -38,15 +37,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const base = useAppSelector(selectBaseWithSDXLFallback);
|
||||
const {
|
||||
modelConfigs,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
onChange,
|
||||
isLoading,
|
||||
groupByType = false,
|
||||
showDescriptions = false,
|
||||
} = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg;
|
||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||
if (!modelConfigs) {
|
||||
return [];
|
||||
@@ -60,7 +51,6 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
options: val.map((model) => ({
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
description: (showDescriptions && model.description) || undefined,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
})),
|
||||
});
|
||||
@@ -70,7 +60,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
);
|
||||
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base) ? -1 : 1));
|
||||
return _options;
|
||||
}, [modelConfigs, groupByType, getIsDisabled, base, showDescriptions]);
|
||||
}, [modelConfigs, groupByType, getIsDisabled, base]);
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
import type { MenuButtonProps, MenuItemProps, MenuListProps, MenuProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Icon, Text } from '@invoke-ai/ui-library';
|
||||
import { useDisclosure } from 'common/hooks/useBoolean';
|
||||
import type { FocusEventHandler, PointerEvent, RefObject } from 'react';
|
||||
import { useCallback, useEffect, useRef } from 'react';
|
||||
import { PiCaretRightBold } from 'react-icons/pi';
|
||||
import { useDebouncedCallback } from 'use-debounce';
|
||||
|
||||
const offset: [number, number] = [0, 8];
|
||||
|
||||
type UseSubMenuReturn = {
|
||||
parentMenuItemProps: Partial<MenuItemProps>;
|
||||
menuProps: Partial<MenuProps>;
|
||||
menuButtonProps: Partial<MenuButtonProps>;
|
||||
menuListProps: Partial<MenuListProps> & { ref: RefObject<HTMLDivElement> };
|
||||
};
|
||||
|
||||
/**
|
||||
* A hook that provides the necessary props to create a sub-menu within a menu.
|
||||
*
|
||||
* The sub-menu should be wrapped inside a parent `MenuItem` component.
|
||||
*
|
||||
* Use SubMenuButtonContent to render a button with a label and a right caret icon.
|
||||
*
|
||||
* TODO(psyche): Add keyboard handling for sub-menu.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* const SubMenuExample = () => {
|
||||
* const subMenu = useSubMenu();
|
||||
* return (
|
||||
* <Menu>
|
||||
* <MenuButton>Open Parent Menu</MenuButton>
|
||||
* <MenuList>
|
||||
* <MenuItem>Parent Item 1</MenuItem>
|
||||
* <MenuItem>Parent Item 2</MenuItem>
|
||||
* <MenuItem>Parent Item 3</MenuItem>
|
||||
* <MenuItem {...subMenu.parentMenuItemProps} icon={<PiImageBold />}>
|
||||
* <Menu {...subMenu.menuProps}>
|
||||
* <MenuButton {...subMenu.menuButtonProps}>
|
||||
* <SubMenuButtonContent label="Open Sub Menu" />
|
||||
* </MenuButton>
|
||||
* <MenuList {...subMenu.menuListProps}>
|
||||
* <MenuItem>Sub Item 1</MenuItem>
|
||||
* <MenuItem>Sub Item 2</MenuItem>
|
||||
* <MenuItem>Sub Item 3</MenuItem>
|
||||
* </MenuList>
|
||||
* </Menu>
|
||||
* </MenuItem>
|
||||
* </MenuList>
|
||||
* </Menu>
|
||||
* );
|
||||
* };
|
||||
* ```
|
||||
*/
|
||||
export const useSubMenu = (): UseSubMenuReturn => {
|
||||
const subMenu = useDisclosure(false);
|
||||
const menuListRef = useRef<HTMLDivElement>(null);
|
||||
const closeDebounced = useDebouncedCallback(subMenu.close, 300);
|
||||
const openAndCancelPendingClose = useCallback(() => {
|
||||
closeDebounced.cancel();
|
||||
subMenu.open();
|
||||
}, [closeDebounced, subMenu]);
|
||||
const toggleAndCancelPendingClose = useCallback(() => {
|
||||
if (subMenu.isOpen) {
|
||||
subMenu.close();
|
||||
return;
|
||||
} else {
|
||||
closeDebounced.cancel();
|
||||
subMenu.toggle();
|
||||
}
|
||||
}, [closeDebounced, subMenu]);
|
||||
const onBlurMenuList = useCallback<FocusEventHandler<HTMLDivElement>>(
|
||||
(e) => {
|
||||
// Don't trigger blur if focus is moving to a child element - e.g. from a sub-menu item to another sub-menu item
|
||||
if (e.currentTarget.contains(e.relatedTarget)) {
|
||||
closeDebounced.cancel();
|
||||
return;
|
||||
}
|
||||
subMenu.close();
|
||||
},
|
||||
[closeDebounced, subMenu]
|
||||
);
|
||||
|
||||
const onParentMenuItemPointerLeave = useCallback(
|
||||
(e: PointerEvent<HTMLButtonElement>) => {
|
||||
/**
|
||||
* The pointerleave event is triggered when the pen or touch device is lifted, which would close the sub-menu.
|
||||
* However, we want to keep the sub-menu open until the pen or touch device pressed some other element. This
|
||||
* will be handled in the useEffect below - just ignore the pointerleave event for pen and touch devices.
|
||||
*/
|
||||
if (e.pointerType === 'pen' || e.pointerType === 'touch') {
|
||||
return;
|
||||
}
|
||||
subMenu.close();
|
||||
},
|
||||
[subMenu]
|
||||
);
|
||||
|
||||
/**
|
||||
* When using a mouse, the pointerleave events close the menu. But when using a pen or touch device, we need to close
|
||||
* the sub-menu when the user taps outside of the menu list. So we need to listen for clicks outside of the menu list
|
||||
* and close the menu accordingly.
|
||||
*/
|
||||
useEffect(() => {
|
||||
const el = menuListRef.current;
|
||||
if (!el) {
|
||||
return;
|
||||
}
|
||||
const controller = new AbortController();
|
||||
window.addEventListener(
|
||||
'click',
|
||||
(e) => {
|
||||
if (menuListRef.current?.contains(e.target as Node)) {
|
||||
return;
|
||||
}
|
||||
subMenu.close();
|
||||
},
|
||||
{ signal: controller.signal }
|
||||
);
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [subMenu]);
|
||||
|
||||
return {
|
||||
parentMenuItemProps: {
|
||||
onClick: toggleAndCancelPendingClose,
|
||||
onPointerEnter: openAndCancelPendingClose,
|
||||
onPointerLeave: onParentMenuItemPointerLeave,
|
||||
closeOnSelect: false,
|
||||
},
|
||||
menuProps: {
|
||||
isOpen: subMenu.isOpen,
|
||||
onClose: subMenu.close,
|
||||
placement: 'right',
|
||||
offset: offset,
|
||||
closeOnBlur: false,
|
||||
},
|
||||
menuButtonProps: {
|
||||
as: Box,
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
},
|
||||
menuListProps: {
|
||||
ref: menuListRef,
|
||||
onPointerEnter: openAndCancelPendingClose,
|
||||
onPointerLeave: closeDebounced,
|
||||
onBlur: onBlurMenuList,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export const SubMenuButtonContent = ({ label }: { label: string }) => {
|
||||
return (
|
||||
<Flex w="full" h="full" flexDir="row" justifyContent="space-between" alignItems="center">
|
||||
<Text>{label}</Text>
|
||||
<Icon as={PiCaretRightBold} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddGlobalReferenceImage,
|
||||
@@ -29,80 +28,69 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
<Flex position="relative" flexDir="column" gap={4} top="20%">
|
||||
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
|
||||
<Heading size="xs">{t('controlLayers.global')}</Heading>
|
||||
<InformationalPopover feature="globalReferenceImage">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addGlobalReferenceImage}
|
||||
>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</Button>
|
||||
</InformationalPopover>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addGlobalReferenceImage}
|
||||
>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Heading size="xs">{t('controlLayers.regional')}</Heading>
|
||||
<InformationalPopover feature="inpainting">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addInpaintMask}
|
||||
>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</Button>
|
||||
</InformationalPopover>
|
||||
<InformationalPopover feature="regionalGuidance">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidance}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
</InformationalPopover>
|
||||
<InformationalPopover feature="regionalReferenceImage">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalReferenceImage}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</Button>
|
||||
</InformationalPopover>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addInpaintMask}
|
||||
>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidance}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalReferenceImage}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
|
||||
<Heading size="xs">{t('controlLayers.layer_other')}</Heading>
|
||||
<InformationalPopover feature="controlNet">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addControlLayer}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
</InformationalPopover>
|
||||
<InformationalPopover feature="rasterLayer">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRasterLayer}
|
||||
>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</Button>
|
||||
</InformationalPopover>
|
||||
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addControlLayer}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRasterLayer}
|
||||
>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { MenuGroup, MenuItem } from '@invoke-ai/ui-library';
|
||||
import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox';
|
||||
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
|
||||
import {
|
||||
@@ -17,8 +16,6 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasContextMenuGlobalMenuItems = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const saveSubMenu = useSubMenu();
|
||||
const newSubMenu = useSubMenu();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const saveCanvasToGallery = useSaveCanvasToGallery();
|
||||
const saveBboxToGallery = useSaveBboxToGallery();
|
||||
@@ -31,41 +28,27 @@ export const CanvasContextMenuGlobalMenuItems = memo(() => {
|
||||
<>
|
||||
<MenuGroup title={t('controlLayers.canvasContextMenu.canvasGroup')}>
|
||||
<CanvasContextMenuItemsCropCanvasToBbox />
|
||||
<MenuItem {...saveSubMenu.parentMenuItemProps} icon={<PiFloppyDiskBold />}>
|
||||
<Menu {...saveSubMenu.menuProps}>
|
||||
<MenuButton {...saveSubMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.canvasContextMenu.saveToGalleryGroup')} />
|
||||
</MenuButton>
|
||||
<MenuList {...saveSubMenu.menuListProps}>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveCanvasToGallery}>
|
||||
{t('controlLayers.canvasContextMenu.saveCanvasToGallery')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveBboxToGallery}>
|
||||
{t('controlLayers.canvasContextMenu.saveBboxToGallery')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.canvasContextMenu.saveToGalleryGroup')}>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveCanvasToGallery}>
|
||||
{t('controlLayers.canvasContextMenu.saveCanvasToGallery')}
|
||||
</MenuItem>
|
||||
<MenuItem {...newSubMenu.parentMenuItemProps} icon={<NewLayerIcon />}>
|
||||
<Menu {...newSubMenu.menuProps}>
|
||||
<MenuButton {...newSubMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.canvasContextMenu.bboxGroup')} />
|
||||
</MenuButton>
|
||||
<MenuList {...newSubMenu.menuListProps}>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newGlobalReferenceImageFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRegionalReferenceImageFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newControlLayerFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newControlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRasterLayerFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newRasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
<MenuItem icon={<PiFloppyDiskBold />} isDisabled={isBusy} onClick={saveBboxToGallery}>
|
||||
{t('controlLayers.canvasContextMenu.saveBboxToGallery')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.canvasContextMenu.bboxGroup')}>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newGlobalReferenceImageFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRegionalReferenceImageFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newControlLayerFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newControlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} isDisabled={isBusy} onClick={newRasterLayerFromBbox}>
|
||||
{t('controlLayers.canvasContextMenu.newRasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
</>
|
||||
|
||||
@@ -1,40 +1,42 @@
|
||||
import { MenuGroup } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ControlLayerMenuItems } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItems';
|
||||
import { InpaintMaskMenuItems } from 'features/controlLayers/components/InpaintMask/InpaintMaskMenuItems';
|
||||
import { IPAdapterMenuItems } from 'features/controlLayers/components/IPAdapter/IPAdapterMenuItems';
|
||||
import { RasterLayerMenuItems } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItems';
|
||||
import { RegionalGuidanceMenuItems } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItems';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import {
|
||||
EntityIdentifierContext,
|
||||
useEntityIdentifierContext,
|
||||
} from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityTitle } from 'features/controlLayers/hooks/useEntityTitle';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import {
|
||||
isFilterableEntityIdentifier,
|
||||
isSaveableEntityIdentifier,
|
||||
isSegmentableEntityIdentifier,
|
||||
isTransformableEntityIdentifier,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const CanvasContextMenuSelectedEntityMenuItemsContent = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const title = useEntityTitle(entityIdentifier);
|
||||
|
||||
if (entityIdentifier.type === 'raster_layer') {
|
||||
return <RasterLayerMenuItems />;
|
||||
}
|
||||
if (entityIdentifier.type === 'control_layer') {
|
||||
return <ControlLayerMenuItems />;
|
||||
}
|
||||
if (entityIdentifier.type === 'inpaint_mask') {
|
||||
return <InpaintMaskMenuItems />;
|
||||
}
|
||||
if (entityIdentifier.type === 'regional_guidance') {
|
||||
return <RegionalGuidanceMenuItems />;
|
||||
}
|
||||
if (entityIdentifier.type === 'reference_image') {
|
||||
return <IPAdapterMenuItems />;
|
||||
}
|
||||
|
||||
assert<Equals<typeof entityIdentifier.type, never>>(false);
|
||||
return (
|
||||
<MenuGroup title={title}>
|
||||
{isFilterableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsFilter />}
|
||||
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsTransform />}
|
||||
{isSegmentableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSegment />}
|
||||
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCopyToClipboard />}
|
||||
{isSaveableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsSave />}
|
||||
{isTransformableEntityIdentifier(entityIdentifier) && <CanvasEntityMenuItemsCropToBbox />}
|
||||
<CanvasEntityMenuItemsDelete />
|
||||
</MenuGroup>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasContextMenuSelectedEntityMenuItemsContent.displayName = 'CanvasContextMenuSelectedEntityMenuItemsContent';
|
||||
|
||||
export const CanvasContextMenuSelectedEntityMenuItems = memo(() => {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { EntityListGlobalActionBarAddLayerMenu } from 'features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu';
|
||||
import { EntityListSelectedEntityActionBarAutoMaskButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarAutoMaskButton';
|
||||
import { EntityListSelectedEntityActionBarDuplicateButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarDuplicateButton';
|
||||
import { EntityListSelectedEntityActionBarFill } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFill';
|
||||
import { EntityListSelectedEntityActionBarFilterButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFilterButton';
|
||||
@@ -17,7 +16,6 @@ export const EntityListSelectedEntityActionBar = memo(() => {
|
||||
<Spacer />
|
||||
<EntityListSelectedEntityActionBarFill />
|
||||
<Flex h="full">
|
||||
<EntityListSelectedEntityActionBarAutoMaskButton />
|
||||
<EntityListSelectedEntityActionBarFilterButton />
|
||||
<EntityListSelectedEntityActionBarTransformButton />
|
||||
<EntityListSelectedEntityActionBarSaveToAssetsButton />
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntitySegmentAnything } from 'features/controlLayers/hooks/useEntitySegmentAnything';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { isSegmentableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiMaskHappyBold } from 'react-icons/pi';
|
||||
|
||||
export const EntityListSelectedEntityActionBarAutoMaskButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const segment = useEntitySegmentAnything(selectedEntityIdentifier);
|
||||
|
||||
if (!selectedEntityIdentifier) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!isSegmentableEntityIdentifier(selectedEntityIdentifier)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={segment.start}
|
||||
isDisabled={segment.isDisabled}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('controlLayers.segment.autoMask')}
|
||||
tooltip={t('controlLayers.segment.autoMask')}
|
||||
icon={<PiMaskHappyBold />}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListSelectedEntityActionBarAutoMaskButton.displayName = 'EntityListSelectedEntityActionBarAutoMaskButton';
|
||||
@@ -25,8 +25,8 @@ const MenuContent = () => {
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
<MenuList>
|
||||
<CanvasContextMenuSelectedEntityMenuItems />
|
||||
<CanvasContextMenuGlobalMenuItems />
|
||||
<CanvasContextMenuSelectedEntityMenuItems />
|
||||
</MenuList>
|
||||
</CanvasManagerProviderGate>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { MenuDivider } from '@invoke-ai/ui-library';
|
||||
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
|
||||
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
@@ -8,8 +9,7 @@ import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/c
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { ControlLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertToSubMenu';
|
||||
import { ControlLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsCopyToSubMenu';
|
||||
import { ControlLayerMenuItemsConvertControlToRaster } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsConvertControlToRaster';
|
||||
import { ControlLayerMenuItemsTransparencyEffect } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItemsTransparencyEffect';
|
||||
import { memo } from 'react';
|
||||
|
||||
@@ -25,13 +25,12 @@ export const ControlLayerMenuItems = memo(() => {
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<CanvasEntityMenuItemsSegment />
|
||||
<ControlLayerMenuItemsConvertControlToRaster />
|
||||
<ControlLayerMenuItemsTransparencyEffect />
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsCropToBbox />
|
||||
<CanvasEntityMenuItemsCopyToClipboard />
|
||||
<CanvasEntityMenuItemsSave />
|
||||
<MenuDivider />
|
||||
<ControlLayerMenuItemsConvertToSubMenu />
|
||||
<ControlLayerMenuItemsCopyToSubMenu />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { controlLayerConvertedToRasterLayer } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLightningBold } from 'react-icons/pi';
|
||||
|
||||
export const ControlLayerMenuItemsConvertControlToRaster = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const convertControlLayerToRasterLayer = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiLightningBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.convertToRasterLayer')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
ControlLayerMenuItemsConvertControlToRaster.displayName = 'ControlLayerMenuItemsConvertControlToRaster';
|
||||
@@ -1,56 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import {
|
||||
controlLayerConvertedToInpaintMask,
|
||||
controlLayerConvertedToRasterLayer,
|
||||
controlLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSwapBold } from 'react-icons/pi';
|
||||
|
||||
export const ControlLayerMenuItemsConvertToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const convertToInpaintMask = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToInpaintMask({ entityIdentifier, replace: true }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const convertToRegionalGuidance = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToRegionalGuidance({ entityIdentifier, replace: true }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const convertToRasterLayer = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier, replace: true }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.convertControlLayerTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem onClick={convertToInpaintMask} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={convertToRegionalGuidance} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={convertToRasterLayer} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
ControlLayerMenuItemsConvertToSubMenu.displayName = 'ControlLayerMenuItemsConvertToSubMenu';
|
||||
@@ -1,58 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import {
|
||||
controlLayerConvertedToInpaintMask,
|
||||
controlLayerConvertedToRasterLayer,
|
||||
controlLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
|
||||
export const ControlLayerMenuItemsCopyToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const copyToInpaintMask = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToInpaintMask({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const copyToRegionalGuidance = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToRegionalGuidance({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const copyToRasterLayer = useCallback(() => {
|
||||
dispatch(controlLayerConvertedToRasterLayer({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.copyControlLayerTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<CanvasEntityMenuItemsCopyToClipboard />
|
||||
<MenuItem onClick={copyToInpaintMask} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newInpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={copyToRegionalGuidance} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newRegionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={copyToRasterLayer} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newRasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
ControlLayerMenuItemsCopyToSubMenu.displayName = 'ControlLayerMenuItemsCopyToSubMenu';
|
||||
@@ -1,22 +0,0 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold } from 'react-icons/pi';
|
||||
|
||||
export const IPAdapterMenuItemPullBbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const entityIdentifier = useEntityIdentifierContext('reference_image');
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
return (
|
||||
<MenuItem onClick={pullBboxIntoIPAdapter} icon={<PiBoundingBoxBold />} isDisabled={isBusy}>
|
||||
{t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
IPAdapterMenuItemPullBbox.displayName = 'IPAdapterMenuItemPullBbox';
|
||||
@@ -1,22 +1,16 @@
|
||||
import { MenuDivider } from '@invoke-ai/ui-library';
|
||||
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
|
||||
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
import { IPAdapterMenuItemPullBbox } from 'features/controlLayers/components/IPAdapter/IPAdapterMenuItemPullBbox';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const IPAdapterMenuItems = memo(() => {
|
||||
return (
|
||||
<>
|
||||
<IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</IconMenuItemGroup>
|
||||
<MenuDivider />
|
||||
<IPAdapterMenuItemPullBbox />
|
||||
</>
|
||||
<IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</IconMenuItemGroup>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/componen
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { InpaintMaskMenuItemsConvertToSubMenu } from 'features/controlLayers/components/InpaintMask/InpaintMaskMenuItemsConvertToSubMenu';
|
||||
import { InpaintMaskMenuItemsCopyToSubMenu } from 'features/controlLayers/components/InpaintMask/InpaintMaskMenuItemsCopyToSubMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const InpaintMaskMenuItems = memo(() => {
|
||||
@@ -21,9 +19,6 @@ export const InpaintMaskMenuItems = memo(() => {
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsCropToBbox />
|
||||
<MenuDivider />
|
||||
<InpaintMaskMenuItemsConvertToSubMenu />
|
||||
<InpaintMaskMenuItemsCopyToSubMenu />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { inpaintMaskConvertedToRegionalGuidance } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSwapBold } from 'react-icons/pi';
|
||||
|
||||
export const InpaintMaskMenuItemsConvertToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('inpaint_mask');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const convertToRegionalGuidance = useCallback(() => {
|
||||
dispatch(inpaintMaskConvertedToRegionalGuidance({ entityIdentifier, replace: true }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.convertInpaintMaskTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem onClick={convertToRegionalGuidance} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
InpaintMaskMenuItemsConvertToSubMenu.displayName = 'InpaintMaskMenuItemsConvertToSubMenu';
|
||||
@@ -1,40 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { inpaintMaskConvertedToRegionalGuidance } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
|
||||
export const InpaintMaskMenuItemsCopyToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('inpaint_mask');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const copyToRegionalGuidance = useCallback(() => {
|
||||
dispatch(inpaintMaskConvertedToRegionalGuidance({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.copyInpaintMaskTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<CanvasEntityMenuItemsCopyToClipboard />
|
||||
<MenuItem onClick={copyToRegionalGuidance} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newRegionalGuidance')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
InpaintMaskMenuItemsCopyToSubMenu.displayName = 'InpaintMaskMenuItemsCopyToSubMenu';
|
||||
@@ -1,6 +1,7 @@
|
||||
import { MenuDivider } from '@invoke-ai/ui-library';
|
||||
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
|
||||
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
@@ -8,8 +9,7 @@ import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/c
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSegment } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSegment';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu';
|
||||
import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu';
|
||||
import { RasterLayerMenuItemsConvertRasterToControl } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertRasterToControl';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const RasterLayerMenuItems = memo(() => {
|
||||
@@ -24,12 +24,11 @@ export const RasterLayerMenuItems = memo(() => {
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<CanvasEntityMenuItemsSegment />
|
||||
<RasterLayerMenuItemsConvertRasterToControl />
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsCropToBbox />
|
||||
<CanvasEntityMenuItemsCopyToClipboard />
|
||||
<CanvasEntityMenuItemsSave />
|
||||
<MenuDivider />
|
||||
<RasterLayerMenuItemsConvertToSubMenu />
|
||||
<RasterLayerMenuItemsCopyToSubMenu />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { rasterLayerConvertedToControlLayer } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLightningBold } from 'react-icons/pi';
|
||||
|
||||
export const RasterLayerMenuItemsConvertRasterToControl = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
overrides: {
|
||||
controlAdapter: defaultControlAdapter,
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={onClick} icon={<PiLightningBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.convertToControlLayer')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
RasterLayerMenuItemsConvertRasterToControl.displayName = 'RasterLayerMenuItemsConvertRasterToControl';
|
||||
@@ -1,65 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import {
|
||||
rasterLayerConvertedToControlLayer,
|
||||
rasterLayerConvertedToInpaintMask,
|
||||
rasterLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSwapBold } from 'react-icons/pi';
|
||||
|
||||
export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const convertToInpaintMask = useCallback(() => {
|
||||
dispatch(rasterLayerConvertedToInpaintMask({ entityIdentifier, replace: true }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const convertToRegionalGuidance = useCallback(() => {
|
||||
dispatch(rasterLayerConvertedToRegionalGuidance({ entityIdentifier, replace: true }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const convertToControlLayer = useCallback(() => {
|
||||
dispatch(
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
replace: true,
|
||||
overrides: { controlAdapter: defaultControlAdapter },
|
||||
})
|
||||
);
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.convertRasterLayerTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem onClick={convertToInpaintMask} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={convertToRegionalGuidance} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={convertToControlLayer} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
RasterLayerMenuItemsConvertToSubMenu.displayName = 'RasterLayerMenuItemsConvertToSubMenu';
|
||||
@@ -1,66 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import {
|
||||
rasterLayerConvertedToControlLayer,
|
||||
rasterLayerConvertedToInpaintMask,
|
||||
rasterLayerConvertedToRegionalGuidance,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
|
||||
export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('raster_layer');
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const copyToInpaintMask = useCallback(() => {
|
||||
dispatch(rasterLayerConvertedToInpaintMask({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const copyToRegionalGuidance = useCallback(() => {
|
||||
dispatch(rasterLayerConvertedToRegionalGuidance({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
const copyToControlLayer = useCallback(() => {
|
||||
dispatch(
|
||||
rasterLayerConvertedToControlLayer({
|
||||
entityIdentifier,
|
||||
overrides: { controlAdapter: defaultControlAdapter },
|
||||
})
|
||||
);
|
||||
}, [defaultControlAdapter, dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.copyRasterLayerTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<CanvasEntityMenuItemsCopyToClipboard />
|
||||
<MenuItem onClick={copyToInpaintMask} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newInpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={copyToRegionalGuidance} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newRegionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={copyToControlLayer} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newControlLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
RasterLayerMenuItemsCopyToSubMenu.displayName = 'RasterLayerMenuItemsCopyToSubMenu';
|
||||
@@ -1,5 +1,4 @@
|
||||
import { MenuDivider } from '@invoke-ai/ui-library';
|
||||
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
|
||||
import { Flex, MenuDivider } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
|
||||
import { CanvasEntityMenuItemsCropToBbox } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCropToBbox';
|
||||
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
|
||||
@@ -7,18 +6,16 @@ import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/component
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { RegionalGuidanceMenuItemsAddPromptsAndIPAdapter } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsAddPromptsAndIPAdapter';
|
||||
import { RegionalGuidanceMenuItemsAutoNegative } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsAutoNegative';
|
||||
import { RegionalGuidanceMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsConvertToSubMenu';
|
||||
import { RegionalGuidanceMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItemsCopyToSubMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const RegionalGuidanceMenuItems = memo(() => {
|
||||
return (
|
||||
<>
|
||||
<IconMenuItemGroup>
|
||||
<Flex gap={2}>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</IconMenuItemGroup>
|
||||
</Flex>
|
||||
<MenuDivider />
|
||||
<RegionalGuidanceMenuItemsAddPromptsAndIPAdapter />
|
||||
<MenuDivider />
|
||||
@@ -26,9 +23,6 @@ export const RegionalGuidanceMenuItems = memo(() => {
|
||||
<RegionalGuidanceMenuItemsAutoNegative />
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsCropToBbox />
|
||||
<MenuDivider />
|
||||
<RegionalGuidanceMenuItemsConvertToSubMenu />
|
||||
<RegionalGuidanceMenuItemsCopyToSubMenu />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { rgConvertedToInpaintMask } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSwapBold } from 'react-icons/pi';
|
||||
|
||||
export const RegionalGuidanceMenuItemsConvertToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const convertToInpaintMask = useCallback(() => {
|
||||
dispatch(rgConvertedToInpaintMask({ entityIdentifier, replace: true }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.convertRegionalGuidanceTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem onClick={convertToInpaintMask} icon={<PiSwapBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
RegionalGuidanceMenuItemsConvertToSubMenu.displayName = 'RegionalGuidanceMenuItemsConvertToSubMenu';
|
||||
@@ -1,40 +0,0 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useIsEntityInteractable } from 'features/controlLayers/hooks/useEntityIsInteractable';
|
||||
import { rgConvertedToInpaintMask } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
|
||||
export const RegionalGuidanceMenuItemsCopyToSubMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const subMenu = useSubMenu();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
|
||||
const isInteractable = useIsEntityInteractable(entityIdentifier);
|
||||
|
||||
const copyToInpaintMask = useCallback(() => {
|
||||
dispatch(rgConvertedToInpaintMask({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label={t('controlLayers.copyRegionalGuidanceTo')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<CanvasEntityMenuItemsCopyToClipboard />
|
||||
<MenuItem onClick={copyToInpaintMask} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('controlLayers.newInpaintMask')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
RegionalGuidanceMenuItemsCopyToSubMenu.displayName = 'RegionalGuidanceMenuItemsCopyToSubMenu';
|
||||
@@ -1,14 +1,4 @@
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
Heading,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Spacer,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
@@ -20,9 +10,9 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { memo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiFloppyDiskBold, PiStarBold, PiXBold } from 'react-icons/pi';
|
||||
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiStarBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
const SegmentAnythingContent = memo(
|
||||
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
@@ -32,25 +22,8 @@ const SegmentAnythingContent = memo(
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const isProcessing = useStore(adapter.segmentAnything.$isProcessing);
|
||||
const hasPoints = useStore(adapter.segmentAnything.$hasPoints);
|
||||
const hasImageState = useStore(adapter.segmentAnything.$hasImageState);
|
||||
const autoProcess = useAppSelector(selectAutoProcess);
|
||||
|
||||
const saveAsInpaintMask = useCallback(() => {
|
||||
adapter.segmentAnything.saveAs('inpaint_mask');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
const saveAsRegionalGuidance = useCallback(() => {
|
||||
adapter.segmentAnything.saveAs('regional_guidance');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
const saveAsRasterLayer = useCallback(() => {
|
||||
adapter.segmentAnything.saveAs('raster_layer');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
const saveAsControlLayer = useCallback(() => {
|
||||
adapter.segmentAnything.saveAs('control_layer');
|
||||
}, [adapter.segmentAnything]);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'applySegmentAnything',
|
||||
category: 'canvas',
|
||||
@@ -113,32 +86,15 @@ const SegmentAnythingContent = memo(
|
||||
>
|
||||
{t('controlLayers.segment.reset')}
|
||||
</Button>
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={Button}
|
||||
leftIcon={<PiFloppyDiskBold />}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.segment.saveAs')}
|
||||
variant="ghost"
|
||||
isDisabled={!hasImageState}
|
||||
>
|
||||
{t('controlLayers.segment.saveAs')}
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
<MenuItem isDisabled={!hasImageState} onClick={saveAsInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem isDisabled={!hasImageState} onClick={saveAsRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem isDisabled={!hasImageState} onClick={saveAsControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem isDisabled={!hasImageState} onClick={saveAsRasterLayer}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
<Button
|
||||
leftIcon={<PiCheckBold />}
|
||||
onClick={adapter.segmentAnything.apply}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.segment.apply')}
|
||||
variant="ghost"
|
||||
>
|
||||
{t('controlLayers.segment.apply')}
|
||||
</Button>
|
||||
<Button
|
||||
leftIcon={<PiXBold />}
|
||||
onClick={adapter.segmentAnything.cancel}
|
||||
|
||||
@@ -26,10 +26,13 @@ export const SegmentAnythingPointType = memo(
|
||||
<RadioGroup value={pointType} onChange={onChange} w="full" size="md">
|
||||
<Flex alignItems="center" w="full" gap={4} fontWeight="semibold" color="base.300">
|
||||
<Radio value="foreground">
|
||||
<Text>{t('controlLayers.segment.include')}</Text>
|
||||
<Text>{t('controlLayers.segment.foreground')}</Text>
|
||||
</Radio>
|
||||
<Radio value="background">
|
||||
<Text>{t('controlLayers.segment.exclude')}</Text>
|
||||
<Text>{t('controlLayers.segment.background')}</Text>
|
||||
</Radio>
|
||||
<Radio value="neutral">
|
||||
<Text>{t('controlLayers.segment.neutral')}</Text>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
|
||||
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
|
||||
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
|
||||
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
|
||||
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
@@ -23,7 +21,6 @@ const _hover: SystemStyleObject = {
|
||||
|
||||
export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props) => {
|
||||
const title = useEntityTypeTitle(type);
|
||||
const informationalPopoverFeature = useEntityTypeInformationalPopover(type);
|
||||
const collapse = useBoolean(true);
|
||||
const canMergeVisible = useMemo(() => type === 'raster_layer' || type === 'inpaint_mask', [type]);
|
||||
const canHideAll = useMemo(() => type !== 'reference_image', [type]);
|
||||
@@ -50,30 +47,15 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
/>
|
||||
{informationalPopoverFeature ? (
|
||||
<InformationalPopover feature={informationalPopoverFeature}>
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
color={isSelected ? 'base.200' : 'base.500'}
|
||||
userSelect="none"
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
</InformationalPopover>
|
||||
) : (
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
color={isSelected ? 'base.200' : 'base.500'}
|
||||
userSelect="none"
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
color={isSelected ? 'base.200' : 'base.500'}
|
||||
userSelect="none"
|
||||
transitionProperty="common"
|
||||
transitionDuration="fast"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
<Spacer />
|
||||
</Flex>
|
||||
{canMergeVisible && <CanvasEntityMergeVisibleButton type={type} />}
|
||||
|
||||
@@ -20,7 +20,7 @@ export const CanvasEntityMenuItemsCopyToClipboard = memo(() => {
|
||||
|
||||
return (
|
||||
<MenuItem onClick={onClick} icon={<PiCopyBold />} isDisabled={!isInteractable}>
|
||||
{t('common.clipboard')}
|
||||
{t('controlLayers.copyToClipboard')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -5,13 +5,11 @@ import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdap
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { isFilterableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
export const useEntityFilter = (entityIdentifier: CanvasEntityIdentifier | null) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useEntityAdapterSafe(entityIdentifier);
|
||||
const imageViewer = useImageViewer();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
|
||||
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
|
||||
@@ -52,9 +50,8 @@ export const useEntityFilter = (entityIdentifier: CanvasEntityIdentifier | null)
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
imageViewer.close();
|
||||
adapter.filterer.start();
|
||||
}, [isDisabled, entityIdentifier, canvasManager, imageViewer]);
|
||||
}, [isDisabled, entityIdentifier, canvasManager]);
|
||||
|
||||
return { isDisabled, start } as const;
|
||||
};
|
||||
|
||||
@@ -5,13 +5,11 @@ import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdap
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { isSegmentableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
export const useEntitySegmentAnything = (entityIdentifier: CanvasEntityIdentifier | null) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useEntityAdapterSafe(entityIdentifier);
|
||||
const imageViewer = useImageViewer();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
|
||||
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
|
||||
@@ -52,9 +50,8 @@ export const useEntitySegmentAnything = (entityIdentifier: CanvasEntityIdentifie
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
imageViewer.close();
|
||||
adapter.segmentAnything.start();
|
||||
}, [isDisabled, entityIdentifier, canvasManager, imageViewer]);
|
||||
}, [isDisabled, entityIdentifier, canvasManager]);
|
||||
|
||||
return { isDisabled, start } as const;
|
||||
};
|
||||
|
||||
@@ -5,13 +5,11 @@ import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdap
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { isTransformableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
export const useEntityTransform = (entityIdentifier: CanvasEntityIdentifier | null) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useEntityAdapterSafe(entityIdentifier);
|
||||
const imageViewer = useImageViewer();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isInteractable = useStore(adapter?.$isInteractable ?? $false);
|
||||
const isEmpty = useStore(adapter?.$isEmpty ?? $false);
|
||||
@@ -69,11 +67,10 @@ export const useEntityTransform = (entityIdentifier: CanvasEntityIdentifier | nu
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
imageViewer.close();
|
||||
await adapter.transformer.startTransform({ silent: true });
|
||||
adapter.transformer.fitToBboxContain();
|
||||
await adapter.transformer.applyTransform();
|
||||
}, [canvasManager, entityIdentifier, imageViewer, isDisabled]);
|
||||
}, [canvasManager, entityIdentifier, isDisabled]);
|
||||
|
||||
return { isDisabled, start, fitToBbox } as const;
|
||||
};
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import type { Feature } from 'common/components/InformationalPopover/constants';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useEntityTypeInformationalPopover = (type: CanvasEntityIdentifier['type']): Feature | undefined => {
|
||||
const feature = useMemo(() => {
|
||||
switch (type) {
|
||||
case 'control_layer':
|
||||
return 'controlNet';
|
||||
case 'inpaint_mask':
|
||||
return 'inpainting';
|
||||
case 'raster_layer':
|
||||
return 'rasterLayer';
|
||||
case 'regional_guidance':
|
||||
return 'regionalGuidanceAndReferenceImage';
|
||||
case 'reference_image':
|
||||
return 'globalReferenceImage';
|
||||
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
}, [type]);
|
||||
|
||||
return feature;
|
||||
};
|
||||
@@ -6,22 +6,15 @@ import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konv
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
|
||||
import {
|
||||
addCoords,
|
||||
getKonvaNodeDebugAttrs,
|
||||
getPrefixedId,
|
||||
offsetCoord,
|
||||
roundCoord,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import { addCoords, getKonvaNodeDebugAttrs, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
|
||||
import { selectAutoProcess } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type {
|
||||
CanvasEntityType,
|
||||
CanvasImageState,
|
||||
Coordinate,
|
||||
RgbaColor,
|
||||
SAMPoint,
|
||||
SAMPointLabel,
|
||||
SAMPointLabelString,
|
||||
SAMPointWithId,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { SAM_POINT_LABEL_NUMBER_TO_STRING } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
@@ -34,9 +27,6 @@ import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import stableHash from 'stable-hash';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type CanvasSegmentAnythingModuleConfig = {
|
||||
/**
|
||||
@@ -80,7 +70,7 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
|
||||
SAM_POINT_FOREGROUND_COLOR: { r: 50, g: 255, b: 0, a: 1 }, // light green
|
||||
SAM_POINT_BACKGROUND_COLOR: { r: 255, g: 0, b: 50, a: 1 }, // red-ish
|
||||
SAM_POINT_NEUTRAL_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
|
||||
MASK_COLOR: { r: 0, g: 225, b: 255, a: 1 }, // cyan
|
||||
MASK_COLOR: { r: 0, g: 200, b: 200, a: 0.5 }, // cyan with 50% opacity
|
||||
PROCESS_DEBOUNCE_MS: 1000,
|
||||
};
|
||||
|
||||
@@ -95,7 +85,6 @@ const DEFAULT_CONFIG: CanvasSegmentAnythingModuleConfig = {
|
||||
type SAMPointState = {
|
||||
id: string;
|
||||
label: SAMPointLabel;
|
||||
coord: Coordinate;
|
||||
konva: {
|
||||
circle: Konva.Circle;
|
||||
};
|
||||
@@ -124,9 +113,9 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
$isSegmenting = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* The hash of the last processed points. This is used to prevent re-processing the same points.
|
||||
* Whether the current set of points has been processed.
|
||||
*/
|
||||
$lastProcessedHash = atom<string>('');
|
||||
$hasProcessed = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* Whether the module is currently processing the points.
|
||||
@@ -155,15 +144,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* The ephemeral image state of the processed image. Only used while segmenting.
|
||||
*/
|
||||
$imageState = atom<CanvasImageState | null>(null);
|
||||
imageState: CanvasImageState | null = null;
|
||||
|
||||
/**
|
||||
* Whether the module has an image state. This is a computed value based on $imageState.
|
||||
*/
|
||||
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
|
||||
|
||||
/**
|
||||
* The current input points. A listener is added to this atom to process the points when they change.
|
||||
* The current input points.
|
||||
*/
|
||||
$points = atom<SAMPointState[]>([]);
|
||||
|
||||
@@ -203,10 +187,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
* It's rendered with a globalCompositeOperation of 'source-atop' to preview the mask as a semi-transparent overlay.
|
||||
*/
|
||||
compositingRect: Konva.Rect;
|
||||
/**
|
||||
* A tween for pulsing the mask group's opacity.
|
||||
*/
|
||||
maskTween: Konva.Tween | null;
|
||||
};
|
||||
|
||||
KONVA_CIRCLE_NAME = `${this.type}:circle`;
|
||||
@@ -229,7 +209,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.konva = {
|
||||
group: new Konva.Group({ name: this.KONVA_GROUP_NAME }),
|
||||
pointGroup: new Konva.Group({ name: this.KONVA_POINT_GROUP_NAME }),
|
||||
maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME, opacity: 0.6 }),
|
||||
maskGroup: new Konva.Group({ name: this.KONVA_MASK_GROUP_NAME }),
|
||||
compositingRect: new Konva.Rect({
|
||||
name: this.KONVA_COMPOSITING_RECT_NAME,
|
||||
fill: rgbaColorToString(this.config.MASK_COLOR),
|
||||
@@ -239,7 +219,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
perfectDrawEnabled: false,
|
||||
visible: false,
|
||||
}),
|
||||
maskTween: null,
|
||||
};
|
||||
|
||||
// Points should always be rendered above the mask group
|
||||
@@ -271,12 +250,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
createPoint(coord: Coordinate, label: SAMPointLabel): SAMPointState {
|
||||
const id = getPrefixedId('sam_point');
|
||||
|
||||
const roundedCoord = roundCoord(coord);
|
||||
|
||||
const circle = new Konva.Circle({
|
||||
name: this.KONVA_CIRCLE_NAME,
|
||||
x: roundedCoord.x,
|
||||
y: roundedCoord.y,
|
||||
x: Math.round(coord.x),
|
||||
y: Math.round(coord.y),
|
||||
radius: this.manager.stage.unscale(this.config.SAM_POINT_RADIUS), // We will scale this as the stage scale changes
|
||||
fill: rgbaColorToString(this.getSAMPointColor(label)),
|
||||
stroke: rgbaColorToString(this.config.SAM_POINT_BORDER_COLOR),
|
||||
@@ -296,12 +273,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// This event should not bubble up to the parent, stage or any other nodes
|
||||
e.cancelBubble = true;
|
||||
circle.destroy();
|
||||
|
||||
const newPoints = this.$points.get().filter((point) => point.id !== id);
|
||||
if (newPoints.length === 0) {
|
||||
this.$points.set(this.$points.get().filter((point) => point.id !== id));
|
||||
if (this.$points.get().length === 0) {
|
||||
this.resetEphemeralState();
|
||||
} else {
|
||||
this.$points.set(newPoints);
|
||||
this.$hasProcessed.set(false);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -310,28 +286,25 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
});
|
||||
|
||||
circle.on('dragend', () => {
|
||||
const roundedCoord = roundCoord(circle.position());
|
||||
|
||||
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Moved SAM point');
|
||||
this.$isDraggingPoint.set(false);
|
||||
|
||||
const newPoints = this.$points.get().map((point) => {
|
||||
if (point.id === id) {
|
||||
return { ...point, coord: roundedCoord };
|
||||
}
|
||||
return point;
|
||||
});
|
||||
|
||||
this.$points.set(newPoints);
|
||||
// Point has changed!
|
||||
this.$hasProcessed.set(false);
|
||||
this.$points.notify();
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Moved SAM point'
|
||||
);
|
||||
});
|
||||
|
||||
this.konva.pointGroup.add(circle);
|
||||
|
||||
this.log.trace({ ...roundedCoord, label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] }, 'Created SAM point');
|
||||
this.log.trace(
|
||||
{ x: Math.round(circle.x()), y: Math.round(circle.y()), label: SAM_POINT_LABEL_NUMBER_TO_STRING[label] },
|
||||
'Created SAM point'
|
||||
);
|
||||
|
||||
return {
|
||||
id,
|
||||
coord: roundedCoord,
|
||||
label,
|
||||
konva: { circle },
|
||||
};
|
||||
@@ -354,14 +327,14 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* Gets the SAM points in the format expected by the segment-anything API. The x and y values are rounded to integers.
|
||||
*/
|
||||
getSAMPoints = (): SAMPointWithId[] => {
|
||||
const points: SAMPointWithId[] = [];
|
||||
getSAMPoints = (): SAMPoint[] => {
|
||||
const points: SAMPoint[] = [];
|
||||
|
||||
for (const { id, coord, label } of this.$points.get()) {
|
||||
for (const { konva, label } of this.$points.get()) {
|
||||
points.push({
|
||||
id,
|
||||
x: coord.x,
|
||||
y: coord.y,
|
||||
// Pull out and round the x and y values from Konva
|
||||
x: Math.round(konva.circle.x()),
|
||||
y: Math.round(konva.circle.y()),
|
||||
label,
|
||||
});
|
||||
}
|
||||
@@ -408,8 +381,10 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
|
||||
// Create a SAM point at the normalized position
|
||||
const point = this.createPoint(normalizedPoint, this.$pointType.get());
|
||||
const newPoints = [...this.$points.get(), point];
|
||||
this.$points.set(newPoints);
|
||||
this.$points.set([...this.$points.get(), point]);
|
||||
|
||||
// Mark the module as having _not_ processed the points now that they have changed
|
||||
this.$hasProcessed.set(false);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -446,7 +421,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
if (points.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.manager.stateApi.getSettings().autoProcess) {
|
||||
this.process();
|
||||
}
|
||||
@@ -459,7 +433,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
if (this.$points.get().length === 0) {
|
||||
return;
|
||||
}
|
||||
if (autoProcess) {
|
||||
if (autoProcess && !this.$hasProcessed.get()) {
|
||||
this.process();
|
||||
}
|
||||
})
|
||||
@@ -526,12 +500,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
const hash = stableHash(points);
|
||||
if (hash === this.$lastProcessedHash.get()) {
|
||||
this.log.trace('Already processed points');
|
||||
return;
|
||||
}
|
||||
|
||||
this.$isProcessing.set(true);
|
||||
|
||||
this.log.trace({ points }, 'Segmenting');
|
||||
@@ -553,7 +521,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.abortController = controller;
|
||||
|
||||
// Build the graph for segmenting the image, using the rasterized image DTO
|
||||
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value, points);
|
||||
const { graph, outputNodeId } = this.buildGraph(rasterizeResult.value);
|
||||
|
||||
// Run the graph and get the segmented image output
|
||||
const segmentResult = await withResultAsync(() =>
|
||||
@@ -580,27 +548,21 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.log.trace({ imageDTO: segmentResult.value }, 'Segmented');
|
||||
|
||||
// Prepare the ephemeral image state
|
||||
const imageState = imageDTOToImageObject(segmentResult.value);
|
||||
this.$imageState.set(imageState);
|
||||
this.imageState = imageDTOToImageObject(segmentResult.value);
|
||||
|
||||
// Destroy any existing masked image and create a new one
|
||||
if (this.maskedImage) {
|
||||
this.maskedImage.destroy();
|
||||
}
|
||||
if (this.konva.maskTween) {
|
||||
this.konva.maskTween.destroy();
|
||||
this.konva.maskTween = null;
|
||||
}
|
||||
|
||||
this.maskedImage = new CanvasObjectImage(imageState, this);
|
||||
this.maskedImage = new CanvasObjectImage(this.imageState, this);
|
||||
|
||||
// Force update the masked image - after awaiting, the image will be rendered (in memory)
|
||||
await this.maskedImage.update(imageState, true);
|
||||
await this.maskedImage.update(this.imageState, true);
|
||||
|
||||
// Update the compositing rect to match the image size
|
||||
this.konva.compositingRect.setAttrs({
|
||||
width: imageState.image.width,
|
||||
height: imageState.image.height,
|
||||
width: this.imageState.image.width,
|
||||
height: this.imageState.image.height,
|
||||
visible: true,
|
||||
});
|
||||
|
||||
@@ -612,24 +574,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
// Cache the group to ensure the mask is rendered correctly w/ opacity
|
||||
this.konva.maskGroup.cache();
|
||||
|
||||
// Create a pulsing tween
|
||||
this.konva.maskTween = new Konva.Tween({
|
||||
node: this.konva.maskGroup,
|
||||
duration: 1,
|
||||
opacity: 0.4, // oscillate between this value and pre-tween opacity
|
||||
yoyo: true,
|
||||
repeat: Infinity,
|
||||
easing: Konva.Easings.EaseOut,
|
||||
});
|
||||
|
||||
// Start the pulsing effect
|
||||
this.konva.maskTween.play();
|
||||
|
||||
this.$lastProcessedHash.set(hash);
|
||||
|
||||
// We are done processing (still segmenting though!)
|
||||
this.$isProcessing.set(false);
|
||||
|
||||
// The current points have been processed
|
||||
this.$hasProcessed.set(true);
|
||||
|
||||
// Clean up the abort controller as needed
|
||||
if (!this.abortController.signal.aborted) {
|
||||
this.abortController.abort();
|
||||
@@ -646,7 +596,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
* Applies the segmented image to the entity.
|
||||
*/
|
||||
apply = () => {
|
||||
const imageState = this.$imageState.get();
|
||||
if (!this.$hasProcessed.get()) {
|
||||
this.log.error('Cannot apply unprocessed points');
|
||||
return;
|
||||
}
|
||||
const imageState = this.imageState;
|
||||
if (!imageState) {
|
||||
this.log.error('No image state to apply');
|
||||
return;
|
||||
@@ -673,55 +627,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
/**
|
||||
* Applies the segmented image to the entity.
|
||||
*/
|
||||
saveAs = (type: Exclude<CanvasEntityType, 'reference_image'>) => {
|
||||
const imageState = this.$imageState.get();
|
||||
if (!imageState) {
|
||||
this.log.error('No image state to save as');
|
||||
return;
|
||||
}
|
||||
this.log.trace(`Saving as ${type}`);
|
||||
|
||||
// Clear the buffer - we are creating a new entity, so we don't want to keep the old one
|
||||
this.parent.bufferRenderer.clearBuffer();
|
||||
|
||||
// Create the new entity with the masked image as its only object
|
||||
const rect = this.parent.transformer.getRelativeRect();
|
||||
const arg = {
|
||||
overrides: {
|
||||
objects: [imageState],
|
||||
position: {
|
||||
x: Math.round(rect.x),
|
||||
y: Math.round(rect.y),
|
||||
},
|
||||
},
|
||||
isSelected: true,
|
||||
};
|
||||
|
||||
switch (type) {
|
||||
case 'raster_layer':
|
||||
this.manager.stateApi.addRasterLayer(arg);
|
||||
break;
|
||||
case 'control_layer':
|
||||
this.manager.stateApi.addControlLayer(arg);
|
||||
break;
|
||||
case 'inpaint_mask':
|
||||
this.manager.stateApi.addInpaintMask(arg);
|
||||
break;
|
||||
case 'regional_guidance':
|
||||
this.manager.stateApi.addRegionalGuidance(arg);
|
||||
break;
|
||||
default:
|
||||
assert<Equals<typeof type, never>>(false);
|
||||
}
|
||||
|
||||
// Final cleanup and teardown, returning user to main canvas UI
|
||||
this.resetEphemeralState();
|
||||
this.teardown();
|
||||
};
|
||||
|
||||
/**
|
||||
* Resets the module (e.g. remove all points and the mask image).
|
||||
*
|
||||
@@ -781,16 +686,12 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
if (this.maskedImage) {
|
||||
this.maskedImage.destroy();
|
||||
}
|
||||
if (this.konva.maskTween) {
|
||||
this.konva.maskTween.destroy();
|
||||
this.konva.maskTween = null;
|
||||
}
|
||||
|
||||
// Empty internal module state
|
||||
this.$points.set([]);
|
||||
this.$imageState.set(null);
|
||||
this.imageState = null;
|
||||
this.$pointType.set(1);
|
||||
this.$lastProcessedHash.set('');
|
||||
this.$hasProcessed.set(false);
|
||||
this.$isProcessing.set(false);
|
||||
|
||||
// Reset non-ephemeral konva nodes
|
||||
@@ -805,7 +706,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
/**
|
||||
* Builds a graph for segmenting an image with the given image DTO.
|
||||
*/
|
||||
buildGraph = ({ image_name }: ImageDTO, points: SAMPointWithId[]): { graph: Graph; outputNodeId: string } => {
|
||||
buildGraph = ({ image_name }: ImageDTO): { graph: Graph; outputNodeId: string } => {
|
||||
const graph = new Graph(getPrefixedId('canvas_segment_anything'));
|
||||
|
||||
// TODO(psyche): When SAM2 is available in transformers, use it here
|
||||
@@ -815,7 +716,7 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
type: 'segment_anything',
|
||||
model: 'segment-anything-huge',
|
||||
image: { image_name },
|
||||
point_lists: [{ points: points.map(({ x, y, label }) => ({ x, y, label })) }],
|
||||
point_lists: [{ points: this.getSAMPoints() }],
|
||||
mask_filter: 'largest',
|
||||
});
|
||||
|
||||
@@ -858,11 +759,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
|
||||
label,
|
||||
circle: getKonvaNodeDebugAttrs(konva.circle),
|
||||
})),
|
||||
imageState: deepClone(this.$imageState.get()),
|
||||
imageState: deepClone(this.imageState),
|
||||
maskedImage: this.maskedImage?.repr(),
|
||||
config: deepClone(this.config),
|
||||
$isSegmenting: this.$isSegmenting.get(),
|
||||
$lastProcessedHash: this.$lastProcessedHash.get(),
|
||||
$hasProcessed: this.$hasProcessed.get(),
|
||||
$isProcessing: this.$isProcessing.get(),
|
||||
$pointType: this.$pointType.get(),
|
||||
$pointTypeString: this.$pointTypeString.get(),
|
||||
|
||||
@@ -17,16 +17,12 @@ import {
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import {
|
||||
bboxChangedFromCanvas,
|
||||
controlLayerAdded,
|
||||
entityBrushLineAdded,
|
||||
entityEraserLineAdded,
|
||||
entityMoved,
|
||||
entityRasterized,
|
||||
entityRectAdded,
|
||||
entityReset,
|
||||
inpaintMaskAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasStagingAreaSlice } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
@@ -55,7 +51,6 @@ import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
|
||||
import { QueueError } from 'services/events/errors';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { CanvasEntityAdapter } from './CanvasEntity/types';
|
||||
@@ -165,34 +160,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
this.store.dispatch(entityRectAdded(arg));
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds a raster layer to the canvas, pushing state to redux.
|
||||
*/
|
||||
addRasterLayer = (arg: Param0<typeof rasterLayerAdded>) => {
|
||||
this.store.dispatch(rasterLayerAdded(arg));
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds a control layer to the canvas, pushing state to redux.
|
||||
*/
|
||||
addControlLayer = (arg: Param0<typeof controlLayerAdded>) => {
|
||||
this.store.dispatch(controlLayerAdded(arg));
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds an inpaint mask to the canvas, pushing state to redux.
|
||||
*/
|
||||
addInpaintMask = (arg: Param0<typeof inpaintMaskAdded>) => {
|
||||
this.store.dispatch(inpaintMaskAdded(arg));
|
||||
};
|
||||
|
||||
/**
|
||||
* Adds regional guidance to the canvas, pushing state to redux.
|
||||
*/
|
||||
addRegionalGuidance = (arg: Param0<typeof rgAdded>) => {
|
||||
this.store.dispatch(rgAdded(arg));
|
||||
};
|
||||
|
||||
/**
|
||||
* Rasterizes an entity, pushing state to redux.
|
||||
*/
|
||||
|
||||
@@ -126,13 +126,6 @@ export const floorCoord = (coord: Coordinate): Coordinate => {
|
||||
};
|
||||
};
|
||||
|
||||
export const roundCoord = (coord: Coordinate): Coordinate => {
|
||||
return {
|
||||
x: Math.round(coord.x),
|
||||
y: Math.round(coord.y),
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Snaps a position to the edge of the given rect if within a threshold of the edge
|
||||
* @param pos The position to snap
|
||||
|
||||
@@ -29,7 +29,7 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
|
||||
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
|
||||
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { merge } from 'lodash-es';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -57,13 +57,13 @@ import type {
|
||||
} from './types';
|
||||
import { getEntityIdentifier, isRenderableEntity } from './types';
|
||||
import {
|
||||
converters,
|
||||
getControlLayerState,
|
||||
getInpaintMaskState,
|
||||
getRasterLayerState,
|
||||
getReferenceImageState,
|
||||
getRegionalGuidanceState,
|
||||
imageDTOToImageWithDims,
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
} from './util';
|
||||
|
||||
@@ -157,25 +157,28 @@ export const canvasSlice = createSlice({
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasControlLayerState>; replace?: boolean },
|
||||
'raster_layer'
|
||||
>
|
||||
EntityIdentifierPayload<{ newId: string; overrides?: Partial<CanvasControlLayerState> }, 'raster_layer'>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
const { entityIdentifier, newId, overrides } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the raster layer to control layer
|
||||
const controlLayerState = converters.rasterLayer.toControlLayer(newId, layer, overrides);
|
||||
const controlLayerState: CanvasControlLayerState = {
|
||||
...deepClone(layer),
|
||||
id: newId,
|
||||
type: 'control_layer',
|
||||
controlAdapter: deepClone(initialControlNet),
|
||||
withTransparencyEffect: true,
|
||||
};
|
||||
|
||||
if (replace) {
|
||||
// Remove the raster layer
|
||||
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
|
||||
}
|
||||
merge(controlLayerState, overrides);
|
||||
|
||||
// Remove the raster layer
|
||||
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
|
||||
|
||||
// Add the converted control layer
|
||||
state.controlLayers.entities.push(controlLayerState);
|
||||
@@ -183,90 +186,11 @@ export const canvasSlice = createSlice({
|
||||
state.selectedEntityIdentifier = { type: controlLayerState.type, id: controlLayerState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasControlLayerState>; replace?: boolean } | undefined,
|
||||
'raster_layer'
|
||||
>
|
||||
payload: EntityIdentifierPayload<{ overrides?: Partial<CanvasControlLayerState> } | undefined, 'raster_layer'>
|
||||
) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('control_layer') },
|
||||
}),
|
||||
},
|
||||
rasterLayerConvertedToInpaintMask: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean },
|
||||
'raster_layer'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the raster layer to inpaint mask
|
||||
const inpaintMaskState = converters.rasterLayer.toInpaintMask(newId, layer, overrides);
|
||||
|
||||
if (replace) {
|
||||
// Remove the raster layer
|
||||
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
|
||||
}
|
||||
|
||||
// Add the converted inpaint mask
|
||||
state.inpaintMasks.entities.push(inpaintMaskState);
|
||||
|
||||
state.selectedEntityIdentifier = { type: inpaintMaskState.type, id: inpaintMaskState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean } | undefined,
|
||||
'raster_layer'
|
||||
>
|
||||
) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('inpaint_mask') },
|
||||
}),
|
||||
},
|
||||
rasterLayerConvertedToRegionalGuidance: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean },
|
||||
'raster_layer'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the raster layer to inpaint mask
|
||||
const regionalGuidanceState = converters.rasterLayer.toRegionalGuidance(newId, layer, overrides);
|
||||
|
||||
if (replace) {
|
||||
// Remove the raster layer
|
||||
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
|
||||
}
|
||||
|
||||
// Add the converted inpaint mask
|
||||
state.regionalGuidance.entities.push(regionalGuidanceState);
|
||||
|
||||
state.selectedEntityIdentifier = { type: regionalGuidanceState.type, id: regionalGuidanceState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean } | undefined,
|
||||
'raster_layer'
|
||||
>
|
||||
) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('regional_guidance') },
|
||||
}),
|
||||
},
|
||||
//#region Control layers
|
||||
controlLayerAdded: {
|
||||
reducer: (
|
||||
@@ -293,125 +217,32 @@ export const canvasSlice = createSlice({
|
||||
state.selectedEntityIdentifier = { type: 'control_layer', id: data.id };
|
||||
},
|
||||
controlLayerConvertedToRasterLayer: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasRasterLayerState>; replace?: boolean },
|
||||
'control_layer'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
reducer: (state, action: PayloadAction<EntityIdentifierPayload<{ newId: string }, 'control_layer'>>) => {
|
||||
const { entityIdentifier, newId } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the raster layer to control layer
|
||||
const rasterLayerState = converters.controlLayer.toRasterLayer(newId, layer, overrides);
|
||||
const rasterLayerState: CanvasRasterLayerState = {
|
||||
...omit(deepClone(layer), ['type', 'controlAdapter', 'withTransparencyEffect']),
|
||||
id: newId,
|
||||
type: 'raster_layer',
|
||||
};
|
||||
|
||||
if (replace) {
|
||||
// Remove the control layer
|
||||
state.controlLayers.entities = state.controlLayers.entities.filter(
|
||||
(layer) => layer.id !== entityIdentifier.id
|
||||
);
|
||||
}
|
||||
// Remove the control layer
|
||||
state.controlLayers.entities = state.controlLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
|
||||
|
||||
// Add the new raster layer
|
||||
state.rasterLayers.entities.push(rasterLayerState);
|
||||
|
||||
state.selectedEntityIdentifier = { type: rasterLayerState.type, id: rasterLayerState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasRasterLayerState>; replace?: boolean } | undefined,
|
||||
'control_layer'
|
||||
>
|
||||
) => ({
|
||||
prepare: (payload: EntityIdentifierPayload<void, 'control_layer'>) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('raster_layer') },
|
||||
}),
|
||||
},
|
||||
controlLayerConvertedToInpaintMask: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean },
|
||||
'control_layer'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the control layer to inpaint mask
|
||||
const inpaintMaskState = converters.controlLayer.toInpaintMask(newId, layer, overrides);
|
||||
|
||||
if (replace) {
|
||||
// Remove the control layer
|
||||
state.controlLayers.entities = state.controlLayers.entities.filter(
|
||||
(layer) => layer.id !== entityIdentifier.id
|
||||
);
|
||||
}
|
||||
|
||||
// Add the new inpaint mask
|
||||
state.inpaintMasks.entities.push(inpaintMaskState);
|
||||
|
||||
state.selectedEntityIdentifier = { type: inpaintMaskState.type, id: inpaintMaskState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean } | undefined,
|
||||
'control_layer'
|
||||
>
|
||||
) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('inpaint_mask') },
|
||||
}),
|
||||
},
|
||||
controlLayerConvertedToRegionalGuidance: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean },
|
||||
'control_layer'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the control layer to regional guidance
|
||||
const regionalGuidanceState = converters.controlLayer.toRegionalGuidance(newId, layer, overrides);
|
||||
|
||||
if (replace) {
|
||||
// Remove the control layer
|
||||
state.controlLayers.entities = state.controlLayers.entities.filter(
|
||||
(layer) => layer.id !== entityIdentifier.id
|
||||
);
|
||||
}
|
||||
|
||||
// Add the new regional guidance
|
||||
state.regionalGuidance.entities.push(regionalGuidanceState);
|
||||
|
||||
state.selectedEntityIdentifier = { type: regionalGuidanceState.type, id: regionalGuidanceState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean } | undefined,
|
||||
'control_layer'
|
||||
>
|
||||
) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('regional_guidance') },
|
||||
}),
|
||||
},
|
||||
controlLayerModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
@@ -616,46 +447,6 @@ export const canvasSlice = createSlice({
|
||||
state.regionalGuidance.entities.push(data);
|
||||
state.selectedEntityIdentifier = { type: 'regional_guidance', id: data.id };
|
||||
},
|
||||
rgConvertedToInpaintMask: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean },
|
||||
'regional_guidance'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the regional guidance to inpaint mask
|
||||
const inpaintMaskState = converters.regionalGuidance.toInpaintMask(newId, layer, overrides);
|
||||
|
||||
if (replace) {
|
||||
// Remove the regional guidance
|
||||
state.regionalGuidance.entities = state.regionalGuidance.entities.filter(
|
||||
(layer) => layer.id !== entityIdentifier.id
|
||||
);
|
||||
}
|
||||
|
||||
// Add the new inpaint mask
|
||||
state.inpaintMasks.entities.push(inpaintMaskState);
|
||||
|
||||
state.selectedEntityIdentifier = { type: inpaintMaskState.type, id: inpaintMaskState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasInpaintMaskState>; replace?: boolean } | undefined,
|
||||
'regional_guidance'
|
||||
>
|
||||
) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('inpaint_mask') },
|
||||
}),
|
||||
},
|
||||
rgPositivePromptChanged: (
|
||||
state,
|
||||
action: PayloadAction<EntityIdentifierPayload<{ prompt: string | null }, 'regional_guidance'>>
|
||||
@@ -853,44 +644,6 @@ export const canvasSlice = createSlice({
|
||||
state.inpaintMasks.entities = [data];
|
||||
state.selectedEntityIdentifier = { type: 'inpaint_mask', id: data.id };
|
||||
},
|
||||
inpaintMaskConvertedToRegionalGuidance: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{ newId: string; overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean },
|
||||
'inpaint_mask'
|
||||
>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, newId, overrides, replace } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert the inpaint mask to regional guidance
|
||||
const regionalGuidanceState = converters.inpaintMask.toRegionalGuidance(newId, layer, overrides);
|
||||
|
||||
if (replace) {
|
||||
// Remove the inpaint mask
|
||||
state.inpaintMasks.entities = state.inpaintMasks.entities.filter((layer) => layer.id !== entityIdentifier.id);
|
||||
}
|
||||
|
||||
// Add the new regional guidance
|
||||
state.regionalGuidance.entities.push(regionalGuidanceState);
|
||||
|
||||
state.selectedEntityIdentifier = { type: regionalGuidanceState.type, id: regionalGuidanceState.id };
|
||||
},
|
||||
prepare: (
|
||||
payload: EntityIdentifierPayload<
|
||||
{ overrides?: Partial<CanvasRegionalGuidanceState>; replace?: boolean } | undefined,
|
||||
'inpaint_mask'
|
||||
>
|
||||
) => ({
|
||||
payload: { ...payload, newId: getPrefixedId('regional_guidance') },
|
||||
}),
|
||||
},
|
||||
//#region BBox
|
||||
bboxScaledWidthChanged: (state, action: PayloadAction<number>) => {
|
||||
const gridSize = getGridSize(state.bbox.modelBase);
|
||||
@@ -1457,14 +1210,10 @@ export const {
|
||||
rasterLayerAdded,
|
||||
// rasterLayerRecalled,
|
||||
rasterLayerConvertedToControlLayer,
|
||||
rasterLayerConvertedToInpaintMask,
|
||||
rasterLayerConvertedToRegionalGuidance,
|
||||
// Control layers
|
||||
controlLayerAdded,
|
||||
// controlLayerRecalled,
|
||||
controlLayerConvertedToRasterLayer,
|
||||
controlLayerConvertedToInpaintMask,
|
||||
controlLayerConvertedToRegionalGuidance,
|
||||
controlLayerModelChanged,
|
||||
controlLayerControlModeChanged,
|
||||
controlLayerWeightChanged,
|
||||
@@ -1482,7 +1231,6 @@ export const {
|
||||
// Regions
|
||||
rgAdded,
|
||||
// rgRecalled,
|
||||
rgConvertedToInpaintMask,
|
||||
rgPositivePromptChanged,
|
||||
rgNegativePromptChanged,
|
||||
rgAutoNegativeToggled,
|
||||
@@ -1496,7 +1244,6 @@ export const {
|
||||
rgIPAdapterCLIPVisionModelChanged,
|
||||
// Inpaint mask
|
||||
inpaintMaskAdded,
|
||||
inpaintMaskConvertedToRegionalGuidance,
|
||||
// inpaintMaskRecalled,
|
||||
} = canvasSlice.actions;
|
||||
|
||||
|
||||
@@ -131,8 +131,7 @@ const zSAMPoint = z.object({
|
||||
y: z.number().int().gte(0),
|
||||
label: zSAMPointLabel,
|
||||
});
|
||||
type SAMPoint = z.infer<typeof zSAMPoint>;
|
||||
export type SAMPointWithId = SAMPoint & { id: string };
|
||||
export type SAMPoint = z.infer<typeof zSAMPoint>;
|
||||
|
||||
const zRect = z.object({
|
||||
x: z.number(),
|
||||
|
||||
@@ -184,153 +184,3 @@ export const getInpaintMaskState = (
|
||||
merge(entityState, overrides);
|
||||
return entityState;
|
||||
};
|
||||
|
||||
const convertRasterLayerToControlLayer = (
|
||||
newId: string,
|
||||
rasterLayerState: CanvasRasterLayerState,
|
||||
overrides?: Partial<CanvasControlLayerState>
|
||||
): CanvasControlLayerState => {
|
||||
const { name, objects, position } = rasterLayerState;
|
||||
const controlLayerState = getControlLayerState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(controlLayerState, overrides);
|
||||
return controlLayerState;
|
||||
};
|
||||
|
||||
const convertRasterLayerToInpaintMask = (
|
||||
newId: string,
|
||||
rasterLayerState: CanvasRasterLayerState,
|
||||
overrides?: Partial<CanvasInpaintMaskState>
|
||||
): CanvasInpaintMaskState => {
|
||||
const { name, objects, position } = rasterLayerState;
|
||||
const inpaintMaskState = getInpaintMaskState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(inpaintMaskState, overrides);
|
||||
return inpaintMaskState;
|
||||
};
|
||||
|
||||
const convertRasterLayerToRegionalGuidance = (
|
||||
newId: string,
|
||||
rasterLayerState: CanvasRasterLayerState,
|
||||
overrides?: Partial<CanvasRegionalGuidanceState>
|
||||
): CanvasRegionalGuidanceState => {
|
||||
const { name, objects, position } = rasterLayerState;
|
||||
const regionalGuidanceState = getRegionalGuidanceState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(regionalGuidanceState, overrides);
|
||||
return regionalGuidanceState;
|
||||
};
|
||||
|
||||
const convertControlLayerToRasterLayer = (
|
||||
newId: string,
|
||||
controlLayerState: CanvasControlLayerState,
|
||||
overrides?: Partial<CanvasRasterLayerState>
|
||||
): CanvasRasterLayerState => {
|
||||
const { name, objects, position } = controlLayerState;
|
||||
const rasterLayerState = getRasterLayerState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(rasterLayerState, overrides);
|
||||
return rasterLayerState;
|
||||
};
|
||||
|
||||
const convertControlLayerToInpaintMask = (
|
||||
newId: string,
|
||||
rasterLayerState: CanvasControlLayerState,
|
||||
overrides?: Partial<CanvasInpaintMaskState>
|
||||
): CanvasInpaintMaskState => {
|
||||
const { name, objects, position } = rasterLayerState;
|
||||
const inpaintMaskState = getInpaintMaskState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(inpaintMaskState, overrides);
|
||||
return inpaintMaskState;
|
||||
};
|
||||
|
||||
const convertControlLayerToRegionalGuidance = (
|
||||
newId: string,
|
||||
rasterLayerState: CanvasControlLayerState,
|
||||
overrides?: Partial<CanvasRegionalGuidanceState>
|
||||
): CanvasRegionalGuidanceState => {
|
||||
const { name, objects, position } = rasterLayerState;
|
||||
const regionalGuidanceState = getRegionalGuidanceState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(regionalGuidanceState, overrides);
|
||||
return regionalGuidanceState;
|
||||
};
|
||||
|
||||
const convertInpaintMaskToRegionalGuidance = (
|
||||
newId: string,
|
||||
inpaintMaskState: CanvasInpaintMaskState,
|
||||
overrides?: Partial<CanvasRegionalGuidanceState>
|
||||
): CanvasRegionalGuidanceState => {
|
||||
const { name, objects, position } = inpaintMaskState;
|
||||
const regionalGuidanceState = getRegionalGuidanceState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(regionalGuidanceState, overrides);
|
||||
return regionalGuidanceState;
|
||||
};
|
||||
|
||||
const convertRegionalGuidanceToInpaintMask = (
|
||||
newId: string,
|
||||
regionalGuidanceState: CanvasRegionalGuidanceState,
|
||||
overrides?: Partial<CanvasInpaintMaskState>
|
||||
): CanvasInpaintMaskState => {
|
||||
const { name, objects, position } = regionalGuidanceState;
|
||||
const inpaintMaskState = getInpaintMaskState(newId, {
|
||||
name,
|
||||
objects,
|
||||
position,
|
||||
});
|
||||
merge(inpaintMaskState, overrides);
|
||||
return inpaintMaskState;
|
||||
};
|
||||
|
||||
/**
|
||||
* Supported conversions:
|
||||
* - Raster Layer -> Control Layer
|
||||
* - Raster Layer -> Inpaint Mask
|
||||
* - Raster Layer -> Regional Guidance
|
||||
* - Control Layer -> Control Layer
|
||||
* - Control Layer -> Inpaint Mask
|
||||
* - Control Layer -> Regional Guidance
|
||||
* - Inpaint Mask -> Regional Guidance
|
||||
* - Regional Guidance -> Inpaint Mask
|
||||
*/
|
||||
export const converters = {
|
||||
rasterLayer: {
|
||||
toControlLayer: convertRasterLayerToControlLayer,
|
||||
toInpaintMask: convertRasterLayerToInpaintMask,
|
||||
toRegionalGuidance: convertRasterLayerToRegionalGuidance,
|
||||
},
|
||||
controlLayer: {
|
||||
toRasterLayer: convertControlLayerToRasterLayer,
|
||||
toInpaintMask: convertControlLayerToInpaintMask,
|
||||
toRegionalGuidance: convertControlLayerToRegionalGuidance,
|
||||
},
|
||||
inpaintMask: {
|
||||
toRegionalGuidance: convertInpaintMaskToRegionalGuidance,
|
||||
},
|
||||
regionalGuidance: {
|
||||
toInpaintMask: convertRegionalGuidanceToInpaintMask,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Link } from '@invoke-ai/ui-library';
|
||||
import { Flex, Link, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $projectName, $projectUrl } from 'app/store/nanostores/projectId';
|
||||
import { memo } from 'react';
|
||||
@@ -9,13 +9,15 @@ export const GalleryHeader = memo(() => {
|
||||
|
||||
if (projectName && projectUrl) {
|
||||
return (
|
||||
<Link fontSize="md" fontWeight="semibold" noOfLines={1} wordBreak="break-all" href={projectUrl}>
|
||||
{projectName}
|
||||
</Link>
|
||||
<Flex gap={2} alignItems="center" justifyContent="space-evenly" pe={2} w="50%">
|
||||
<Text fontSize="md" fontWeight="semibold" noOfLines={1} wordBreak="break-all" w="full" textAlign="center">
|
||||
<Link href={projectUrl}>{projectName}</Link>
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
return <Spacer />;
|
||||
});
|
||||
|
||||
GalleryHeader.displayName = 'GalleryHeader';
|
||||
|
||||
@@ -51,8 +51,8 @@ const GalleryPanelContent = () => {
|
||||
|
||||
return (
|
||||
<Flex ref={galleryPanelFocusRef} position="relative" flexDirection="column" h="full" w="full" tabIndex={-1}>
|
||||
<Flex alignItems="center" justifyContent="space-between" w="full">
|
||||
<Flex flexGrow={1} flexBasis={0}>
|
||||
<Flex alignItems="center" w="full">
|
||||
<Flex w="25%">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
@@ -62,10 +62,8 @@ const GalleryPanelContent = () => {
|
||||
{boardsListPanel.isCollapsed ? t('boards.viewBoards') : t('boards.hideBoards')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Flex>
|
||||
<GalleryHeader />
|
||||
</Flex>
|
||||
<Flex flexGrow={1} flexBasis={0} justifyContent="flex-end">
|
||||
<GalleryHeader />
|
||||
<Flex h="full" w="25%" justifyContent="flex-end">
|
||||
<BoardsSettingsPopover />
|
||||
<IconButton
|
||||
size="sm"
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
PiArrowBendUpLeftBold,
|
||||
PiArrowsCounterClockwiseBold,
|
||||
PiAsteriskBold,
|
||||
PiPaintBrushBold,
|
||||
@@ -16,36 +14,28 @@ import {
|
||||
export const ImageMenuItemMetadataRecallActions = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const imageDTO = useImageDTOContext();
|
||||
const subMenu = useSubMenu();
|
||||
|
||||
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, createAsPreset } =
|
||||
useImageActions(imageDTO);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiArrowBendUpLeftBold />}>
|
||||
<Menu {...subMenu.menuProps}>
|
||||
<MenuButton {...subMenu.menuButtonProps}>
|
||||
<SubMenuButtonContent label="Recall Metadata" />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={remix} isDisabled={!hasMetadata}>
|
||||
{t('parameters.remixImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiQuotesBold />} onClick={recallPrompts} isDisabled={!hasPrompts}>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlantBold />} onClick={recallSeed} isDisabled={!hasSeed}>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiAsteriskBold />} onClick={recallAll} isDisabled={!hasMetadata}>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPaintBrushBold />} onClick={createAsPreset} isDisabled={!hasPrompts}>
|
||||
{t('stylePresets.useForTemplate')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</MenuItem>
|
||||
<>
|
||||
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClickCapture={remix} isDisabled={!hasMetadata}>
|
||||
{t('parameters.remixImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiQuotesBold />} onClickCapture={recallPrompts} isDisabled={!hasPrompts}>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlantBold />} onClickCapture={recallSeed} isDisabled={!hasSeed}>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiAsteriskBold />} onClickCapture={recallAll} isDisabled={!hasMetadata}>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPaintBrushBold />} onClickCapture={createAsPreset} isDisabled={!hasPrompts}>
|
||||
{t('stylePresets.useForTemplate')}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -21,12 +21,9 @@ export const useBuildModelInstallArg = () => {
|
||||
});
|
||||
|
||||
const getIsInstalled = useCallback(
|
||||
({ source, name, base, type, is_installed, previous_names }: StarterModel): boolean =>
|
||||
({ source, name, base, type, is_installed }: StarterModel): boolean =>
|
||||
modelList.some(
|
||||
(mc) =>
|
||||
is_installed ||
|
||||
source === mc.source ||
|
||||
(base === mc.base && (name === mc.name || previous_names?.includes(name)) && type === mc.type)
|
||||
(mc) => is_installed || source === mc.source || (base === mc.base && name === mc.name && type === mc.type)
|
||||
),
|
||||
[modelList]
|
||||
);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, Flex, ListItem, Text, Tooltip, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { flattenStarterModel, useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
|
||||
import { isMainModelBase } from 'features/nodes/types/common';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
@@ -44,15 +44,8 @@ export const StarterBundle = ({ bundleName, bundle }: { bundleName: string; bund
|
||||
return (
|
||||
<Tooltip
|
||||
label={
|
||||
<Flex flexDir="column" p={1}>
|
||||
<Text>{t('modelManager.includesNModels', { n: bundle.length })}:</Text>
|
||||
<UnorderedList>
|
||||
{bundle.map((model, index) => (
|
||||
<ListItem key={index} wordBreak="break-all">
|
||||
{model.name}
|
||||
</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
<Flex flexDir="column">
|
||||
<Text>{t('modelManager.includesNModels', { n: bundle.length })}</Text>
|
||||
</Flex>
|
||||
}
|
||||
>
|
||||
|
||||
@@ -1,4 +1,14 @@
|
||||
import { Flex, Icon, IconButton, Input, InputGroup, InputRightElement, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
Icon,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { map, size } from 'lodash-es';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
@@ -49,14 +59,14 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
{size(results.starter_bundles) > 0 && (
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Flex gap={1} alignItems="center">
|
||||
<Text color="base.200" fontWeight="semibold">
|
||||
{t('modelManager.starterBundles')}
|
||||
</Text>
|
||||
<Tooltip label={t('modelManager.starterBundleHelpText')}>
|
||||
<Flex alignItems="center">
|
||||
<Box>
|
||||
<Icon as={PiInfoBold} color="base.200" />
|
||||
</Flex>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
|
||||
@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
any: 'base',
|
||||
'sd-1': 'green',
|
||||
'sd-2': 'teal',
|
||||
'sd-3': 'purple',
|
||||
sdxl: 'invokeBlue',
|
||||
'sdxl-refiner': 'invokeBlue',
|
||||
flux: 'gold',
|
||||
|
||||
@@ -34,6 +34,8 @@ import {
|
||||
isModelIdentifierFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isSD3MainModelFieldInputInstance,
|
||||
isSD3MainModelFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
@@ -66,6 +68,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
|
||||
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
@@ -168,10 +171,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
import type { CLIPEmbedModelConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -19,7 +19,7 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
(value: CLIPEmbedModelConfig | null) => {
|
||||
(value: CLIPEmbedModelConfig | MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } f
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -19,7 +19,7 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
(value: VAEModelConfig | MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
|
||||
|
||||
const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl
|
||||
className="nowheel nodrag"
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && props.fieldTemplate.required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SD3MainModelFieldInputComponent);
|
||||
@@ -7,7 +7,11 @@ import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
MainModelConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bModelConfig,
|
||||
T5EncoderModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -20,7 +24,7 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
const _onChange = useCallback(
|
||||
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
@@ -40,14 +44,14 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
const required = props.fieldTemplate.required;
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!isModelsTabDisabled && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
|
||||
@@ -5,7 +5,7 @@ import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -16,7 +16,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
(value: VAEModelConfig | MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -61,8 +61,8 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sdxl', 'flux']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux']);
|
||||
export type MainModelBase = z.infer<typeof zMainModelBase>;
|
||||
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
|
||||
const zModelType = z.enum([
|
||||
@@ -84,8 +84,10 @@ const zSubModelType = z.enum([
|
||||
'transformer',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'text_encoder_3',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'tokenizer_3',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
|
||||
@@ -32,6 +32,7 @@ export const MODEL_TYPES = [
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'FluxMainModelField',
|
||||
'SD3MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
@@ -65,6 +66,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
FluxMainModelField: 'teal.500',
|
||||
SD3MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SpandrelImageToImageModelField: 'teal.500',
|
||||
|
||||
@@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SD3MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -174,6 +178,7 @@ const zStatefulFieldType = z.union([
|
||||
zModelIdentifierFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSD3MainModelFieldType,
|
||||
zFluxMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
@@ -467,6 +472,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
|
||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SD3MainModelField
|
||||
|
||||
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
});
|
||||
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
|
||||
zSD3MainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
|
||||
zSD3MainModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region FluxMainModelField
|
||||
|
||||
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
@@ -806,6 +834,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zFluxMainModelFieldValue,
|
||||
zSD3MainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
@@ -837,6 +866,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zFluxMainModelFieldInputInstance,
|
||||
zSD3MainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
@@ -870,6 +900,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zFluxMainModelFieldInputTemplate,
|
||||
zSD3MainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
@@ -904,6 +935,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zFluxMainModelFieldOutputTemplate,
|
||||
zSD3MainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
|
||||
@@ -106,12 +106,10 @@ export const getInfill = (
|
||||
}
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
const { a, ...rgb } = infillColorValue;
|
||||
const color = { ...rgb, a: Math.round(a * 255) };
|
||||
return g.addNode({
|
||||
id: 'infill_rgba',
|
||||
type: 'infill_rgba',
|
||||
color,
|
||||
color: infillColorValue,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SchedulerField: 'dpmpp_3m_k',
|
||||
SDXLMainModelField: undefined,
|
||||
FluxMainModelField: undefined,
|
||||
SD3MainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
|
||||
@@ -18,6 +18,7 @@ import type {
|
||||
MainModelFieldInputTemplate,
|
||||
ModelIdentifierFieldInputTemplate,
|
||||
SchedulerFieldInputTemplate,
|
||||
SD3MainModelFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
@@ -198,6 +199,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainMo
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SD3MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -446,6 +461,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
MainModelField: buildMainModelFieldInputTemplate,
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
|
||||
@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'FluxMainModelField',
|
||||
'SD3MainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VAEModelField',
|
||||
'LoRAModelField',
|
||||
|
||||
@@ -6,7 +6,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
import type { CLIPEmbedModelConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamCLIPEmbedModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -15,7 +15,7 @@ const ParamCLIPEmbedModelSelect = () => {
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(clipEmbedModel: CLIPEmbedModelConfig | null) => {
|
||||
(clipEmbedModel: CLIPEmbedModelConfig | MainModelConfig | null) => {
|
||||
if (clipEmbedModel) {
|
||||
dispatch(clipEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
|
||||
}
|
||||
|
||||
@@ -6,7 +6,11 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
MainModelConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bModelConfig,
|
||||
T5EncoderModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const ParamT5EncoderModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -15,7 +19,7 @@ const ParamT5EncoderModelSelect = () => {
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | MainModelConfig | null) => {
|
||||
if (t5EncoderModel) {
|
||||
dispatch(t5EncoderModelSelected(zModelIdentifierField.parse(t5EncoderModel)));
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamFLUXVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -16,7 +16,7 @@ const ParamFLUXVAEModelSelect = () => {
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEModelConfig | null) => {
|
||||
(vae: VAEModelConfig | MainModelConfig | null) => {
|
||||
if (vae) {
|
||||
dispatch(fluxVAESelected(zModelIdentifierField.parse(vae)));
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { MainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -16,7 +16,7 @@ const ParamVAEModelSelect = () => {
|
||||
const vae = useAppSelector(selectVAE);
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const getIsDisabled = useCallback(
|
||||
(vae: VAEModelConfig): boolean => {
|
||||
(vae: VAEModelConfig | MainModelConfig): boolean => {
|
||||
const isCompatible = base === vae.base;
|
||||
const hasMainModel = Boolean(base);
|
||||
return !hasMainModel || !isCompatible;
|
||||
@@ -24,7 +24,7 @@ const ParamVAEModelSelect = () => {
|
||||
[base]
|
||||
);
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEModelConfig | null) => {
|
||||
(vae: VAEModelConfig | MainModelConfig | null) => {
|
||||
dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
|
||||
},
|
||||
[dispatch]
|
||||
|
||||
@@ -19,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'SD1.X',
|
||||
'sd-2': 'SD2.X',
|
||||
'sd-3': 'SD3.X',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
flux: 'FLUX',
|
||||
@@ -40,6 +41,10 @@ export const CLIP_SKIP_MAP = {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
},
|
||||
'sd-3': {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
sdxl: {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
isNonRefinerMainModelConfig,
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSD3MainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
@@ -47,6 +48,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -75,20 +75,38 @@ export type AnyModelConfig =
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
|
||||
const check_submodel_model_type = (submodels: AnyModelConfig['submodels'], model_type: string): boolean => {
|
||||
for (const submodel in submodels) {
|
||||
if (submodel && submodels[submodel] && submodels[submodel].model_type === model_type) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const check_submodels = (indentifier: string, config: AnyModelConfig): boolean => {
|
||||
return (
|
||||
(config.type === 'main' &&
|
||||
config.submodels &&
|
||||
(indentifier in config.submodels || check_submodel_model_type(config.submodels, indentifier))) ||
|
||||
false
|
||||
);
|
||||
};
|
||||
|
||||
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
|
||||
return config.type === 'lora';
|
||||
};
|
||||
|
||||
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae';
|
||||
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
|
||||
return config.type === 'vae' || check_submodels('vae', config);
|
||||
};
|
||||
|
||||
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base !== 'flux';
|
||||
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
|
||||
return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux';
|
||||
};
|
||||
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base === 'flux';
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
|
||||
return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux';
|
||||
};
|
||||
|
||||
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
||||
@@ -109,12 +127,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
|
||||
export const isT5EncoderModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => {
|
||||
return config.type === 't5_encoder';
|
||||
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | MainModelConfig => {
|
||||
return config.type === 't5_encoder' || check_submodels('t5_encoder', config);
|
||||
};
|
||||
|
||||
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
|
||||
return config.type === 'clip_embed';
|
||||
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
|
||||
return config.type === 'clip_embed' || check_submodels('clip_embed', config);
|
||||
};
|
||||
|
||||
export const isSpandrelImageToImageModelConfig = (
|
||||
@@ -145,6 +163,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
||||
return config.type === 'main' && config.base === 'sdxl';
|
||||
};
|
||||
|
||||
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'sd-3';
|
||||
};
|
||||
|
||||
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'flux';
|
||||
};
|
||||
|
||||
@@ -33,12 +33,12 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate==0.30.1",
|
||||
"accelerate==1.0.1",
|
||||
"bitsandbytes==0.43.3; sys_platform!='darwin'",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.27.2",
|
||||
"diffusers[torch]==0.31.0",
|
||||
"gguf==0.10.0",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe>=0.10.7", # needed for "mediapipeface" controlnet model
|
||||
@@ -61,7 +61,7 @@ dependencies = [
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.1",
|
||||
"fastapi==0.111.0",
|
||||
"huggingface-hub==0.23.1",
|
||||
"huggingface-hub==0.26.1",
|
||||
"pydantic-settings==2.2.1",
|
||||
"pydantic==2.7.2",
|
||||
"python-socketio==5.11.1",
|
||||
@@ -89,7 +89,7 @@ dependencies = [
|
||||
"pypatchmatch",
|
||||
'pyperclip',
|
||||
"pyreadline3",
|
||||
"python-multipart==0.0.12",
|
||||
"python-multipart",
|
||||
"requests~=2.28.2",
|
||||
"rich~=13.3",
|
||||
"scikit-image~=0.21.0",
|
||||
|
||||
Reference in New Issue
Block a user