mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge remote-tracking branch 'upstream/main' into external-models
This commit is contained in:
27
invokeai/app/invocations/canvas.py
Normal file
27
invokeai/app/invocations/canvas.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_output",
|
||||
title="Canvas Output",
|
||||
tags=["canvas", "output", "image"],
|
||||
category="canvas",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class CanvasOutputInvocation(BaseInvocation):
|
||||
"""Outputs an image to the canvas staging area.
|
||||
|
||||
Use this node in workflows intended for canvas workflow integration.
|
||||
Connect the final image of your workflow to this node to send it
|
||||
to the canvas staging area when run via 'Run Workflow on Canvas'."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -711,6 +711,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
||||
state_dict,
|
||||
{
|
||||
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
|
||||
"transformer.layers.", # OneTrainer/diffusers prefix variant
|
||||
"base_model.model.transformer.layers.", # PEFT-wrapped variant
|
||||
},
|
||||
)
|
||||
|
||||
@@ -747,6 +749,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
||||
state_dict,
|
||||
{
|
||||
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
|
||||
"transformer.layers.", # OneTrainer/diffusers prefix variant
|
||||
"base_model.model.transformer.layers.", # PEFT-wrapped variant
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -323,6 +323,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _filename_suggests_base(name: str) -> bool:
|
||||
"""Check if a model name/filename suggests it is a Base (undistilled) variant.
|
||||
|
||||
Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished
|
||||
from the state dict. We use the filename as a heuristic: filenames containing "base"
|
||||
(e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model.
|
||||
"""
|
||||
return "base" in name.lower()
|
||||
|
||||
|
||||
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
|
||||
"""Determine FLUX.2 variant from state dict.
|
||||
|
||||
@@ -330,9 +340,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
|
||||
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
|
||||
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
|
||||
|
||||
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
|
||||
We default to Klein9B (distilled) for all 9B models since GGUF models may not
|
||||
include guidance embedding keys needed to distinguish them.
|
||||
Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
|
||||
and cannot be distinguished from the state dict alone. This function defaults to Klein9B
|
||||
for all 9B models. Callers should use filename heuristics to detect Klein9BBase.
|
||||
|
||||
Supports both BFL format (checkpoint) and diffusers format keys:
|
||||
- BFL format: txt_in.weight (context embedder)
|
||||
@@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
|
||||
context_in_dim = shape[1]
|
||||
# Determine variant based on context dimension
|
||||
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
|
||||
# Default to Klein9B (distilled) - the official/common 9B model
|
||||
# Default to Klein9B - callers use filename heuristics to detect Klein9BBase
|
||||
return Flux2VariantType.Klein9B
|
||||
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
|
||||
return Flux2VariantType.Klein4B
|
||||
@@ -553,6 +563,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con
|
||||
if variant is None:
|
||||
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
||||
|
||||
# Klein 9B Base and Klein 9B have identical architectures.
|
||||
# Use filename heuristic to detect the Base (undistilled) variant.
|
||||
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
|
||||
return Flux2VariantType.Klein9BBase
|
||||
|
||||
return variant
|
||||
|
||||
@classmethod
|
||||
@@ -720,6 +735,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba
|
||||
if variant is None:
|
||||
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
||||
|
||||
# Klein 9B Base and Klein 9B have identical architectures.
|
||||
# Use filename heuristic to detect the Base (undistilled) variant.
|
||||
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
|
||||
return Flux2VariantType.Klein9BBase
|
||||
|
||||
return variant
|
||||
|
||||
@classmethod
|
||||
@@ -829,12 +849,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
|
||||
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
|
||||
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
|
||||
|
||||
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
|
||||
we check guidance_embeds:
|
||||
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
|
||||
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
|
||||
|
||||
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
|
||||
Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
|
||||
and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
|
||||
"""
|
||||
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
||||
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
||||
@@ -842,17 +858,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
|
||||
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
||||
|
||||
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
|
||||
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
||||
|
||||
# Determine variant based on joint_attention_dim
|
||||
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
|
||||
# Check guidance_embeds to distinguish distilled from undistilled
|
||||
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
|
||||
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
|
||||
if guidance_embeds:
|
||||
if _filename_suggests_base(mod.name):
|
||||
return Flux2VariantType.Klein9BBase
|
||||
else:
|
||||
return Flux2VariantType.Klein9B
|
||||
return Flux2VariantType.Klein9B
|
||||
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
|
||||
return Flux2VariantType.Klein4B
|
||||
elif joint_attention_dim > 4096:
|
||||
|
||||
@@ -44,6 +44,10 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_bfl_format,
|
||||
lora_model_from_flux_onetrainer_bfl_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
lora_model_from_flux_onetrainer_state_dict,
|
||||
@@ -128,6 +132,8 @@ class LoRALoader(ModelLoader):
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_onetrainer_bfl_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
|
||||
@@ -81,7 +81,7 @@ t5_base_encoder = StarterModel(
|
||||
name="t5_base_encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/t5-v1_1-xxl::bfloat16",
|
||||
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
|
||||
description="T5-XXL text encoder (used in FLUX pipelines). ~9.5GB",
|
||||
type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
@@ -166,7 +166,7 @@ flux_kontext_quantized = StarterModel(
|
||||
name="FLUX.1 Kontext dev (quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
@@ -174,7 +174,7 @@ flux_krea = StarterModel(
|
||||
name="FLUX.1 Krea dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors",
|
||||
description="FLUX.1 Krea dev. Total size with dependencies: ~33GB",
|
||||
description="FLUX.1 Krea dev. Total size with dependencies: ~29GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
@@ -182,7 +182,7 @@ flux_krea_quantized = StarterModel(
|
||||
name="FLUX.1 Krea dev (quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
@@ -190,7 +190,7 @@ sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-medium",
|
||||
description="Medium SD3.5 Model: ~15GB",
|
||||
description="Medium SD3.5 Model: ~16GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
@@ -198,7 +198,7 @@ sd35_large = StarterModel(
|
||||
name="SD3.5 Large",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-large",
|
||||
description="Large SD3.5 Model: ~19G",
|
||||
description="Large SD3.5 Model: ~28GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
@@ -654,7 +654,7 @@ cogview4 = StarterModel(
|
||||
name="CogView4",
|
||||
base=BaseModelType.CogView4,
|
||||
source="THUDM/CogView4-6B",
|
||||
description="The base CogView4 model (~29GB).",
|
||||
description="The base CogView4 model (~31GB).",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
# endregion
|
||||
@@ -705,7 +705,7 @@ flux2_vae = StarterModel(
|
||||
name="FLUX.2 VAE",
|
||||
base=BaseModelType.Flux2,
|
||||
source="black-forest-labs/FLUX.2-klein-4B::vae",
|
||||
description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~335MB",
|
||||
description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~168MB",
|
||||
type=ModelType.VAE,
|
||||
)
|
||||
|
||||
@@ -729,7 +729,7 @@ flux2_klein_4b = StarterModel(
|
||||
name="FLUX.2 Klein 4B (Diffusers)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="black-forest-labs/FLUX.2-klein-4B",
|
||||
description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~10GB",
|
||||
description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~16GB",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
|
||||
@@ -755,7 +755,7 @@ flux2_klein_9b = StarterModel(
|
||||
name="FLUX.2 Klein 9B (Diffusers)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="black-forest-labs/FLUX.2-klein-9B",
|
||||
description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~20GB",
|
||||
description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~35GB",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
|
||||
@@ -831,7 +831,7 @@ z_image_turbo = StarterModel(
|
||||
name="Z-Image Turbo",
|
||||
base=BaseModelType.ZImage,
|
||||
source="Tongyi-MAI/Z-Image-Turbo",
|
||||
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~30.6GB",
|
||||
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~33GB",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
|
||||
|
||||
@@ -215,6 +215,7 @@ class FluxLoRAFormat(str, Enum):
|
||||
AIToolkit = "flux.aitoolkit"
|
||||
XLabs = "flux.xlabs"
|
||||
BflPeft = "flux.bfl_peft"
|
||||
OneTrainerBfl = "flux.onetrainer_bfl"
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Utilities for detecting and converting FLUX LoRAs in OneTrainer BFL format.
|
||||
|
||||
This format is produced by newer versions of OneTrainer and uses BFL internal key names
|
||||
(double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and
|
||||
InvokeAI-native LoRA suffixes (lora_down.weight, lora_up.weight, alpha).
|
||||
|
||||
Unlike the standard BFL PEFT format (which uses 'diffusion_model.' prefix and lora_A/lora_B),
|
||||
this format also has split QKV projections:
|
||||
- double_blocks.{i}.img_attn.qkv.{0,1,2} (Q, K, V separate)
|
||||
- double_blocks.{i}.txt_attn.qkv.{0,1,2} (Q, K, V separate)
|
||||
- single_blocks.{i}.linear1.{0,1,2,3} (Q, K, V, MLP separate)
|
||||
|
||||
Example keys:
|
||||
transformer.double_blocks.0.img_attn.qkv.0.lora_down.weight
|
||||
transformer.double_blocks.0.img_attn.qkv.0.lora_up.weight
|
||||
transformer.double_blocks.0.img_attn.qkv.0.alpha
|
||||
transformer.single_blocks.0.linear1.3.lora_down.weight
|
||||
transformer.double_blocks.0.img_mlp.0.lora_down.weight
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
_TRANSFORMER_PREFIX = "transformer."
|
||||
|
||||
# Valid LoRA weight suffixes in this format.
|
||||
_LORA_SUFFIXES = ("lora_down.weight", "lora_up.weight", "alpha")
|
||||
|
||||
# Regex to detect split QKV keys in double blocks: e.g. "double_blocks.0.img_attn.qkv.1"
|
||||
_SPLIT_QKV_RE = re.compile(r"^(double_blocks\.\d+\.(img_attn|txt_attn)\.qkv)\.\d+$")
|
||||
|
||||
# Regex to detect split linear1 keys in single blocks: e.g. "single_blocks.0.linear1.2"
|
||||
_SPLIT_LINEAR1_RE = re.compile(r"^(single_blocks\.\d+\.linear1)\.\d+$")
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_onetrainer_bfl_format(
|
||||
state_dict: dict[str | int, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Checks if the provided state dict is likely in the OneTrainer BFL FLUX LoRA format.
|
||||
|
||||
This format uses BFL internal key names with 'transformer.' prefix and split QKV projections.
|
||||
"""
|
||||
str_keys = [k for k in state_dict.keys() if isinstance(k, str)]
|
||||
if not str_keys:
|
||||
return False
|
||||
|
||||
# All keys must start with 'transformer.'
|
||||
if not all(k.startswith(_TRANSFORMER_PREFIX) for k in str_keys):
|
||||
return False
|
||||
|
||||
# All keys must end with recognized LoRA suffixes.
|
||||
if not all(k.endswith(_LORA_SUFFIXES) for k in str_keys):
|
||||
return False
|
||||
|
||||
# Must have BFL block structure (double_blocks or single_blocks) under transformer prefix.
|
||||
has_bfl_blocks = any(
|
||||
k.startswith("transformer.double_blocks.") or k.startswith("transformer.single_blocks.") for k in str_keys
|
||||
)
|
||||
if not has_bfl_blocks:
|
||||
return False
|
||||
|
||||
# Must have split QKV pattern (qkv.0, qkv.1, qkv.2) to distinguish from other formats
|
||||
# that might use transformer. prefix in the future.
|
||||
has_split_qkv = any(".qkv.0." in k or ".qkv.1." in k or ".qkv.2." in k or ".linear1.0." in k for k in str_keys)
|
||||
if not has_split_qkv:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _split_key(key: str) -> tuple[str, str]:
|
||||
"""Split a key into (layer_name, weight_suffix).
|
||||
|
||||
Handles:
|
||||
- 2-component suffixes ending with '.weight': e.g., 'lora_down.weight' → split at 2nd-to-last dot
|
||||
- 1-component suffixes: e.g., 'alpha' → split at last dot
|
||||
"""
|
||||
if key.endswith(".weight"):
|
||||
parts = key.rsplit(".", maxsplit=2)
|
||||
return parts[0], f"{parts[1]}.{parts[2]}"
|
||||
else:
|
||||
parts = key.rsplit(".", maxsplit=1)
|
||||
return parts[0], parts[1]
|
||||
|
||||
|
||||
def lora_model_from_flux_onetrainer_bfl_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
"""Convert a OneTrainer BFL format FLUX LoRA state dict to a ModelPatchRaw.
|
||||
|
||||
Strips the 'transformer.' prefix, groups by layer, and merges split QKV/linear1
|
||||
layers into MergedLayerPatch instances.
|
||||
"""
|
||||
# Step 1: Strip prefix and group by layer name.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
|
||||
# Strip 'transformer.' prefix.
|
||||
key = key[len(_TRANSFORMER_PREFIX) :]
|
||||
|
||||
layer_name, suffix = _split_key(key)
|
||||
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][suffix] = value
|
||||
|
||||
# Step 2: Build LoRA layers, merging split QKV and linear1.
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
|
||||
# Identify which layers need merging.
|
||||
merge_groups: dict[str, list[str]] = {}
|
||||
standalone_keys: list[str] = []
|
||||
|
||||
for layer_key in grouped_state_dict:
|
||||
qkv_match = _SPLIT_QKV_RE.match(layer_key)
|
||||
linear1_match = _SPLIT_LINEAR1_RE.match(layer_key)
|
||||
|
||||
if qkv_match:
|
||||
parent = qkv_match.group(1)
|
||||
if parent not in merge_groups:
|
||||
merge_groups[parent] = []
|
||||
merge_groups[parent].append(layer_key)
|
||||
elif linear1_match:
|
||||
parent = linear1_match.group(1)
|
||||
if parent not in merge_groups:
|
||||
merge_groups[parent] = []
|
||||
merge_groups[parent].append(layer_key)
|
||||
else:
|
||||
standalone_keys.append(layer_key)
|
||||
|
||||
# Process standalone layers.
|
||||
for layer_key in standalone_keys:
|
||||
layer_sd = grouped_state_dict[layer_key]
|
||||
layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"] = any_lora_layer_from_state_dict(layer_sd)
|
||||
|
||||
# Process merged layers.
|
||||
for parent_key, sub_keys in merge_groups.items():
|
||||
# Sort by the numeric index at the end (e.g., qkv.0, qkv.1, qkv.2).
|
||||
sub_keys.sort(key=lambda k: int(k.rsplit(".", maxsplit=1)[1]))
|
||||
|
||||
sub_layers: list[BaseLayerPatch] = []
|
||||
sub_ranges: list[Range] = []
|
||||
dim_0_offset = 0
|
||||
|
||||
for sub_key in sub_keys:
|
||||
layer_sd = grouped_state_dict[sub_key]
|
||||
sub_layer = any_lora_layer_from_state_dict(layer_sd)
|
||||
|
||||
# Determine the output dimension from the up weight shape.
|
||||
up_weight = layer_sd["lora_up.weight"]
|
||||
out_dim = up_weight.shape[0]
|
||||
|
||||
sub_layers.append(sub_layer)
|
||||
sub_ranges.append(Range(dim_0_offset, dim_0_offset + out_dim))
|
||||
dim_0_offset += out_dim
|
||||
|
||||
layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{parent_key}"] = MergedLayerPatch(sub_layers, sub_ranges)
|
||||
|
||||
return ModelPatchRaw(layers=layers)
|
||||
@@ -14,6 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_bfl_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
)
|
||||
@@ -28,6 +31,8 @@ def flux_format_from_state_dict(
|
||||
) -> FluxLoRAFormat | None:
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict):
|
||||
return FluxLoRAFormat.Kohya
|
||||
elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict, metadata):
|
||||
return FluxLoRAFormat.OneTrainerBfl
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
|
||||
return FluxLoRAFormat.OneTrainer
|
||||
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):
|
||||
|
||||
@@ -2415,6 +2415,27 @@
|
||||
"pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage",
|
||||
"addAdjustments": "Add Adjustments",
|
||||
"removeAdjustments": "Remove Adjustments",
|
||||
"workflowIntegration": {
|
||||
"title": "Run Workflow on Canvas",
|
||||
"description": "Select a workflow with a Canvas Output node and an image parameter to run on the current canvas layer. You can adjust parameters before executing. The result will be added back to the canvas.",
|
||||
"execute": "Execute Workflow",
|
||||
"executing": "Executing...",
|
||||
"runWorkflow": "Run Workflow",
|
||||
"filteringWorkflows": "Filtering workflows...",
|
||||
"loadingWorkflows": "Loading workflows...",
|
||||
"noWorkflowsFound": "No workflows found.",
|
||||
"noWorkflowsWithImageField": "No compatible workflows found. A workflow needs a Form Builder with an image input field and a Canvas Output node.",
|
||||
"selectWorkflow": "Select Workflow",
|
||||
"selectPlaceholder": "Choose a workflow...",
|
||||
"unnamedWorkflow": "Unnamed Workflow",
|
||||
"loadingParameters": "Loading workflow parameters...",
|
||||
"noFormBuilderError": "This workflow has no form builder and cannot be used. Please select a different workflow.",
|
||||
"imageFieldSelected": "This field will receive the canvas image",
|
||||
"imageFieldNotSelected": "Click to use this field for canvas image",
|
||||
"executionStarted": "Workflow execution started",
|
||||
"executionStartedDescription": "The result will appear in the staging area when complete.",
|
||||
"executionFailed": "Failed to execute workflow"
|
||||
},
|
||||
"compositeOperation": {
|
||||
"label": "Blend Mode",
|
||||
"add": "Add Blend Mode",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
|
||||
import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { CropImageModal } from 'features/cropper/components/CropImageModal';
|
||||
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
@@ -51,6 +52,7 @@ export const GlobalModalIsolator = memo(() => {
|
||||
<SaveWorkflowAsDialog />
|
||||
<CanvasManagerProviderGate>
|
||||
<CanvasPasteModal />
|
||||
<CanvasWorkflowIntegrationModal />
|
||||
</CanvasManagerProviderGate>
|
||||
<LoadWorkflowFromGraphModal />
|
||||
<CropImageModal />
|
||||
|
||||
@@ -16,6 +16,7 @@ const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||
|
||||
export const zLogNamespace = z.enum([
|
||||
'canvas',
|
||||
'canvas-workflow-integration',
|
||||
'config',
|
||||
'dnd',
|
||||
'events',
|
||||
|
||||
@@ -25,6 +25,7 @@ import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSe
|
||||
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { canvasTextSliceConfig } from 'features/controlLayers/store/canvasTextSlice';
|
||||
import { canvasWorkflowIntegrationSliceConfig } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
|
||||
@@ -67,6 +68,7 @@ const SLICE_CONFIGS = {
|
||||
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
|
||||
[canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig,
|
||||
[canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
|
||||
[canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig,
|
||||
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
|
||||
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig,
|
||||
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig,
|
||||
@@ -98,6 +100,7 @@ const ALL_REDUCERS = {
|
||||
canvasSliceConfig.slice.reducer,
|
||||
canvasSliceConfig.undoableConfig?.reduxUndoOptions
|
||||
),
|
||||
[canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig.slice.reducer,
|
||||
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer,
|
||||
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer,
|
||||
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer,
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
Heading,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Spacer,
|
||||
Spinner,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
canvasWorkflowIntegrationClosed,
|
||||
selectCanvasWorkflowIntegrationIsOpen,
|
||||
selectCanvasWorkflowIntegrationIsProcessing,
|
||||
selectCanvasWorkflowIntegrationSelectedWorkflowId,
|
||||
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { CanvasWorkflowIntegrationParameterPanel } from './CanvasWorkflowIntegrationParameterPanel';
|
||||
import { CanvasWorkflowIntegrationWorkflowSelector } from './CanvasWorkflowIntegrationWorkflowSelector';
|
||||
import { useCanvasWorkflowIntegrationExecute } from './useCanvasWorkflowIntegrationExecute';
|
||||
|
||||
export const CanvasWorkflowIntegrationModal = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const isOpen = useAppSelector(selectCanvasWorkflowIntegrationIsOpen);
|
||||
const isProcessing = useAppSelector(selectCanvasWorkflowIntegrationIsProcessing);
|
||||
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
|
||||
|
||||
const { execute, canExecute } = useCanvasWorkflowIntegrationExecute();
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
if (!isProcessing) {
|
||||
dispatch(canvasWorkflowIntegrationClosed());
|
||||
}
|
||||
}, [dispatch, isProcessing]);
|
||||
|
||||
const onExecute = useCallback(() => {
|
||||
execute();
|
||||
}, [execute]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={onClose} size="xl" isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>
|
||||
<Heading size="md">{t('controlLayers.workflowIntegration.title')}</Heading>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton isDisabled={isProcessing} />
|
||||
|
||||
<ModalBody>
|
||||
<Flex direction="column" gap={4}>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{t('controlLayers.workflowIntegration.description')}
|
||||
</Text>
|
||||
|
||||
<CanvasWorkflowIntegrationWorkflowSelector />
|
||||
|
||||
{selectedWorkflowId && <CanvasWorkflowIntegrationParameterPanel />}
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<ButtonGroup>
|
||||
<Button variant="ghost" onClick={onClose} isDisabled={isProcessing}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button
|
||||
onClick={onExecute}
|
||||
isDisabled={!canExecute || isProcessing}
|
||||
loadingText={t('controlLayers.workflowIntegration.executing')}
|
||||
>
|
||||
{isProcessing && <Spinner size="sm" mr={2} />}
|
||||
{t('controlLayers.workflowIntegration.execute')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasWorkflowIntegrationModal.displayName = 'CanvasWorkflowIntegrationModal';
|
||||
@@ -0,0 +1,13 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { WorkflowFormPreview } from 'features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFormPreview';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const CanvasWorkflowIntegrationParameterPanel = memo(() => {
|
||||
return (
|
||||
<Box w="full">
|
||||
<WorkflowFormPreview />
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasWorkflowIntegrationParameterPanel.displayName = 'CanvasWorkflowIntegrationParameterPanel';
|
||||
@@ -0,0 +1,92 @@
|
||||
import { Flex, FormControl, FormLabel, Select, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
canvasWorkflowIntegrationWorkflowSelected,
|
||||
selectCanvasWorkflowIntegrationSelectedWorkflowId,
|
||||
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListWorkflowsInfiniteInfiniteQuery } from 'services/api/endpoints/workflows';
|
||||
|
||||
import { useFilteredWorkflows } from './useFilteredWorkflows';
|
||||
|
||||
export const CanvasWorkflowIntegrationWorkflowSelector = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
|
||||
const { data: workflowsData, isLoading } = useListWorkflowsInfiniteInfiniteQuery(
|
||||
{
|
||||
per_page: 100, // Get a reasonable number of workflows
|
||||
page: 0,
|
||||
},
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
data,
|
||||
isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const workflows = useMemo(() => {
|
||||
if (!workflowsData) {
|
||||
return [];
|
||||
}
|
||||
// Flatten all pages into a single list
|
||||
return workflowsData.pages.flatMap((page) => page.items);
|
||||
}, [workflowsData]);
|
||||
|
||||
// Filter workflows to only show those with ImageFields
|
||||
const { filteredWorkflows, isFiltering } = useFilteredWorkflows(workflows);
|
||||
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const workflowId = e.target.value || null;
|
||||
dispatch(canvasWorkflowIntegrationWorkflowSelected({ workflowId }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
if (isLoading || isFiltering) {
|
||||
return (
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Spinner size="sm" />
|
||||
<Text>
|
||||
{isFiltering
|
||||
? t('controlLayers.workflowIntegration.filteringWorkflows')
|
||||
: t('controlLayers.workflowIntegration.loadingWorkflows')}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (filteredWorkflows.length === 0) {
|
||||
return (
|
||||
<Text color="warning.400" fontSize="sm">
|
||||
{workflows.length === 0
|
||||
? t('controlLayers.workflowIntegration.noWorkflowsFound')
|
||||
: t('controlLayers.workflowIntegration.noWorkflowsWithImageField')}
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('controlLayers.workflowIntegration.selectWorkflow')}</FormLabel>
|
||||
<Select
|
||||
placeholder={t('controlLayers.workflowIntegration.selectPlaceholder')}
|
||||
value={selectedWorkflowId || ''}
|
||||
onChange={onChange}
|
||||
>
|
||||
{filteredWorkflows.map((workflow) => (
|
||||
<option key={workflow.workflow_id} value={workflow.workflow_id}>
|
||||
{workflow.name || t('controlLayers.workflowIntegration.unnamedWorkflow')}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasWorkflowIntegrationWorkflowSelector.displayName = 'CanvasWorkflowIntegrationWorkflowSelector';
|
||||
@@ -0,0 +1,548 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Combobox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
IconButton,
|
||||
Input,
|
||||
Radio,
|
||||
Select,
|
||||
Switch,
|
||||
Text,
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
|
||||
import {
|
||||
canvasWorkflowIntegrationFieldValueChanged,
|
||||
canvasWorkflowIntegrationImageFieldSelected,
|
||||
selectCanvasWorkflowIntegrationFieldValues,
|
||||
selectCanvasWorkflowIntegrationSelectedImageFieldKey,
|
||||
selectCanvasWorkflowIntegrationSelectedWorkflowId,
|
||||
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
|
||||
import type { AnyModelConfig, ImageDTO } from 'services/api/types';
|
||||
|
||||
const log = logger('canvas-workflow-integration');
|
||||
|
||||
interface WorkflowFieldRendererProps {
|
||||
el: NodeFieldElement;
|
||||
}
|
||||
|
||||
export const WorkflowFieldRenderer = memo(({ el }: WorkflowFieldRendererProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
|
||||
const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues);
|
||||
const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey);
|
||||
const templates = useStore($templates);
|
||||
|
||||
const { data: workflow } = useGetWorkflowQuery(selectedWorkflowId!, {
|
||||
skip: !selectedWorkflowId,
|
||||
});
|
||||
|
||||
// Load boards and models for BoardField and ModelIdentifierField
|
||||
const { data: boardsData } = useListAllBoardsQuery({ include_archived: true });
|
||||
const { data: modelsData, isLoading: isLoadingModels } = useGetModelConfigsQuery();
|
||||
|
||||
const { fieldIdentifier } = el.data;
|
||||
const fieldKey = `${fieldIdentifier.nodeId}.${fieldIdentifier.fieldName}`;
|
||||
|
||||
log.debug({ fieldIdentifier, fieldKey }, 'Rendering workflow field');
|
||||
|
||||
// Get the node, field instance, and field template
|
||||
const { field, fieldTemplate } = useMemo(() => {
|
||||
if (!workflow?.workflow.nodes) {
|
||||
log.warn('No workflow nodes found');
|
||||
return { field: null, fieldTemplate: null };
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const foundNode = workflow.workflow.nodes.find((n: any) => n.data.id === fieldIdentifier.nodeId);
|
||||
if (!foundNode) {
|
||||
log.warn({ nodeId: fieldIdentifier.nodeId }, 'Node not found');
|
||||
return { field: null, fieldTemplate: null };
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const foundField = (foundNode.data as any).inputs[fieldIdentifier.fieldName];
|
||||
if (!foundField) {
|
||||
log.warn({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName }, 'Field not found in node');
|
||||
return { field: null, fieldTemplate: null };
|
||||
}
|
||||
|
||||
// Get the field template from the invocation templates
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const nodeType = (foundNode.data as any).type;
|
||||
const template = templates[nodeType];
|
||||
if (!template) {
|
||||
log.warn({ nodeType }, 'No template found for node type');
|
||||
return { field: foundField, fieldTemplate: null };
|
||||
}
|
||||
|
||||
const foundFieldTemplate = template.inputs[fieldIdentifier.fieldName];
|
||||
if (!foundFieldTemplate) {
|
||||
log.warn({ nodeType, fieldName: fieldIdentifier.fieldName }, 'Field template not found');
|
||||
return { field: foundField, fieldTemplate: null };
|
||||
}
|
||||
|
||||
return { field: foundField, fieldTemplate: foundFieldTemplate };
|
||||
}, [workflow, fieldIdentifier, templates]);
|
||||
|
||||
// Get the current value from Redux or fallback to field default
|
||||
const currentValue = useMemo(() => {
|
||||
if (fieldValues && fieldKey in fieldValues) {
|
||||
return fieldValues[fieldKey];
|
||||
}
|
||||
|
||||
return field?.value ?? fieldTemplate?.default ?? '';
|
||||
}, [fieldValues, fieldKey, field, fieldTemplate]);
|
||||
|
||||
// Get field type from the template
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const fieldType = fieldTemplate ? (fieldTemplate as any).type?.name : null;
|
||||
|
||||
const handleChange = useCallback(
|
||||
(value: unknown) => {
|
||||
dispatch(canvasWorkflowIntegrationFieldValueChanged({ fieldName: fieldKey, value }));
|
||||
},
|
||||
[dispatch, fieldKey]
|
||||
);
|
||||
|
||||
const handleStringChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
handleChange(e.target.value);
|
||||
},
|
||||
[handleChange]
|
||||
);
|
||||
|
||||
const handleNumberChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const val = fieldType === 'IntegerField' ? parseInt(e.target.value, 10) : parseFloat(e.target.value);
|
||||
handleChange(isNaN(val) ? 0 : val);
|
||||
},
|
||||
[handleChange, fieldType]
|
||||
);
|
||||
|
||||
const handleBooleanChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
handleChange(e.target.checked);
|
||||
},
|
||||
[handleChange]
|
||||
);
|
||||
|
||||
const handleSelectChange = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
handleChange(e.target.value);
|
||||
},
|
||||
[handleChange]
|
||||
);
|
||||
|
||||
// SchedulerField handlers
|
||||
const handleSchedulerChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterScheduler(v?.value)) {
|
||||
return;
|
||||
}
|
||||
handleChange(v.value);
|
||||
},
|
||||
[handleChange]
|
||||
);
|
||||
|
||||
const schedulerValue = useMemo(() => SCHEDULER_OPTIONS.find((o) => o.value === currentValue), [currentValue]);
|
||||
|
||||
// BoardField handlers
|
||||
const handleBoardChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
const value = v.value === 'auto' || v.value === 'none' ? v.value : { board_id: v.value };
|
||||
handleChange(value);
|
||||
},
|
||||
[handleChange]
|
||||
);
|
||||
|
||||
const boardOptions = useMemo<ComboboxOption[]>(() => {
|
||||
const _options: ComboboxOption[] = [
|
||||
{ label: t('common.auto'), value: 'auto' },
|
||||
{ label: `${t('common.none')} (${t('boards.uncategorized')})`, value: 'none' },
|
||||
];
|
||||
if (boardsData) {
|
||||
for (const board of boardsData) {
|
||||
_options.push({
|
||||
label: board.board_name,
|
||||
value: board.board_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
return _options;
|
||||
}, [boardsData, t]);
|
||||
|
||||
const boardValue = useMemo(() => {
|
||||
const _value = currentValue;
|
||||
const autoOption = boardOptions[0];
|
||||
const noneOption = boardOptions[1];
|
||||
if (!_value || _value === 'auto') {
|
||||
return autoOption;
|
||||
}
|
||||
if (_value === 'none') {
|
||||
return noneOption;
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const boardId = typeof _value === 'object' ? (_value as any).board_id : _value;
|
||||
const boardOption = boardOptions.find((o) => o.value === boardId);
|
||||
return boardOption ?? autoOption;
|
||||
}, [currentValue, boardOptions]);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('boards.noMatching'), [t]);
|
||||
|
||||
// ModelIdentifierField handlers
|
||||
const handleModelChange = useCallback(
|
||||
(value: AnyModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
handleChange(value);
|
||||
},
|
||||
[handleChange]
|
||||
);
|
||||
|
||||
const modelConfigs = useMemo(() => {
|
||||
if (!modelsData) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const ui_model_base = fieldTemplate ? (fieldTemplate as any)?.ui_model_base : null;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const ui_model_type = fieldTemplate ? (fieldTemplate as any)?.ui_model_type : null;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const ui_model_variant = fieldTemplate ? (fieldTemplate as any)?.ui_model_variant : null;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const ui_model_format = fieldTemplate ? (fieldTemplate as any)?.ui_model_format : null;
|
||||
|
||||
if (!ui_model_base && !ui_model_type) {
|
||||
return modelConfigsAdapterSelectors.selectAll(modelsData);
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(modelsData).filter((config) => {
|
||||
if (ui_model_base && !ui_model_base.includes(config.base)) {
|
||||
return false;
|
||||
}
|
||||
if (ui_model_type && !ui_model_type.includes(config.type)) {
|
||||
return false;
|
||||
}
|
||||
if (ui_model_variant && 'variant' in config && config.variant && !ui_model_variant.includes(config.variant)) {
|
||||
return false;
|
||||
}
|
||||
if (ui_model_format && !ui_model_format.includes(config.format)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}, [modelsData, fieldTemplate]);
|
||||
|
||||
// ImageField handler
|
||||
const handleImageFieldSelect = useCallback(() => {
|
||||
dispatch(canvasWorkflowIntegrationImageFieldSelected({ fieldKey }));
|
||||
}, [dispatch, fieldKey]);
|
||||
|
||||
if (!field || !fieldTemplate) {
|
||||
log.warn({ fieldIdentifier }, 'Field or template is null - not rendering');
|
||||
return null;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const label = (field as any)?.label || (fieldTemplate as any)?.title || fieldIdentifier.fieldName;
|
||||
|
||||
// Log the entire field structure to understand its shape
|
||||
log.debug(
|
||||
{ fieldType, label, currentValue, fieldStructure: field, fieldTemplateStructure: fieldTemplate },
|
||||
'Field info'
|
||||
);
|
||||
|
||||
// ImageField - allow user to select which one receives the canvas image
|
||||
if (fieldType === 'ImageField') {
|
||||
return (
|
||||
<ImageFieldComponent
|
||||
label={label}
|
||||
fieldKey={fieldKey}
|
||||
currentValue={currentValue}
|
||||
selectedImageFieldKey={selectedImageFieldKey}
|
||||
fieldTemplate={fieldTemplate}
|
||||
handleImageFieldSelect={handleImageFieldSelect}
|
||||
handleChange={handleChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Render different input types based on field type
|
||||
if (fieldType === 'StringField') {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const isTextarea = (fieldTemplate as any)?.ui_component === 'textarea';
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const isRequired = (fieldTemplate as any)?.required ?? false;
|
||||
|
||||
if (isTextarea) {
|
||||
return (
|
||||
<FormControl isRequired={isRequired}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Textarea
|
||||
value={String(currentValue)}
|
||||
onChange={handleStringChange}
|
||||
placeholder={label}
|
||||
rows={3}
|
||||
isRequired={isRequired}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl isRequired={isRequired}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Input value={String(currentValue)} onChange={handleStringChange} placeholder={label} isRequired={isRequired} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldType === 'IntegerField' || fieldType === 'FloatField') {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const min = (fieldTemplate as any)?.minimum;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const max = (fieldTemplate as any)?.maximum;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const step = fieldType === 'IntegerField' ? 1 : ((fieldTemplate as any)?.multipleOf ?? 0.01);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Input
|
||||
type="number"
|
||||
value={Number(currentValue)}
|
||||
onChange={handleNumberChange}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldType === 'BooleanField') {
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Switch isChecked={Boolean(currentValue)} onChange={handleBooleanChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldType === 'EnumField') {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const options = (fieldTemplate as any)?.options ?? (fieldTemplate as any)?.ui_choice_labels ?? [];
|
||||
const optionsList = Array.isArray(options) ? options : Object.keys(options);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Select value={String(currentValue)} onChange={handleSelectChange}>
|
||||
{optionsList.map((option: string) => (
|
||||
<option key={option} value={option}>
|
||||
{option}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldType === 'SchedulerField') {
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Combobox value={schedulerValue} options={SCHEDULER_OPTIONS} onChange={handleSchedulerChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldType === 'BoardField') {
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Combobox
|
||||
value={boardValue}
|
||||
options={boardOptions}
|
||||
onChange={handleBoardChange}
|
||||
placeholder={t('boards.selectBoard')}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldType === 'ModelIdentifierField') {
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<ModelFieldCombobox
|
||||
value={currentValue}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoadingModels}
|
||||
onChange={handleModelChange}
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
required={(fieldTemplate as any).required}
|
||||
groupByType
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
|
||||
// For other field types, show a read-only message
|
||||
log.warn(`Unsupported field type "${fieldType}" for field "${label}" - showing as read-only`);
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
<Input value={`${fieldType} (read-only)`} isReadOnly />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
WorkflowFieldRenderer.displayName = 'WorkflowFieldRenderer';
|
||||
|
||||
// Separate component for ImageField to avoid conditional hooks
|
||||
interface ImageFieldComponentProps {
|
||||
label: string;
|
||||
fieldKey: string;
|
||||
currentValue: unknown;
|
||||
selectedImageFieldKey: string | null;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
fieldTemplate: any;
|
||||
handleImageFieldSelect: () => void;
|
||||
handleChange: (value: unknown) => void;
|
||||
}
|
||||
|
||||
const ImageFieldComponent = memo(
|
||||
({
|
||||
label,
|
||||
fieldKey,
|
||||
currentValue,
|
||||
selectedImageFieldKey,
|
||||
fieldTemplate,
|
||||
handleImageFieldSelect,
|
||||
handleChange,
|
||||
}: ImageFieldComponentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isSelected = selectedImageFieldKey === fieldKey;
|
||||
|
||||
// Get image from field values (uploaded image) or from workflow field (default/saved image)
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const imageValue = currentValue as any;
|
||||
const imageName = imageValue?.image_name;
|
||||
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
|
||||
|
||||
const handleImageUpload = useCallback(
|
||||
(uploadedImage: ImageDTO) => {
|
||||
handleChange(uploadedImage);
|
||||
},
|
||||
[handleChange]
|
||||
);
|
||||
|
||||
const handleImageClear = useCallback(() => {
|
||||
handleChange(undefined);
|
||||
}, [handleChange]);
|
||||
|
||||
return (
|
||||
<FormControl overflow="hidden">
|
||||
<Flex alignItems="center" gap={2} mb={2}>
|
||||
<Radio isChecked={isSelected} onChange={handleImageFieldSelect} />
|
||||
<FormLabel mb={0} cursor="pointer" onClick={handleImageFieldSelect}>
|
||||
{label}
|
||||
</FormLabel>
|
||||
</Flex>
|
||||
<Text fontSize="sm" color="base.400" ml={6} mb={2}>
|
||||
{isSelected
|
||||
? t('controlLayers.workflowIntegration.imageFieldSelected')
|
||||
: t('controlLayers.workflowIntegration.imageFieldNotSelected')}
|
||||
</Text>
|
||||
|
||||
{/* Show image upload/preview for non-selected fields */}
|
||||
{!isSelected && (
|
||||
<Flex ml={6} position="relative" h={32} alignItems="stretch" maxW="calc(100% - 1.5rem)">
|
||||
{!imageDTO && (
|
||||
<UploadImageIconButton
|
||||
w="full"
|
||||
h="auto"
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
isError={(fieldTemplate as any)?.required && !imageValue}
|
||||
onUpload={handleImageUpload}
|
||||
fontSize={24}
|
||||
/>
|
||||
)}
|
||||
{imageDTO && (
|
||||
<Flex gap={2} alignItems="center" maxW="full">
|
||||
<Flex
|
||||
borderRadius="base"
|
||||
borderWidth={1}
|
||||
borderStyle="solid"
|
||||
overflow="hidden"
|
||||
position="relative"
|
||||
h={32}
|
||||
maxH={32}
|
||||
>
|
||||
<DndImage imageDTO={imageDTO} asThumbnail />
|
||||
<Text
|
||||
position="absolute"
|
||||
background="base.900"
|
||||
color="base.50"
|
||||
fontSize="sm"
|
||||
fontWeight="semibold"
|
||||
insetInlineEnd={1}
|
||||
insetBlockEnd={1}
|
||||
opacity={0.7}
|
||||
px={2}
|
||||
borderRadius="base"
|
||||
pointerEvents="none"
|
||||
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||
</Flex>
|
||||
<IconButton
|
||||
aria-label={t('common.clearImage', 'Clear image')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={handleImageClear}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
colorScheme="error"
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
ImageFieldComponent.displayName = 'ImageFieldComponent';
|
||||
@@ -0,0 +1,289 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowFieldRenderer } from 'features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer';
|
||||
import {
|
||||
canvasWorkflowIntegrationImageFieldSelected,
|
||||
selectCanvasWorkflowIntegrationFieldValues,
|
||||
selectCanvasWorkflowIntegrationSelectedImageFieldKey,
|
||||
selectCanvasWorkflowIntegrationSelectedWorkflowId,
|
||||
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import {
|
||||
ContainerContextProvider,
|
||||
DepthContextProvider,
|
||||
useContainerContext,
|
||||
useDepthContext,
|
||||
} from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { DividerElement } from 'features/nodes/components/sidePanel/builder/DividerElement';
|
||||
import { HeadingElement } from 'features/nodes/components/sidePanel/builder/HeadingElement';
|
||||
import { TextElement } from 'features/nodes/components/sidePanel/builder/TextElement';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import type { FormElement } from 'features/nodes/types/workflow';
|
||||
import {
|
||||
CONTAINER_CLASS_NAME,
|
||||
isContainerElement,
|
||||
isDividerElement,
|
||||
isHeadingElement,
|
||||
isNodeFieldElement,
|
||||
isTextElement,
|
||||
ROOT_CONTAINER_CLASS_NAME,
|
||||
} from 'features/nodes/types/workflow';
|
||||
import { memo, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
|
||||
|
||||
const log = logger('canvas-workflow-integration');
|
||||
|
||||
const rootViewModeSx: SystemStyleObject = {
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
gap: 2,
|
||||
display: 'flex',
|
||||
flex: 1,
|
||||
maxW: '768px',
|
||||
'&[data-self-layout="column"]': {
|
||||
flexDir: 'column',
|
||||
alignItems: 'stretch',
|
||||
},
|
||||
'&[data-self-layout="row"]': {
|
||||
flexDir: 'row',
|
||||
alignItems: 'flex-start',
|
||||
},
|
||||
};
|
||||
|
||||
const containerViewModeSx: SystemStyleObject = {
|
||||
gap: 2,
|
||||
'&[data-self-layout="column"]': {
|
||||
flexDir: 'column',
|
||||
alignItems: 'stretch',
|
||||
},
|
||||
'&[data-self-layout="row"]': {
|
||||
flexDir: 'row',
|
||||
alignItems: 'flex-start',
|
||||
overflowX: 'auto',
|
||||
overflowY: 'visible',
|
||||
h: 'min-content',
|
||||
flexShrink: 0,
|
||||
},
|
||||
'&[data-parent-layout="column"]': {
|
||||
w: 'full',
|
||||
h: 'min-content',
|
||||
},
|
||||
'&[data-parent-layout="row"]': {
|
||||
flex: '1 1 0',
|
||||
minW: 32,
|
||||
},
|
||||
};
|
||||
|
||||
export const WorkflowFormPreview = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
|
||||
const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey);
|
||||
const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues);
|
||||
const templates = useStore($templates);
|
||||
|
||||
const { data: workflow, isLoading } = useGetWorkflowQuery(selectedWorkflowId!, {
|
||||
skip: !selectedWorkflowId,
|
||||
});
|
||||
|
||||
const elements = useMemo((): Record<string, FormElement> => {
|
||||
if (!workflow?.workflow.form) {
|
||||
return {};
|
||||
}
|
||||
const els = workflow.workflow.form.elements as Record<string, FormElement>;
|
||||
log.debug({ elementCount: Object.keys(els).length, elementIds: Object.keys(els) }, 'Form elements loaded');
|
||||
return els;
|
||||
}, [workflow]);
|
||||
|
||||
const rootElementId = useMemo((): string => {
|
||||
if (!workflow?.workflow.form) {
|
||||
return '';
|
||||
}
|
||||
const rootId = workflow.workflow.form.rootElementId as string;
|
||||
log.debug({ rootElementId: rootId }, 'Root element ID');
|
||||
return rootId;
|
||||
}, [workflow]);
|
||||
|
||||
// Auto-select the image field if there's only one unfilled ImageField
|
||||
useEffect(() => {
|
||||
// Don't auto-select if user already selected one
|
||||
if (selectedImageFieldKey) {
|
||||
return;
|
||||
}
|
||||
if (!workflow?.workflow.nodes || Object.keys(elements).length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const unfilledImageFieldKeys: string[] = [];
|
||||
|
||||
for (const element of Object.values(elements)) {
|
||||
if (!isNodeFieldElement(element)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const { fieldIdentifier } = element.data;
|
||||
const fieldKey = `${fieldIdentifier.nodeId}.${fieldIdentifier.fieldName}`;
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const node = workflow.workflow.nodes.find((n: any) => n.data?.id === fieldIdentifier.nodeId);
|
||||
if (!node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const nodeType = (node.data as any)?.type;
|
||||
const template = templates[nodeType];
|
||||
if (!template?.inputs) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const fieldTemplate = template.inputs[fieldIdentifier.fieldName];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const fieldType = (fieldTemplate as any)?.type?.name;
|
||||
if (fieldType !== 'ImageField') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if the field already has a value
|
||||
const hasReduxValue = fieldValues && fieldKey in fieldValues && fieldValues[fieldKey]?.image_name;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const fieldInstance = (node.data as any)?.inputs?.[fieldIdentifier.fieldName];
|
||||
const hasWorkflowValue = fieldInstance?.value?.image_name;
|
||||
|
||||
if (!hasReduxValue && !hasWorkflowValue) {
|
||||
unfilledImageFieldKeys.push(fieldKey);
|
||||
}
|
||||
}
|
||||
|
||||
if (unfilledImageFieldKeys.length === 1) {
|
||||
log.debug({ fieldKey: unfilledImageFieldKeys[0] }, 'Auto-selecting the only unfilled ImageField');
|
||||
dispatch(canvasWorkflowIntegrationImageFieldSelected({ fieldKey: unfilledImageFieldKeys[0]! }));
|
||||
}
|
||||
}, [workflow, elements, templates, selectedImageFieldKey, fieldValues, dispatch]);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Spinner size="sm" />
|
||||
<Text>{t('controlLayers.workflowIntegration.loadingParameters')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (!workflow) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// If workflow has no form builder, it should have been filtered out
|
||||
// This is a fallback in case something went wrong
|
||||
if (Object.keys(elements).length === 0 || !rootElementId) {
|
||||
return (
|
||||
<Text fontSize="sm" color="error.400">
|
||||
{t('controlLayers.workflowIntegration.noFormBuilderError')}
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
const rootElement = elements[rootElementId];
|
||||
|
||||
if (!rootElement || !isContainerElement(rootElement)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { id, data } = rootElement;
|
||||
const { children, layout } = data;
|
||||
|
||||
return (
|
||||
<DepthContextProvider depth={0}>
|
||||
<ContainerContextProvider id={id} layout={layout}>
|
||||
<Box id={id} className={ROOT_CONTAINER_CLASS_NAME} sx={rootViewModeSx} data-self-layout={layout} data-depth={0}>
|
||||
{children.map((childId) => (
|
||||
<FormElementComponentPreview key={childId} id={childId} elements={elements} />
|
||||
))}
|
||||
</Box>
|
||||
</ContainerContextProvider>
|
||||
</DepthContextProvider>
|
||||
);
|
||||
});
|
||||
WorkflowFormPreview.displayName = 'WorkflowFormPreview';
|
||||
|
||||
const FormElementComponentPreview = memo(({ id, elements }: { id: string; elements: Record<string, FormElement> }) => {
|
||||
const el = elements[id];
|
||||
|
||||
if (!el) {
|
||||
log.warn({ id }, 'Element not found in elements map');
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug({ id, type: el.type }, 'Rendering form element');
|
||||
|
||||
if (isContainerElement(el)) {
|
||||
return <ContainerElementPreview el={el} elements={elements} />;
|
||||
}
|
||||
|
||||
if (isDividerElement(el)) {
|
||||
return <DividerElement id={id} />;
|
||||
}
|
||||
|
||||
if (isHeadingElement(el)) {
|
||||
return <HeadingElement id={id} />;
|
||||
}
|
||||
|
||||
if (isTextElement(el)) {
|
||||
return <TextElement id={id} />;
|
||||
}
|
||||
|
||||
if (isNodeFieldElement(el)) {
|
||||
return <WorkflowFieldRenderer el={el} />;
|
||||
}
|
||||
|
||||
// If we get here, it's an unknown element type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
log.warn({ id, type: (el as any).type }, 'Unknown element type - not rendering');
|
||||
return null;
|
||||
});
|
||||
FormElementComponentPreview.displayName = 'FormElementComponentPreview';
|
||||
|
||||
const ContainerElementPreview = memo(({ el, elements }: { el: FormElement; elements: Record<string, FormElement> }) => {
|
||||
const { t } = useTranslation();
|
||||
const depth = useDepthContext();
|
||||
const containerCtx = useContainerContext();
|
||||
|
||||
if (!isContainerElement(el)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { id, data } = el;
|
||||
const { children, layout } = data;
|
||||
|
||||
return (
|
||||
<DepthContextProvider depth={depth + 1}>
|
||||
<ContainerContextProvider id={id} layout={layout}>
|
||||
<Flex
|
||||
id={id}
|
||||
className={CONTAINER_CLASS_NAME}
|
||||
sx={containerViewModeSx}
|
||||
data-self-layout={layout}
|
||||
data-depth={depth}
|
||||
data-parent-layout={containerCtx.layout}
|
||||
>
|
||||
{children.map((childId) => (
|
||||
<FormElementComponentPreview key={childId} id={childId} elements={elements} />
|
||||
))}
|
||||
{children.length === 0 && (
|
||||
<Flex p={8} w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Text color="base.500" fontSize="sm" fontStyle="oblique 10deg">
|
||||
{t('workflows.builder.emptyContainer')}
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</ContainerContextProvider>
|
||||
</DepthContextProvider>
|
||||
);
|
||||
});
|
||||
ContainerElementPreview.displayName = 'ContainerElementPreview';
|
||||
@@ -0,0 +1,302 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
canvasWorkflowIntegrationClosed,
|
||||
canvasWorkflowIntegrationProcessingCompleted,
|
||||
canvasWorkflowIntegrationProcessingStarted,
|
||||
selectCanvasWorkflowIntegrationFieldValues,
|
||||
selectCanvasWorkflowIntegrationSelectedImageFieldKey,
|
||||
selectCanvasWorkflowIntegrationSelectedWorkflowId,
|
||||
selectCanvasWorkflowIntegrationSourceEntityIdentifier,
|
||||
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import { CANVAS_OUTPUT_PREFIX, getPrefixedId } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { useLazyGetWorkflowQuery } from 'services/api/endpoints/workflows';
|
||||
|
||||
const log = logger('canvas-workflow-integration');
|
||||
|
||||
export const useCanvasWorkflowIntegrationExecute = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const canvasManager = useCanvasManager();
|
||||
|
||||
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
|
||||
const sourceEntityIdentifier = useAppSelector(selectCanvasWorkflowIntegrationSourceEntityIdentifier);
|
||||
const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues);
|
||||
const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey);
|
||||
const canvasSessionId = useAppSelector(selectCanvasSessionId);
|
||||
|
||||
const [getWorkflow] = useLazyGetWorkflowQuery();
|
||||
|
||||
const canExecute = useMemo(() => {
|
||||
return Boolean(selectedWorkflowId && sourceEntityIdentifier);
|
||||
}, [selectedWorkflowId, sourceEntityIdentifier]);
|
||||
|
||||
const execute = useCallback(async () => {
|
||||
if (!selectedWorkflowId || !sourceEntityIdentifier || !canvasManager) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
dispatch(canvasWorkflowIntegrationProcessingStarted());
|
||||
|
||||
// 1. Extract the canvas layer as an image
|
||||
const adapter = canvasManager.getAdapter(sourceEntityIdentifier);
|
||||
if (!adapter) {
|
||||
throw new Error('Could not find canvas entity adapter');
|
||||
}
|
||||
|
||||
const rect = adapter.transformer.getRelativeRect();
|
||||
const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } });
|
||||
|
||||
// 2. Fetch the workflow
|
||||
const { data: workflow } = await getWorkflow(selectedWorkflowId);
|
||||
if (!workflow) {
|
||||
throw new Error('Failed to load workflow');
|
||||
}
|
||||
|
||||
// 3. Build the workflow graph with the canvas image
|
||||
// Use the user-selected ImageField, or find one automatically
|
||||
let imageFieldIdentifier: { nodeId: string; fieldName: string } | undefined;
|
||||
|
||||
// Method 1: Use user-selected ImageField (preferred)
|
||||
if (selectedImageFieldKey) {
|
||||
const [nodeId, fieldName] = selectedImageFieldKey.split('.');
|
||||
if (nodeId && fieldName) {
|
||||
imageFieldIdentifier = { nodeId, fieldName };
|
||||
}
|
||||
}
|
||||
|
||||
// Method 2: Search through form elements for an ImageField (fallback)
|
||||
if (!imageFieldIdentifier && workflow.workflow.form && workflow.workflow.form.elements) {
|
||||
for (const element of Object.values(workflow.workflow.form.elements)) {
|
||||
if (element.type !== 'node-field') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const fieldIdentifier = element.data?.fieldIdentifier;
|
||||
if (!fieldIdentifier) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// @ts-expect-error - node data type is complex
|
||||
const node = workflow.workflow.nodes.find((n) => n.data.id === fieldIdentifier.nodeId);
|
||||
if (!node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// @ts-expect-error - node.data type is complex
|
||||
if (node.data.type === 'image') {
|
||||
imageFieldIdentifier = fieldIdentifier;
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if field type is ImageField
|
||||
// @ts-expect-error - field type is complex
|
||||
const field = node.data.inputs[fieldIdentifier.fieldName];
|
||||
if (field?.type?.name === 'ImageField') {
|
||||
imageFieldIdentifier = fieldIdentifier;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Method 3: Fallback to exposedFields
|
||||
if (!imageFieldIdentifier && workflow.workflow.exposedFields) {
|
||||
imageFieldIdentifier = workflow.workflow.exposedFields.find((fieldIdentifier) => {
|
||||
// @ts-expect-error - node data type is complex
|
||||
const node = workflow.workflow.nodes.find((n) => n.data.id === fieldIdentifier.nodeId);
|
||||
if (!node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// @ts-expect-error - node.data type is complex
|
||||
if (node.data.type === 'image') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if field type is ImageField
|
||||
// @ts-expect-error - field type is complex
|
||||
const field = node.data.inputs[fieldIdentifier.fieldName];
|
||||
return field?.type?.name === 'ImageField';
|
||||
});
|
||||
}
|
||||
|
||||
if (!imageFieldIdentifier) {
|
||||
throw new Error('Workflow does not have an image input field in the Form Builder');
|
||||
}
|
||||
|
||||
// Update the workflow nodes with canvas image and user field values
|
||||
const updatedWorkflow = {
|
||||
...workflow.workflow,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
nodes: workflow.workflow.nodes.map((node: any) => {
|
||||
const nodeId = node.data.id;
|
||||
let updatedInputs = { ...node.data.inputs };
|
||||
const updatedData = { ...node.data };
|
||||
let hasChanges = false;
|
||||
|
||||
// Apply image field if this is the image node
|
||||
if (nodeId === imageFieldIdentifier.nodeId) {
|
||||
updatedInputs[imageFieldIdentifier.fieldName] = {
|
||||
...updatedInputs[imageFieldIdentifier.fieldName],
|
||||
value: imageDTO,
|
||||
};
|
||||
hasChanges = true;
|
||||
}
|
||||
|
||||
// Apply any field values from Redux state
|
||||
if (fieldValues) {
|
||||
Object.entries(fieldValues).forEach(([fieldKey, value]) => {
|
||||
const [fieldNodeId, fieldName] = fieldKey.split('.');
|
||||
if (fieldNodeId && fieldName && fieldNodeId === nodeId && updatedInputs[fieldName]) {
|
||||
updatedInputs[fieldName] = {
|
||||
...updatedInputs[fieldName],
|
||||
value: value,
|
||||
};
|
||||
hasChanges = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// If anything was modified, return updated node
|
||||
if (hasChanges) {
|
||||
updatedData.inputs = updatedInputs;
|
||||
return {
|
||||
...node,
|
||||
data: updatedData,
|
||||
};
|
||||
}
|
||||
|
||||
return node;
|
||||
}),
|
||||
};
|
||||
|
||||
// Validate that the workflow has a canvas_output node
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const hasCanvasOutputNode = updatedWorkflow.nodes.some((node: any) => node.data?.type === 'canvas_output');
|
||||
if (!hasCanvasOutputNode) {
|
||||
throw new Error('Workflow does not have a Canvas Output node');
|
||||
}
|
||||
|
||||
// 4. Convert workflow to graph format
|
||||
const graphNodes: Record<string, unknown> = {};
|
||||
const nodeIdMapping: Record<string, string> = {}; // Map original IDs to new IDs
|
||||
|
||||
for (const node of updatedWorkflow.nodes) {
|
||||
const nodeData = node.data;
|
||||
const isCanvasOutputNode = nodeData.type === 'canvas_output';
|
||||
|
||||
// Prefix canvas_output node IDs so the staging area can find them
|
||||
const nodeId = isCanvasOutputNode ? getPrefixedId(CANVAS_OUTPUT_PREFIX) : nodeData.id;
|
||||
nodeIdMapping[nodeData.id] = nodeId;
|
||||
|
||||
const invocation: Record<string, unknown> = {
|
||||
id: nodeId,
|
||||
type: nodeData.type,
|
||||
// Canvas output nodes are always intermediate (they go to the staging area, not gallery)
|
||||
is_intermediate: isCanvasOutputNode ? true : (nodeData.isIntermediate ?? false),
|
||||
use_cache: nodeData.useCache ?? true,
|
||||
};
|
||||
|
||||
// Add input values to the invocation
|
||||
for (const [fieldName, fieldData] of Object.entries(nodeData.inputs)) {
|
||||
const fieldValue = (fieldData as { value?: unknown }).value;
|
||||
if (fieldValue === undefined) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The frontend stores board fields as 'auto', 'none', or { board_id: string }.
|
||||
// The backend expects null or { board_id: string }. Translate accordingly.
|
||||
if (fieldName === 'board') {
|
||||
if (fieldValue === 'auto' || fieldValue === 'none' || fieldValue === null) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
invocation[fieldName] = fieldValue;
|
||||
}
|
||||
|
||||
graphNodes[nodeId] = invocation;
|
||||
}
|
||||
|
||||
// Convert edges to graph format, using the node ID mapping
|
||||
const edgesArray = updatedWorkflow.edges as Array<{
|
||||
source: string;
|
||||
target: string;
|
||||
sourceHandle: string;
|
||||
targetHandle: string;
|
||||
}>;
|
||||
const graphEdges = edgesArray.map((edge) => ({
|
||||
source: {
|
||||
node_id: nodeIdMapping[edge.source] || edge.source,
|
||||
field: edge.sourceHandle,
|
||||
},
|
||||
destination: {
|
||||
node_id: nodeIdMapping[edge.target] || edge.target,
|
||||
field: edge.targetHandle,
|
||||
},
|
||||
}));
|
||||
|
||||
const graph = {
|
||||
id: workflow.workflow.id || workflow.workflow_id || 'temp',
|
||||
nodes: graphNodes,
|
||||
edges: graphEdges,
|
||||
};
|
||||
|
||||
log.debug({ workflowName: workflow.name, destination: canvasSessionId }, 'Enqueueing workflow on canvas');
|
||||
|
||||
await dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate({
|
||||
batch: {
|
||||
workflow: updatedWorkflow,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
graph: graph as any,
|
||||
runs: 1,
|
||||
origin: 'canvas_workflow_integration',
|
||||
destination: canvasSessionId,
|
||||
},
|
||||
prepend: false,
|
||||
})
|
||||
).unwrap();
|
||||
|
||||
// 5. Close the modal and show success message
|
||||
// Results will appear in the staging area where user can accept/discard them
|
||||
toast({
|
||||
status: 'success',
|
||||
title: t('controlLayers.workflowIntegration.executionStarted'),
|
||||
description: t('controlLayers.workflowIntegration.executionStartedDescription'),
|
||||
});
|
||||
|
||||
dispatch(canvasWorkflowIntegrationClosed());
|
||||
} catch (error) {
|
||||
log.error('Error executing workflow');
|
||||
dispatch(canvasWorkflowIntegrationProcessingCompleted());
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('controlLayers.workflowIntegration.executionFailed'),
|
||||
description: error instanceof Error ? error.message : 'Unknown error',
|
||||
});
|
||||
}
|
||||
}, [
|
||||
selectedWorkflowId,
|
||||
sourceEntityIdentifier,
|
||||
canvasManager,
|
||||
dispatch,
|
||||
getWorkflow,
|
||||
t,
|
||||
fieldValues,
|
||||
selectedImageFieldKey,
|
||||
canvasSessionId,
|
||||
]);
|
||||
|
||||
return {
|
||||
execute,
|
||||
canExecute,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,107 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { workflowsApi } from 'services/api/endpoints/workflows';
|
||||
import type { paths } from 'services/api/schema';
|
||||
|
||||
import { workflowHasImageField } from './workflowHasImageField';
|
||||
|
||||
const log = logger('canvas-workflow-integration');
|
||||
|
||||
type WorkflowListItem =
|
||||
paths['/api/v1/workflows/']['get']['responses']['200']['content']['application/json']['items'][number];
|
||||
|
||||
interface UseFilteredWorkflowsResult {
|
||||
filteredWorkflows: WorkflowListItem[];
|
||||
isFiltering: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook that filters workflows to only include those with at least one ImageField
|
||||
* @param workflows The list of workflows to filter
|
||||
* @returns Filtered list of workflows that have ImageFields
|
||||
*/
|
||||
export function useFilteredWorkflows(workflows: WorkflowListItem[]): UseFilteredWorkflowsResult {
|
||||
const dispatch = useAppDispatch();
|
||||
const [filteredWorkflows, setFilteredWorkflows] = useState<WorkflowListItem[]>([]);
|
||||
const [isFiltering, setIsFiltering] = useState(false);
|
||||
|
||||
const filterWorkflows = useCallback(async () => {
|
||||
if (workflows.length === 0) {
|
||||
setFilteredWorkflows([]);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsFiltering(true);
|
||||
|
||||
try {
|
||||
// Load all workflows in parallel and check for ImageFields
|
||||
const workflowChecks = await Promise.all(
|
||||
workflows.map(async (workflow) => {
|
||||
try {
|
||||
// Fetch the full workflow data using dispatch
|
||||
const result = await dispatch(
|
||||
workflowsApi.endpoints.getWorkflow.initiate(workflow.workflow_id, {
|
||||
subscribe: false,
|
||||
forceRefetch: false,
|
||||
})
|
||||
);
|
||||
|
||||
// Get the data from the result
|
||||
const data = 'data' in result ? result.data : undefined;
|
||||
|
||||
const hasImageField = workflowHasImageField(data);
|
||||
|
||||
log.debug(
|
||||
{ workflowId: workflow.workflow_id, name: workflow.name, hasImageField },
|
||||
'Checked workflow for ImageField'
|
||||
);
|
||||
|
||||
// Clean up the subscription
|
||||
if ('unsubscribe' in result && typeof result.unsubscribe === 'function') {
|
||||
result.unsubscribe();
|
||||
}
|
||||
|
||||
return {
|
||||
workflow,
|
||||
hasImageField,
|
||||
};
|
||||
} catch (error) {
|
||||
log.error(
|
||||
{
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
workflowId: workflow.workflow_id,
|
||||
},
|
||||
'Error checking workflow for ImageField'
|
||||
);
|
||||
return {
|
||||
workflow,
|
||||
hasImageField: false,
|
||||
};
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// Filter to only include workflows with ImageFields
|
||||
const filtered = workflowChecks.filter((check) => check.hasImageField).map((check) => check.workflow);
|
||||
|
||||
log.debug({ totalWorkflows: workflows.length, filteredCount: filtered.length }, 'Filtered workflows');
|
||||
|
||||
setFilteredWorkflows(filtered);
|
||||
} catch (error) {
|
||||
log.error({ error: error instanceof Error ? error.message : String(error) }, 'Error filtering workflows');
|
||||
setFilteredWorkflows([]);
|
||||
} finally {
|
||||
setIsFiltering(false);
|
||||
}
|
||||
}, [workflows, dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
filterWorkflows();
|
||||
}, [filterWorkflows]);
|
||||
|
||||
return {
|
||||
filteredWorkflows,
|
||||
isFiltering,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { isNodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import type { paths } from 'services/api/schema';
|
||||
|
||||
const log = logger('canvas-workflow-integration');
|
||||
|
||||
type WorkflowResponse =
|
||||
paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
/**
|
||||
* Checks if a workflow is compatible with canvas workflow integration.
|
||||
* Requirements:
|
||||
* 1. Has a Form Builder (allows users to modify parameters)
|
||||
* 2. Has a canvas_output node (explicit canvas output target)
|
||||
* 3. Has at least one ImageField in the Form Builder (receives the canvas image)
|
||||
* @param workflow The workflow to check
|
||||
* @returns true if the workflow meets all requirements, false otherwise
|
||||
*/
|
||||
export function workflowHasImageField(workflow: WorkflowResponse | undefined): boolean {
|
||||
if (!workflow?.workflow) {
|
||||
log.debug('No workflow data provided');
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only workflows with Form Builder are supported
|
||||
// Workflows without Form Builder don't allow changing models and other parameters
|
||||
if (!workflow.workflow.form?.elements) {
|
||||
log.debug('Workflow has no form builder - excluding from list');
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must have a canvas_output node to define where the output image goes
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const hasCanvasOutputNode = workflow.workflow.nodes?.some((n: any) => (n.data as any)?.type === 'canvas_output');
|
||||
if (!hasCanvasOutputNode) {
|
||||
log.debug('Workflow has no canvas_output node - excluding from list');
|
||||
return false;
|
||||
}
|
||||
|
||||
const templates = $templates.get();
|
||||
const elements = workflow.workflow.form.elements;
|
||||
|
||||
log.debug('Workflow has form builder and canvas_output node, checking form elements for ImageField');
|
||||
|
||||
for (const [elementId, element] of Object.entries(elements)) {
|
||||
if (isNodeFieldElement(element)) {
|
||||
const { fieldIdentifier } = element.data;
|
||||
|
||||
// Find the node that contains this field
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const node = workflow.workflow.nodes?.find((n: any) => n.data?.id === fieldIdentifier.nodeId);
|
||||
if (!node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const nodeType = (node.data as any)?.type;
|
||||
if (!nodeType) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const template = templates[nodeType];
|
||||
if (!template?.inputs) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const fieldTemplate = template.inputs[fieldIdentifier.fieldName];
|
||||
if (!fieldTemplate) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this is an ImageField
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const fieldType = (fieldTemplate as any).type?.name;
|
||||
if (fieldType === 'ImageField') {
|
||||
log.debug({ elementId, fieldName: fieldIdentifier.fieldName }, 'Found ImageField in workflow form');
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have a form but no ImageFields were found in it, return false
|
||||
log.debug('Workflow has form builder but no ImageField found in form elements');
|
||||
return false;
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
|
||||
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
|
||||
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
|
||||
import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/components/common/CanvasEntityMenuItemsMergeDown';
|
||||
import { CanvasEntityMenuItemsRunWorkflow } from 'features/controlLayers/components/common/CanvasEntityMenuItemsRunWorkflow';
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
@@ -25,6 +26,7 @@ export const RasterLayerMenuItems = memo(() => {
|
||||
</IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<CanvasEntityMenuItemsRunWorkflow />
|
||||
<CanvasEntityMenuItemsSelectObject />
|
||||
<RasterLayerMenuItemsAdjustments />
|
||||
<MenuDivider />
|
||||
|
||||
@@ -13,10 +13,10 @@ import {
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { useOutputImageDTO, useStagingAreaContext } from './context';
|
||||
import { useStagingAreaContext } from './context';
|
||||
import { QueueItemNumber } from './QueueItemNumber';
|
||||
import type { StagingEntry } from './state';
|
||||
|
||||
const sx = {
|
||||
cursor: 'pointer',
|
||||
@@ -37,21 +37,23 @@ const sx = {
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
type Props = {
|
||||
item: S['SessionQueueItem'];
|
||||
entry: StagingEntry;
|
||||
index: number;
|
||||
};
|
||||
|
||||
export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
|
||||
export const QueueItemPreviewMini = memo(({ entry, index }: Props) => {
|
||||
const ctx = useStagingAreaContext();
|
||||
const dispatch = useAppDispatch();
|
||||
const $isSelected = useMemo(() => ctx.buildIsSelectedComputed(item.item_id), [ctx, item.item_id]);
|
||||
const $isSelected = useMemo(
|
||||
() => ctx.buildIsSelectedComputed(entry.item.item_id, entry.imageIndex),
|
||||
[ctx, entry.item.item_id, entry.imageIndex]
|
||||
);
|
||||
const isSelected = useStore($isSelected);
|
||||
const imageDTO = useOutputImageDTO(item.item_id);
|
||||
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
ctx.select(item.item_id);
|
||||
}, [ctx, item.item_id]);
|
||||
ctx.select(entry.item.item_id, entry.imageIndex);
|
||||
}, [ctx, entry.item.item_id, entry.imageIndex]);
|
||||
|
||||
const onDoubleClick = useCallback(() => {
|
||||
if (autoSwitch !== 'off') {
|
||||
@@ -63,8 +65,8 @@ export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
|
||||
}, [autoSwitch, dispatch]);
|
||||
|
||||
const onLoad = useCallback(() => {
|
||||
ctx.onImageLoaded(item.item_id);
|
||||
}, [ctx, item.item_id]);
|
||||
ctx.onImageLoaded(entry.item.item_id);
|
||||
}, [ctx, entry.item.item_id]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -74,11 +76,17 @@ export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
|
||||
onClick={onClick}
|
||||
onDoubleClick={onDoubleClick}
|
||||
>
|
||||
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} position="absolute" onLoad={onLoad} />}
|
||||
<QueueItemProgressImage itemId={item.item_id} position="absolute" />
|
||||
<QueueItemStatusLabel item={entry.item} position="absolute" margin="auto" />
|
||||
{entry.imageDTO && <DndImage imageDTO={entry.imageDTO} position="absolute" onLoad={onLoad} />}
|
||||
<QueueItemProgressImage itemId={entry.item.item_id} position="absolute" />
|
||||
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
|
||||
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
|
||||
<QueueItemCircularProgress
|
||||
itemId={entry.item.item_id}
|
||||
status={entry.item.status}
|
||||
position="absolute"
|
||||
top={1}
|
||||
right={2}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -8,10 +8,10 @@ import type { CSSProperties, RefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import type { Components, ComputeItemKey, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
|
||||
import { Virtuoso } from 'react-virtuoso';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { useStagingAreaContext } from './context';
|
||||
import { getQueueItemElementId, STAGING_AREA_THUMBNAIL_STRIP_HEIGHT } from './shared';
|
||||
import type { StagingEntry } from './state';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
@@ -141,7 +141,7 @@ export const StagingAreaItemsList = memo(() => {
|
||||
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const items = useStore(ctx.$items);
|
||||
const entries = useStore(ctx.$entries);
|
||||
|
||||
const scrollerRef = useScrollableStagingArea(rootRef);
|
||||
|
||||
@@ -172,9 +172,9 @@ export const StagingAreaItemsList = memo(() => {
|
||||
|
||||
return (
|
||||
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
|
||||
<Virtuoso<S['SessionQueueItem']>
|
||||
<Virtuoso<StagingEntry>
|
||||
ref={virtuosoRef}
|
||||
data={items}
|
||||
data={entries}
|
||||
horizontalDirection
|
||||
style={virtuosoStyles}
|
||||
computeItemKey={computeItemKey}
|
||||
@@ -183,19 +183,19 @@ export const StagingAreaItemsList = memo(() => {
|
||||
components={components}
|
||||
rangeChanged={onRangeChanged}
|
||||
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
|
||||
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], void>['scrollerRef']}
|
||||
scrollerRef={scrollerRef as VirtuosoProps<StagingEntry, void>['scrollerRef']}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
|
||||
|
||||
const computeItemKey: ComputeItemKey<S['SessionQueueItem'], void> = (_, item: S['SessionQueueItem']) => {
|
||||
return item.item_id;
|
||||
const computeItemKey: ComputeItemKey<StagingEntry, void> = (_, entry: StagingEntry) => {
|
||||
return `${entry.item.item_id}-${entry.imageIndex}`;
|
||||
};
|
||||
|
||||
const itemContent: ItemContent<S['SessionQueueItem'], void> = (index, item) => (
|
||||
<QueueItemPreviewMini key={`${item.item_id}-mini`} item={item} index={index} />
|
||||
const itemContent: ItemContent<StagingEntry, void> = (index, entry) => (
|
||||
<QueueItemPreviewMini key={`${entry.item.item_id}-${entry.imageIndex}-mini`} entry={entry} index={index} />
|
||||
);
|
||||
|
||||
const listSx = {
|
||||
@@ -207,7 +207,7 @@ const listSx = {
|
||||
},
|
||||
};
|
||||
|
||||
const components: Components<S['SessionQueueItem']> = {
|
||||
const components: Components<StagingEntry> = {
|
||||
List: forwardRef(({ context: _, ...rest }, ref) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
@@ -84,12 +84,13 @@ export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWi
|
||||
y: y + Math.round((bboxRect.height - scaledHeight) / 2),
|
||||
};
|
||||
const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState());
|
||||
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position,
|
||||
objects: [imageObject],
|
||||
};
|
||||
|
||||
store.dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
|
||||
|
||||
store.dispatch(canvasSessionReset());
|
||||
store.dispatch(
|
||||
queueApi.endpoints.cancelQueueItemsByDestination.initiate({ destination: sessionId }, { track: false })
|
||||
@@ -129,12 +130,6 @@ export const useStagingAreaContext = () => {
|
||||
return ctx;
|
||||
};
|
||||
|
||||
export const useOutputImageDTO = (itemId: number) => {
|
||||
const ctx = useStagingAreaContext();
|
||||
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });
|
||||
return allProgressData[itemId]?.imageDTO ?? null;
|
||||
};
|
||||
|
||||
export const useProgressDatum = (itemId: number): ProgressData => {
|
||||
const ctx = useStagingAreaContext();
|
||||
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { S } from 'services/api/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { getOutputImageName, getProgressMessage, getQueueItemElementId } from './shared';
|
||||
import { getOutputImageNames, getProgressMessage, getQueueItemElementId } from './shared';
|
||||
|
||||
describe('StagingAreaApi Utility Functions', () => {
|
||||
describe('getProgressMessage', () => {
|
||||
@@ -34,7 +34,7 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('getOutputImageName', () => {
|
||||
describe('getOutputImageNames', () => {
|
||||
it('should extract image name from completed queue item', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
@@ -61,10 +61,10 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe('test-output.png');
|
||||
expect(getOutputImageNames(queueItem)).toEqual(['test-output.png']);
|
||||
});
|
||||
|
||||
it('should return null when no canvas output node found', () => {
|
||||
it('should return empty array when no canvas output node found', () => {
|
||||
const queueItem = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
@@ -93,10 +93,10 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
expect(getOutputImageNames(queueItem)).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return null when output node has no results', () => {
|
||||
it('should return empty array when output node has no results', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
@@ -116,10 +116,10 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
expect(getOutputImageNames(queueItem)).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return null when results contain no image fields', () => {
|
||||
it('should return empty array when results contain no image fields', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
@@ -144,10 +144,10 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
expect(getOutputImageNames(queueItem)).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle multiple outputs and return first image', () => {
|
||||
it('should collect images from multiple canvas_output nodes', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
@@ -161,15 +161,17 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {
|
||||
canvas_output: ['output-node-id'],
|
||||
'canvas_output:abc123': ['output-node-1'],
|
||||
'canvas_output:def456': ['output-node-2'],
|
||||
},
|
||||
results: {
|
||||
'output-node-id': {
|
||||
text: 'some text',
|
||||
first_image: {
|
||||
'output-node-1': {
|
||||
image: {
|
||||
image_name: 'first-image.png',
|
||||
},
|
||||
second_image: {
|
||||
},
|
||||
'output-node-2': {
|
||||
image: {
|
||||
image_name: 'second-image.png',
|
||||
},
|
||||
},
|
||||
@@ -177,8 +179,8 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
const result = getOutputImageName(queueItem);
|
||||
expect(result).toBe('first-image.png');
|
||||
const result = getOutputImageNames(queueItem);
|
||||
expect(result).toEqual(['first-image.png', 'second-image.png']);
|
||||
});
|
||||
|
||||
it('should return first image from image collections', () => {
|
||||
@@ -226,7 +228,7 @@ describe('StagingAreaApi Utility Functions', () => {
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
expect(getOutputImageNames(queueItem)).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -16,24 +16,27 @@ export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0p
|
||||
export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
|
||||
export const STAGING_AREA_THUMBNAIL_STRIP_HEIGHT = '72px';
|
||||
|
||||
export const getOutputImageName = (item: S['SessionQueueItem']) => {
|
||||
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
|
||||
isCanvasOutputNodeId(nodeId)
|
||||
)?.[1][0];
|
||||
const output = nodeId ? item.session.results[nodeId] : undefined;
|
||||
export const getOutputImageNames = (item: S['SessionQueueItem']): string[] => {
|
||||
const imageNames: string[] = [];
|
||||
|
||||
if (!output) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (const [_name, value] of objectEntries(output)) {
|
||||
if (isImageField(value)) {
|
||||
return value.image_name;
|
||||
for (const [sourceNodeId, preparedNodeIds] of Object.entries(item.session.source_prepared_mapping)) {
|
||||
if (!isCanvasOutputNodeId(sourceNodeId)) {
|
||||
continue;
|
||||
}
|
||||
const nodeId = preparedNodeIds[0];
|
||||
const output = nodeId ? item.session.results[nodeId] : undefined;
|
||||
if (!output) {
|
||||
continue;
|
||||
}
|
||||
for (const [_name, value] of objectEntries(output)) {
|
||||
if (isImageField(value)) {
|
||||
imageNames.push(value.image_name);
|
||||
}
|
||||
}
|
||||
if (isImageFieldCollection(value)) {
|
||||
return value[0]?.image_name ?? null;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
return imageNames;
|
||||
};
|
||||
|
||||
@@ -41,18 +41,34 @@ describe('StagingAreaApi', () => {
|
||||
expect(api.$items.get()).toEqual([]);
|
||||
expect(api.$progressData.get()).toEqual({});
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Computed Values', () => {
|
||||
it('should compute item count correctly', () => {
|
||||
it('should compute item count as entry count for single-output items', () => {
|
||||
expect(api.$itemCount.get()).toBe(0);
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
// Item with no imageDTOs produces 1 entry
|
||||
expect(api.$itemCount.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should compute item count as entry count for multi-output items', () => {
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [createMockImageDTO({ image_name: 'a.png' }), createMockImageDTO({ image_name: 'b.png' })],
|
||||
imageLoaded: false,
|
||||
});
|
||||
// 2 imageDTOs = 2 entries
|
||||
expect(api.$itemCount.get()).toBe(2);
|
||||
});
|
||||
|
||||
it('should compute hasItems correctly', () => {
|
||||
expect(api.$hasItems.get()).toBe(false);
|
||||
|
||||
@@ -95,7 +111,7 @@ describe('StagingAreaApi', () => {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
imageDTOs: [imageDTO],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -111,6 +127,95 @@ describe('StagingAreaApi', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Entries Computed', () => {
|
||||
it('should produce one entry per item when each has 0 or 1 imageDTOs', () => {
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
api.$items.set(items);
|
||||
|
||||
const entries = api.$entries.get();
|
||||
expect(entries).toHaveLength(2);
|
||||
expect(entries[0]?.item.item_id).toBe(1);
|
||||
expect(entries[0]?.imageIndex).toBe(0);
|
||||
expect(entries[0]?.imageDTO).toBe(null);
|
||||
expect(entries[1]?.item.item_id).toBe(2);
|
||||
expect(entries[1]?.imageIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should produce one entry per item when each has exactly 1 imageDTO', () => {
|
||||
const imageDTO = createMockImageDTO({ image_name: 'img.png' });
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [imageDTO],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
const entries = api.$entries.get();
|
||||
expect(entries).toHaveLength(1);
|
||||
expect(entries[0]?.imageDTO).toBe(imageDTO);
|
||||
expect(entries[0]?.imageIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should produce multiple entries for items with multiple imageDTOs', () => {
|
||||
const img1 = createMockImageDTO({ image_name: 'a.png' });
|
||||
const img2 = createMockImageDTO({ image_name: 'b.png' });
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img1, img2],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
const entries = api.$entries.get();
|
||||
expect(entries).toHaveLength(2);
|
||||
expect(entries[0]?.item.item_id).toBe(1);
|
||||
expect(entries[0]?.imageDTO).toBe(img1);
|
||||
expect(entries[0]?.imageIndex).toBe(0);
|
||||
expect(entries[1]?.item.item_id).toBe(1);
|
||||
expect(entries[1]?.imageDTO).toBe(img2);
|
||||
expect(entries[1]?.imageIndex).toBe(1);
|
||||
});
|
||||
|
||||
it('should interleave entries from multiple items', () => {
|
||||
const img1a = createMockImageDTO({ image_name: 'a1.png' });
|
||||
const img1b = createMockImageDTO({ image_name: 'a2.png' });
|
||||
const img2 = createMockImageDTO({ image_name: 'b1.png' });
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
api.$items.set(items);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img1a, img1b],
|
||||
imageLoaded: false,
|
||||
});
|
||||
api.$progressData.setKey(2, {
|
||||
itemId: 2,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img2],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
const entries = api.$entries.get();
|
||||
expect(entries).toHaveLength(3);
|
||||
// Item 1 entries first
|
||||
expect(entries[0]?.item.item_id).toBe(1);
|
||||
expect(entries[0]?.imageDTO).toBe(img1a);
|
||||
expect(entries[1]?.item.item_id).toBe(1);
|
||||
expect(entries[1]?.imageDTO).toBe(img1b);
|
||||
// Then item 2
|
||||
expect(entries[2]?.item.item_id).toBe(2);
|
||||
expect(entries[2]?.imageDTO).toBe(img2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Selection Methods', () => {
|
||||
beforeEach(() => {
|
||||
const items = [
|
||||
@@ -124,9 +229,16 @@ describe('StagingAreaApi', () => {
|
||||
it('should select item by ID', () => {
|
||||
api.select(2);
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
expect(mockApp.onSelect).toHaveBeenCalledWith(2);
|
||||
});
|
||||
|
||||
it('should select item by ID and imageIndex', () => {
|
||||
api.select(2, 1);
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
expect(api.$selectedImageIndex.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should select next item', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
api.selectNext();
|
||||
@@ -184,6 +296,124 @@ describe('StagingAreaApi', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Entry-Based Navigation', () => {
|
||||
it('should navigate through entries of multi-output items', () => {
|
||||
const img1 = createMockImageDTO({ image_name: 'a.png' });
|
||||
const img2 = createMockImageDTO({ image_name: 'b.png' });
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img1, img2],
|
||||
imageLoaded: false,
|
||||
});
|
||||
api.$selectedItemId.set(1);
|
||||
api.$selectedImageIndex.set(0);
|
||||
|
||||
// Should be on first entry (imageIndex 0)
|
||||
expect(api.$selectedItem.get()?.imageDTO).toBe(img1);
|
||||
expect(api.$selectedItem.get()?.index).toBe(0);
|
||||
|
||||
// Navigate to next entry (imageIndex 1, same item)
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$selectedImageIndex.get()).toBe(1);
|
||||
expect(api.$selectedItem.get()?.imageDTO).toBe(img2);
|
||||
expect(api.$selectedItem.get()?.index).toBe(1);
|
||||
|
||||
// Navigate past last entry - should wrap to first
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
expect(api.$selectedItem.get()?.imageDTO).toBe(img1);
|
||||
});
|
||||
|
||||
it('should navigate across items and their entries', () => {
|
||||
const img1a = createMockImageDTO({ image_name: 'a1.png' });
|
||||
const img1b = createMockImageDTO({ image_name: 'a2.png' });
|
||||
const img2 = createMockImageDTO({ image_name: 'b1.png' });
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
api.$items.set(items);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img1a, img1b],
|
||||
imageLoaded: false,
|
||||
});
|
||||
api.$progressData.setKey(2, {
|
||||
itemId: 2,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img2],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
// Start on first entry
|
||||
api.$selectedItemId.set(1);
|
||||
api.$selectedImageIndex.set(0);
|
||||
|
||||
// entries: [item1/img0, item1/img1, item2/img0]
|
||||
expect(api.$entries.get()).toHaveLength(3);
|
||||
|
||||
// Next -> item1, imageIndex 1
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$selectedImageIndex.get()).toBe(1);
|
||||
|
||||
// Next -> item2, imageIndex 0
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
|
||||
// Next -> wraps to item1, imageIndex 0
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
|
||||
// Prev -> wraps to item2, imageIndex 0
|
||||
api.selectPrev();
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
});
|
||||
|
||||
it('should select correct entry with selectFirst and selectLast', () => {
|
||||
const img1a = createMockImageDTO({ image_name: 'a1.png' });
|
||||
const img1b = createMockImageDTO({ image_name: 'a2.png' });
|
||||
const img2 = createMockImageDTO({ image_name: 'b1.png' });
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
api.$items.set(items);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img1a, img1b],
|
||||
imageLoaded: false,
|
||||
});
|
||||
api.$progressData.setKey(2, {
|
||||
itemId: 2,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img2],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
api.selectLast();
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
expect(api.$selectedItem.get()?.imageDTO).toBe(img2);
|
||||
|
||||
api.selectFirst();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
expect(api.$selectedItem.get()?.imageDTO).toBe(img1a);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Discard Methods', () => {
|
||||
beforeEach(() => {
|
||||
const items = [
|
||||
@@ -201,6 +431,7 @@ describe('StagingAreaApi', () => {
|
||||
api.discardSelected();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(3);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
expect(mockApp.onDiscard).toHaveBeenCalledWith(selectedItem?.item);
|
||||
});
|
||||
|
||||
@@ -238,6 +469,7 @@ describe('StagingAreaApi', () => {
|
||||
api.discardAll();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
expect(mockApp.onDiscardAll).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -247,6 +479,14 @@ describe('StagingAreaApi', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
expect(api.$discardSelectedIsEnabled.get()).toBe(true);
|
||||
});
|
||||
|
||||
it('should reset selectedImageIndex when discarding', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
api.$selectedImageIndex.set(2);
|
||||
api.discardSelected();
|
||||
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Accept Methods', () => {
|
||||
@@ -256,13 +496,13 @@ describe('StagingAreaApi', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
});
|
||||
|
||||
it('should accept selected item when image is available', () => {
|
||||
it('should accept selected item with single imageDTO', () => {
|
||||
const imageDTO = createMockImageDTO();
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
imageDTOs: [imageDTO],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -272,12 +512,34 @@ describe('StagingAreaApi', () => {
|
||||
expect(mockApp.onAccept).toHaveBeenCalledWith(selectedItem?.item, imageDTO);
|
||||
});
|
||||
|
||||
it('should accept the correct imageDTO from a multi-output entry', () => {
|
||||
const img1 = createMockImageDTO({ image_name: 'a.png' });
|
||||
const img2 = createMockImageDTO({ image_name: 'b.png' });
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTOs: [img1, img2],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
// Select the second image
|
||||
api.$selectedImageIndex.set(1);
|
||||
|
||||
const selectedItem = api.$selectedItem.get();
|
||||
expect(selectedItem?.imageDTO).toBe(img2);
|
||||
|
||||
api.acceptSelected();
|
||||
|
||||
expect(mockApp.onAccept).toHaveBeenCalledWith(selectedItem?.item, img2);
|
||||
});
|
||||
|
||||
it('should do nothing when no image is available', () => {
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageDTOs: [],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -301,7 +563,7 @@ describe('StagingAreaApi', () => {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
imageDTOs: [imageDTO],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -339,7 +601,7 @@ describe('StagingAreaApi', () => {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: createMockImageDTO(),
|
||||
imageDTOs: [createMockImageDTO()],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -352,7 +614,7 @@ describe('StagingAreaApi', () => {
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.progressEvent).toBe(progressEvent);
|
||||
expect(progressData[1]?.imageDTO).toBeTruthy();
|
||||
expect(progressData[1]?.imageDTOs.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -438,7 +700,7 @@ describe('StagingAreaApi', () => {
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.imageDTO).toBe(imageDTO);
|
||||
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO);
|
||||
});
|
||||
|
||||
it('should handle auto-switch on completion', async () => {
|
||||
@@ -473,7 +735,7 @@ describe('StagingAreaApi', () => {
|
||||
itemId: 999,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageDTOs: [],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -490,7 +752,7 @@ describe('StagingAreaApi', () => {
|
||||
itemId: 1,
|
||||
progressEvent: createMockProgressEvent({ item_id: 1 }),
|
||||
progressImage: null,
|
||||
imageDTO: createMockImageDTO(),
|
||||
imageDTOs: [createMockImageDTO()],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -501,7 +763,7 @@ describe('StagingAreaApi', () => {
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.progressEvent).toBe(null);
|
||||
expect(progressData[1]?.progressImage).toBe(null);
|
||||
expect(progressData[1]?.imageDTO).toBe(null);
|
||||
expect(progressData[1]?.imageDTOs).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -547,18 +809,34 @@ describe('StagingAreaApi', () => {
|
||||
});
|
||||
|
||||
describe('Utility Methods', () => {
|
||||
it('should build isSelected computed correctly', () => {
|
||||
it('should build isSelected computed correctly for default imageIndex', () => {
|
||||
const isSelected = api.buildIsSelectedComputed(1);
|
||||
expect(isSelected.get()).toBe(false);
|
||||
|
||||
api.$selectedItemId.set(1);
|
||||
api.$selectedImageIndex.set(0);
|
||||
expect(isSelected.get()).toBe(true);
|
||||
});
|
||||
|
||||
it('should build isSelected computed correctly with specific imageIndex', () => {
|
||||
const isSelected0 = api.buildIsSelectedComputed(1, 0);
|
||||
const isSelected1 = api.buildIsSelectedComputed(1, 1);
|
||||
|
||||
api.$selectedItemId.set(1);
|
||||
api.$selectedImageIndex.set(0);
|
||||
expect(isSelected0.get()).toBe(true);
|
||||
expect(isSelected1.get()).toBe(false);
|
||||
|
||||
api.$selectedImageIndex.set(1);
|
||||
expect(isSelected0.get()).toBe(false);
|
||||
expect(isSelected1.get()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Cleanup', () => {
|
||||
it('should reset all state on cleanup', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
api.$selectedImageIndex.set(2);
|
||||
api.$items.set([createMockQueueItem({ item_id: 1 })]);
|
||||
api.$lastStartedItemId.set(1);
|
||||
api.$lastCompletedItemId.set(1);
|
||||
@@ -566,13 +844,14 @@ describe('StagingAreaApi', () => {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageDTOs: [],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
api.cleanup();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
expect(api.$selectedImageIndex.get()).toBe(0);
|
||||
expect(api.$items.get()).toEqual([]);
|
||||
expect(api.$lastStartedItemId.get()).toBe(null);
|
||||
expect(api.$lastCompletedItemId.get()).toBe(null);
|
||||
@@ -622,13 +901,13 @@ describe('StagingAreaApi', () => {
|
||||
expect(progressData[1]?.progressEvent).toBe(progressEvent);
|
||||
});
|
||||
|
||||
it('should preserve imageDTO when updating progress', () => {
|
||||
it('should preserve imageDTOs when updating progress', () => {
|
||||
const imageDTO = createMockImageDTO();
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
imageDTOs: [imageDTO],
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
@@ -640,7 +919,7 @@ describe('StagingAreaApi', () => {
|
||||
api.onInvocationProgressEvent(progressEvent);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.imageDTO).toBe(imageDTO);
|
||||
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO);
|
||||
expect(progressData[1]?.progressEvent).toBe(progressEvent);
|
||||
});
|
||||
});
|
||||
@@ -712,6 +991,118 @@ describe('StagingAreaApi', () => {
|
||||
expect(api.$selectedItem.get()?.item.item_id).toBe(2);
|
||||
});
|
||||
|
||||
it('should not let stale async call overwrite newer data with fewer images', async () => {
|
||||
const imageDTO1 = createMockImageDTO({ image_name: 'img1.png' });
|
||||
const imageDTO2 = createMockImageDTO({ image_name: 'img2.png' });
|
||||
mockApp._setImageDTO('img1.png', imageDTO1);
|
||||
mockApp._setImageDTO('img2.png', imageDTO2);
|
||||
|
||||
// Simulates optimistic update: status=completed but only 1 result in session.results
|
||||
const itemsStale = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
session: {
|
||||
id: sessionId,
|
||||
source_prepared_mapping: { 'canvas_output:a': ['node-1'] },
|
||||
results: { 'node-1': { image: { image_name: 'img1.png' } } },
|
||||
},
|
||||
}),
|
||||
];
|
||||
|
||||
// Simulates full refetch: both results available
|
||||
const itemsFull = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
session: {
|
||||
id: sessionId,
|
||||
source_prepared_mapping: { 'canvas_output:a': ['node-1'], 'canvas_output:b': ['node-2'] },
|
||||
results: {
|
||||
'node-1': { image: { image_name: 'img1.png' } },
|
||||
'node-2': { image: { image_name: 'img2.png' } },
|
||||
},
|
||||
},
|
||||
}),
|
||||
];
|
||||
|
||||
// Fire both concurrently (stale optimistic update then full refetch)
|
||||
const promise1 = api.onItemsChangedEvent(itemsStale);
|
||||
const promise2 = api.onItemsChangedEvent(itemsFull);
|
||||
await Promise.all([promise1, promise2]);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.imageDTOs).toHaveLength(2);
|
||||
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO1);
|
||||
expect(progressData[1]?.imageDTOs[1]).toBe(imageDTO2);
|
||||
});
|
||||
|
||||
it('should load all images from multiple canvas_output nodes', async () => {
|
||||
const imageDTO1 = createMockImageDTO({ image_name: 'output1.png' });
|
||||
const imageDTO2 = createMockImageDTO({ image_name: 'output2.png' });
|
||||
mockApp._setImageDTO('output1.png', imageDTO1);
|
||||
mockApp._setImageDTO('output2.png', imageDTO2);
|
||||
|
||||
const items = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
session: {
|
||||
id: sessionId,
|
||||
source_prepared_mapping: {
|
||||
'canvas_output:abc': ['prepared-1'],
|
||||
'canvas_output:def': ['prepared-2'],
|
||||
},
|
||||
results: {
|
||||
'prepared-1': { image: { image_name: 'output1.png' } },
|
||||
'prepared-2': { image: { image_name: 'output2.png' } },
|
||||
},
|
||||
},
|
||||
}),
|
||||
];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.imageDTOs).toHaveLength(2);
|
||||
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO1);
|
||||
expect(progressData[1]?.imageDTOs[1]).toBe(imageDTO2);
|
||||
});
|
||||
|
||||
it('should create separate entries for multiple canvas_output images', async () => {
|
||||
const imageDTO1 = createMockImageDTO({ image_name: 'output1.png' });
|
||||
const imageDTO2 = createMockImageDTO({ image_name: 'output2.png' });
|
||||
mockApp._setImageDTO('output1.png', imageDTO1);
|
||||
mockApp._setImageDTO('output2.png', imageDTO2);
|
||||
|
||||
const items = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
session: {
|
||||
id: sessionId,
|
||||
source_prepared_mapping: {
|
||||
'canvas_output:abc': ['prepared-1'],
|
||||
'canvas_output:def': ['prepared-2'],
|
||||
},
|
||||
results: {
|
||||
'prepared-1': { image: { image_name: 'output1.png' } },
|
||||
'prepared-2': { image: { image_name: 'output2.png' } },
|
||||
},
|
||||
},
|
||||
}),
|
||||
];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
const entries = api.$entries.get();
|
||||
expect(entries).toHaveLength(2);
|
||||
expect(entries[0]?.imageDTO).toBe(imageDTO1);
|
||||
expect(entries[0]?.imageIndex).toBe(0);
|
||||
expect(entries[1]?.imageDTO).toBe(imageDTO2);
|
||||
expect(entries[1]?.imageIndex).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle multiple progress events for same item', () => {
|
||||
const event1 = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
@@ -740,7 +1131,7 @@ describe('StagingAreaApi', () => {
|
||||
itemId: i,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageDTOs: [],
|
||||
imageLoaded: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { atom, computed, map } from 'nanostores';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
import { getOutputImageName } from './shared';
|
||||
import { getOutputImageNames } from './shared';
|
||||
|
||||
/**
|
||||
* Interface for the app-level API that the StagingAreaApi depends on.
|
||||
@@ -34,14 +34,23 @@ export type ProgressData = {
|
||||
itemId: number;
|
||||
progressEvent: S['InvocationProgressEvent'] | null;
|
||||
progressImage: ProgressImage | null;
|
||||
imageDTO: ImageDTO | null;
|
||||
imageDTOs: ImageDTO[];
|
||||
imageLoaded: boolean;
|
||||
};
|
||||
|
||||
/** Combined data for the currently selected item */
|
||||
/** A single entry in the staging area. Each canvas_output image is a separate entry. */
|
||||
export type StagingEntry = {
|
||||
item: S['SessionQueueItem'];
|
||||
imageDTO: ImageDTO | null;
|
||||
imageIndex: number;
|
||||
progressData: ProgressData;
|
||||
};
|
||||
|
||||
/** Combined data for the currently selected entry */
|
||||
export type SelectedItemData = {
|
||||
item: S['SessionQueueItem'];
|
||||
index: number;
|
||||
imageDTO: ImageDTO | null;
|
||||
progressData: ProgressData;
|
||||
};
|
||||
|
||||
@@ -50,7 +59,7 @@ export const getInitialProgressData = (itemId: number): ProgressData => ({
|
||||
itemId,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageDTOs: [],
|
||||
imageLoaded: false,
|
||||
});
|
||||
type ProgressDataMap = Record<number, ProgressData | undefined>;
|
||||
@@ -58,8 +67,7 @@ type ProgressDataMap = Record<number, ProgressData | undefined>;
|
||||
/**
|
||||
* API for managing the Canvas Staging Area - a view of the image generation queue.
|
||||
* Provides reactive state management for pending, in-progress, and completed images.
|
||||
* Users can accept images to place on canvas, discard them, navigate between items,
|
||||
* and configure auto-switching behavior.
|
||||
* Each canvas_output node produces a separate entry that can be individually navigated and accepted.
|
||||
*/
|
||||
export class StagingAreaApi {
|
||||
/** The current session ID. */
|
||||
@@ -71,6 +79,9 @@ export class StagingAreaApi {
|
||||
/** A set of subscriptions to be cleaned up when we are finished with a session */
|
||||
_subscriptions = new Set<() => void>();
|
||||
|
||||
/** Generation counter to prevent stale async writes in onItemsChangedEvent */
|
||||
_itemsEventGeneration = 0;
|
||||
|
||||
/** Item ID of the last started item. Used for auto-switch on start. */
|
||||
$lastStartedItemId = atom<number | null>(null);
|
||||
|
||||
@@ -86,8 +97,38 @@ export class StagingAreaApi {
|
||||
/** ID of the currently selected queue item, or null if none selected. */
|
||||
$selectedItemId = atom<number | null>(null);
|
||||
|
||||
/** Total number of items in the queue. */
|
||||
$itemCount = computed([this.$items], (items) => items.length);
|
||||
/** Index of the selected image within the selected queue item (for multi-output items). */
|
||||
$selectedImageIndex = atom<number>(0);
|
||||
|
||||
/**
|
||||
* Flat list of staging entries. Each canvas_output image from a queue item becomes
|
||||
* a separate entry. Items with 0 or 1 output images produce a single entry.
|
||||
*/
|
||||
$entries = computed([this.$items, this.$progressData], (items, progressData) => {
|
||||
const entries: StagingEntry[] = [];
|
||||
for (const item of items) {
|
||||
const datum = progressData[item.item_id] ?? getInitialProgressData(item.item_id);
|
||||
if (datum.imageDTOs.length <= 1) {
|
||||
entries.push({
|
||||
item,
|
||||
imageDTO: datum.imageDTOs[0] ?? null,
|
||||
imageIndex: 0,
|
||||
progressData: datum,
|
||||
});
|
||||
} else {
|
||||
for (let i = 0; i < datum.imageDTOs.length; i++) {
|
||||
const imageDTO = datum.imageDTOs[i];
|
||||
if (imageDTO) {
|
||||
entries.push({ item, imageDTO, imageIndex: i, progressData: datum });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return entries;
|
||||
});
|
||||
|
||||
/** Total number of entries (each canvas_output image counts separately). */
|
||||
$itemCount = computed([this.$entries], (entries) => entries.length);
|
||||
|
||||
/** Whether there are any items in the queue. */
|
||||
$hasItems = computed([this.$items], (items) => items.length > 0);
|
||||
@@ -97,113 +138,153 @@ export class StagingAreaApi {
|
||||
items.some((item) => item.status === 'pending' || item.status === 'in_progress')
|
||||
);
|
||||
|
||||
/** The currently selected queue item with its index and progress data, or null if none selected. */
|
||||
/** The currently selected entry with its global index, or null if none selected. */
|
||||
$selectedItem = computed(
|
||||
[this.$items, this.$selectedItemId, this.$progressData],
|
||||
(items, selectedItemId, progressData) => {
|
||||
if (items.length === 0) {
|
||||
[this.$entries, this.$selectedItemId, this.$selectedImageIndex],
|
||||
(entries, selectedItemId, selectedImageIndex) => {
|
||||
if (entries.length === 0 || selectedItemId === null) {
|
||||
return null;
|
||||
}
|
||||
if (selectedItemId === null) {
|
||||
return null;
|
||||
|
||||
// Find the entry matching (selectedItemId, selectedImageIndex)
|
||||
let targetEntry: StagingEntry | null = null;
|
||||
let globalIndex = -1;
|
||||
let imageIdxWithinItem = 0;
|
||||
|
||||
for (let i = 0; i < entries.length; i++) {
|
||||
const entry = entries[i]!;
|
||||
if (entry.item.item_id === selectedItemId) {
|
||||
if (imageIdxWithinItem === selectedImageIndex) {
|
||||
targetEntry = entry;
|
||||
globalIndex = i;
|
||||
break;
|
||||
}
|
||||
imageIdxWithinItem++;
|
||||
}
|
||||
}
|
||||
const item = items.find(({ item_id }) => item_id === selectedItemId);
|
||||
if (!item) {
|
||||
|
||||
// Fallback: select first entry for this item
|
||||
if (!targetEntry) {
|
||||
for (let i = 0; i < entries.length; i++) {
|
||||
const entry = entries[i]!;
|
||||
if (entry.item.item_id === selectedItemId) {
|
||||
targetEntry = entry;
|
||||
globalIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!targetEntry || globalIndex === -1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
item,
|
||||
index: items.findIndex(({ item_id }) => item_id === selectedItemId),
|
||||
progressData: progressData[selectedItemId] || getInitialProgressData(selectedItemId),
|
||||
item: targetEntry.item,
|
||||
index: globalIndex,
|
||||
imageDTO: targetEntry.imageDTO,
|
||||
progressData: targetEntry.progressData,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
/** The ImageDTO of the currently selected item, or null if none available. */
|
||||
/** The ImageDTO of the currently selected entry, or null if none available. */
|
||||
$selectedItemImageDTO = computed([this.$selectedItem], (selectedItem) => {
|
||||
return selectedItem?.progressData.imageDTO ?? null;
|
||||
return selectedItem?.imageDTO ?? null;
|
||||
});
|
||||
|
||||
/** The index of the currently selected item, or null if none selected. */
|
||||
/** The global entry index of the currently selected entry, or null if none selected. */
|
||||
$selectedItemIndex = computed([this.$selectedItem], (selectedItem) => {
|
||||
return selectedItem?.index ?? null;
|
||||
});
|
||||
|
||||
/** Selects a queue item by ID. */
|
||||
select = (itemId: number) => {
|
||||
/** Selects a queue item by ID, optionally at a specific image index. */
|
||||
select = (itemId: number, imageIndex: number = 0) => {
|
||||
this.$selectedItemId.set(itemId);
|
||||
this.$selectedImageIndex.set(imageIndex);
|
||||
this._app?.onSelect?.(itemId);
|
||||
};
|
||||
|
||||
/** Selects the next item in the queue, wrapping to the first item if at the end. */
|
||||
/** Selects the next entry, cycling through all entries across all items. */
|
||||
selectNext = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
return;
|
||||
}
|
||||
const items = this.$items.get();
|
||||
const nextIndex = (selectedItem.index + 1) % items.length;
|
||||
const nextItem = items[nextIndex];
|
||||
if (!nextItem) {
|
||||
const entries = this.$entries.get();
|
||||
if (entries.length <= 1) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(nextItem.item_id);
|
||||
const nextIndex = (selectedItem.index + 1) % entries.length;
|
||||
const nextEntry = entries[nextIndex];
|
||||
if (!nextEntry) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(nextEntry.item.item_id);
|
||||
this.$selectedImageIndex.set(nextEntry.imageIndex);
|
||||
this._app?.onSelectNext?.();
|
||||
};
|
||||
|
||||
/** Selects the previous item in the queue, wrapping to the last item if at the beginning. */
|
||||
/** Selects the previous entry, cycling through all entries across all items. */
|
||||
selectPrev = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
return;
|
||||
}
|
||||
const items = this.$items.get();
|
||||
const prevIndex = (selectedItem.index - 1 + items.length) % items.length;
|
||||
const prevItem = items[prevIndex];
|
||||
if (!prevItem) {
|
||||
const entries = this.$entries.get();
|
||||
if (entries.length <= 1) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(prevItem.item_id);
|
||||
const prevIndex = (selectedItem.index - 1 + entries.length) % entries.length;
|
||||
const prevEntry = entries[prevIndex];
|
||||
if (!prevEntry) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(prevEntry.item.item_id);
|
||||
this.$selectedImageIndex.set(prevEntry.imageIndex);
|
||||
this._app?.onSelectPrev?.();
|
||||
};
|
||||
|
||||
/** Selects the first item in the queue. */
|
||||
/** Selects the first entry. */
|
||||
selectFirst = () => {
|
||||
const items = this.$items.get();
|
||||
const first = items.at(0);
|
||||
const entries = this.$entries.get();
|
||||
const first = entries[0];
|
||||
if (!first) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(first.item_id);
|
||||
this.$selectedItemId.set(first.item.item_id);
|
||||
this.$selectedImageIndex.set(first.imageIndex);
|
||||
this._app?.onSelectFirst?.();
|
||||
};
|
||||
|
||||
/** Selects the last item in the queue. */
|
||||
/** Selects the last entry. */
|
||||
selectLast = () => {
|
||||
const items = this.$items.get();
|
||||
const last = items.at(-1);
|
||||
const entries = this.$entries.get();
|
||||
const last = entries.at(-1);
|
||||
if (!last) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(last.item_id);
|
||||
this.$selectedItemId.set(last.item.item_id);
|
||||
this.$selectedImageIndex.set(last.imageIndex);
|
||||
this._app?.onSelectLast?.();
|
||||
};
|
||||
|
||||
/** Discards the currently selected item and selects the next available item. */
|
||||
/** Discards the queue item of the currently selected entry and selects the next available entry. */
|
||||
discardSelected = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
return;
|
||||
}
|
||||
const items = this.$items.get();
|
||||
const nextIndex = clamp(selectedItem.index + 1, 0, items.length - 1);
|
||||
const nextItem = items[nextIndex];
|
||||
const itemIndex = items.findIndex((i) => i.item_id === selectedItem.item.item_id);
|
||||
const nextItemIndex = clamp(itemIndex + 1, 0, items.length - 1);
|
||||
const nextItem = items[nextItemIndex];
|
||||
if (nextItem) {
|
||||
this.$selectedItemId.set(nextItem.item_id);
|
||||
} else {
|
||||
this.$selectedItemId.set(null);
|
||||
}
|
||||
this.$selectedImageIndex.set(0);
|
||||
this._app?.onDiscard?.(selectedItem.item);
|
||||
};
|
||||
|
||||
@@ -231,30 +312,22 @@ export class StagingAreaApi {
|
||||
/** Discards all items in the queue. */
|
||||
discardAll = () => {
|
||||
this.$selectedItemId.set(null);
|
||||
this.$selectedImageIndex.set(0);
|
||||
this._app?.onDiscardAll?.();
|
||||
};
|
||||
|
||||
/** Accepts the currently selected item if an image is available. */
|
||||
/** Accepts the currently selected entry if an image is available. */
|
||||
acceptSelected = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
if (selectedItem === null || !selectedItem.imageDTO) {
|
||||
return;
|
||||
}
|
||||
const progressData = this.$progressData.get();
|
||||
const datum = progressData[selectedItem.item.item_id];
|
||||
if (!datum || !datum.imageDTO) {
|
||||
return;
|
||||
}
|
||||
this._app?.onAccept?.(selectedItem.item, datum.imageDTO);
|
||||
this._app?.onAccept?.(selectedItem.item, selectedItem.imageDTO);
|
||||
};
|
||||
|
||||
/** Whether the accept selected action is enabled. */
|
||||
$acceptSelectedIsEnabled = computed([this.$selectedItem, this.$progressData], (selectedItem, progressData) => {
|
||||
if (selectedItem === null) {
|
||||
return false;
|
||||
}
|
||||
const datum = progressData[selectedItem.item.item_id];
|
||||
return !!datum && !!datum.imageDTO;
|
||||
$acceptSelectedIsEnabled = computed([this.$selectedItem], (selectedItem) => {
|
||||
return selectedItem !== null && selectedItem.imageDTO !== null;
|
||||
});
|
||||
|
||||
/** Sets the auto-switch mode. */
|
||||
@@ -297,6 +370,10 @@ export class StagingAreaApi {
|
||||
* handles auto-selection, and implements auto-switch behavior.
|
||||
*/
|
||||
onItemsChangedEvent = async (items: S['SessionQueueItem'][]) => {
|
||||
// Increment generation counter. If a newer call starts while we're awaiting,
|
||||
// we'll detect it and avoid overwriting with stale data.
|
||||
const generation = ++this._itemsEventGeneration;
|
||||
|
||||
const oldItems = this.$items.get();
|
||||
|
||||
if (items === oldItems) {
|
||||
@@ -306,9 +383,11 @@ export class StagingAreaApi {
|
||||
if (items.length === 0) {
|
||||
// If there are no items, cannot have a selected item.
|
||||
this.$selectedItemId.set(null);
|
||||
this.$selectedImageIndex.set(0);
|
||||
} else if (this.$selectedItemId.get() === null && items.length > 0) {
|
||||
// If there is no selected item but there are items, select the first one.
|
||||
this.$selectedItemId.set(items[0]?.item_id ?? null);
|
||||
this.$selectedImageIndex.set(0);
|
||||
}
|
||||
|
||||
const progressData = this.$progressData.get();
|
||||
@@ -328,7 +407,7 @@ export class StagingAreaApi {
|
||||
...(datum ?? getInitialProgressData(item.item_id)),
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageDTOs: [],
|
||||
});
|
||||
continue;
|
||||
}
|
||||
@@ -336,31 +415,57 @@ export class StagingAreaApi {
|
||||
if (item.status === 'in_progress') {
|
||||
if (this.$lastStartedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_start') {
|
||||
this.$selectedItemId.set(item.item_id);
|
||||
this.$selectedImageIndex.set(0);
|
||||
this.$lastStartedItemId.set(null);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.status === 'completed') {
|
||||
if (datum?.imageDTO) {
|
||||
const outputImageNames = getOutputImageNames(item);
|
||||
if (outputImageNames.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const outputImageName = getOutputImageName(item);
|
||||
if (!outputImageName) {
|
||||
// Check current progress data (not the snapshot) to account for concurrent updates
|
||||
const currentDatum = this.$progressData.get()[item.item_id];
|
||||
if (currentDatum && currentDatum.imageDTOs.length === outputImageNames.length) {
|
||||
continue;
|
||||
}
|
||||
const imageDTO = await this._app?.getImageDTO(outputImageName);
|
||||
if (!imageDTO) {
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
for (const imageName of outputImageNames) {
|
||||
const imageDTO = await this._app?.getImageDTO(imageName);
|
||||
if (imageDTO) {
|
||||
imageDTOs.push(imageDTO);
|
||||
}
|
||||
}
|
||||
if (imageDTOs.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// After async work, check if a newer event has started processing.
|
||||
// If so, abort to let the newer call handle the update with fresher data.
|
||||
if (generation !== this._itemsEventGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Re-read progress data to avoid overwriting a better result from a concurrent call
|
||||
const latestDatum = this.$progressData.get()[item.item_id];
|
||||
if (latestDatum && latestDatum.imageDTOs.length >= imageDTOs.length) {
|
||||
continue;
|
||||
}
|
||||
|
||||
this.$progressData.setKey(item.item_id, {
|
||||
...(datum ?? getInitialProgressData(item.item_id)),
|
||||
imageDTO,
|
||||
...(latestDatum ?? getInitialProgressData(item.item_id)),
|
||||
imageDTOs,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// After async work, check if a newer event has started processing
|
||||
if (generation !== this._itemsEventGeneration) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedItemId = this.$selectedItemId.get();
|
||||
if (selectedItemId !== null && !items.find(({ item_id }) => item_id === selectedItemId)) {
|
||||
// If the selected item no longer exists, select the next best item.
|
||||
@@ -370,15 +475,18 @@ export class StagingAreaApi {
|
||||
const nextItem = items[nextItemIndex] ?? items[nextItemIndex - 1];
|
||||
if (nextItem) {
|
||||
this.$selectedItemId.set(nextItem.item_id);
|
||||
this.$selectedImageIndex.set(0);
|
||||
}
|
||||
} else {
|
||||
// Next, if there is an in-progress item, select that.
|
||||
const inProgressItem = items.find(({ status }) => status === 'in_progress');
|
||||
if (inProgressItem) {
|
||||
this.$selectedItemId.set(inProgressItem.item_id);
|
||||
this.$selectedImageIndex.set(0);
|
||||
}
|
||||
// Finally just select the first item.
|
||||
this.$selectedItemId.set(items[0]?.item_id ?? null);
|
||||
this.$selectedImageIndex.set(0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,6 +501,7 @@ export class StagingAreaApi {
|
||||
// This is the load logic mentioned in the comment in the QueueItemStatusChangedEvent handler above.
|
||||
if (this.$lastCompletedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_finish') {
|
||||
this.$selectedItemId.set(item.item_id);
|
||||
this.$selectedImageIndex.set(0);
|
||||
this.$lastCompletedItemId.set(null);
|
||||
}
|
||||
const datum = this.$progressData.get()[item.item_id];
|
||||
@@ -402,20 +511,22 @@ export class StagingAreaApi {
|
||||
});
|
||||
};
|
||||
|
||||
/** Creates a computed value that returns true if the given item ID is selected. */
|
||||
buildIsSelectedComputed = (itemId: number) => {
|
||||
return computed([this.$selectedItemId], (selectedItemId) => {
|
||||
return selectedItemId === itemId;
|
||||
/** Creates a computed value that returns true if the given item ID and image index is selected. */
|
||||
buildIsSelectedComputed = (itemId: number, imageIndex: number = 0) => {
|
||||
return computed([this.$selectedItemId, this.$selectedImageIndex], (selectedItemId, selectedImageIndex) => {
|
||||
return selectedItemId === itemId && selectedImageIndex === imageIndex;
|
||||
});
|
||||
};
|
||||
|
||||
/** Cleans up all state and unsubscribes from all events. */
|
||||
cleanup = () => {
|
||||
this._itemsEventGeneration++;
|
||||
this.$lastStartedItemId.set(null);
|
||||
this.$lastCompletedItemId.set(null);
|
||||
this.$items.set([]);
|
||||
this.$progressData.set({});
|
||||
this.$selectedItemId.set(null);
|
||||
this.$selectedImageIndex.set(0);
|
||||
this._subscriptions.forEach((unsubscribe) => unsubscribe());
|
||||
this._subscriptions.clear();
|
||||
};
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { canvasWorkflowIntegrationOpened } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFlowArrowBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityMenuItemsRunWorkflow = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(canvasWorkflowIntegrationOpened({ sourceEntityIdentifier: entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={onClick} icon={<PiFlowArrowBold />}>
|
||||
{t('controlLayers.workflowIntegration.runWorkflow')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityMenuItemsRunWorkflow.displayName = 'CanvasEntityMenuItemsRunWorkflow';
|
||||
@@ -156,8 +156,8 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
|
||||
if (selectedItem.progressData.imageDTO) {
|
||||
this.$imageSrc.set({ type: 'imageName', data: selectedItem.progressData.imageDTO.image_name });
|
||||
if (selectedItem.imageDTO) {
|
||||
this.$imageSrc.set({ type: 'imageName', data: selectedItem.imageDTO.image_name });
|
||||
return;
|
||||
} else if (selectedItem.progressData?.progressImage) {
|
||||
this.$imageSrc.set({ type: 'dataURL', data: selectedItem.progressData.progressImage.dataURL });
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
const zCanvasWorkflowIntegrationState = z.object({
|
||||
_version: z.literal(1),
|
||||
isOpen: z.boolean(),
|
||||
selectedWorkflowId: z.string().nullable(),
|
||||
sourceEntityIdentifier: z
|
||||
.object({
|
||||
type: z.enum(['raster_layer', 'control_layer', 'regional_guidance', 'inpaint_mask']),
|
||||
id: z.string(),
|
||||
})
|
||||
.nullable(),
|
||||
fieldValues: z.record(z.string(), z.any()).nullable(),
|
||||
// Which ImageField to use for canvas image (format: "nodeId.fieldName")
|
||||
selectedImageFieldKey: z.string().nullable(),
|
||||
isProcessing: z.boolean(),
|
||||
});
|
||||
|
||||
type CanvasWorkflowIntegrationState = z.infer<typeof zCanvasWorkflowIntegrationState>;
|
||||
|
||||
const getInitialState = (): CanvasWorkflowIntegrationState => ({
|
||||
_version: 1,
|
||||
isOpen: false,
|
||||
selectedWorkflowId: null,
|
||||
sourceEntityIdentifier: null,
|
||||
fieldValues: null,
|
||||
selectedImageFieldKey: null,
|
||||
isProcessing: false,
|
||||
});
|
||||
|
||||
const slice = createSlice({
|
||||
name: 'canvasWorkflowIntegration',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
canvasWorkflowIntegrationOpened: (
|
||||
state,
|
||||
action: PayloadAction<{ sourceEntityIdentifier: CanvasEntityIdentifier }>
|
||||
) => {
|
||||
state.isOpen = true;
|
||||
state.sourceEntityIdentifier = action.payload.sourceEntityIdentifier;
|
||||
state.selectedWorkflowId = null;
|
||||
state.fieldValues = null;
|
||||
},
|
||||
canvasWorkflowIntegrationClosed: (state) => {
|
||||
state.isOpen = false;
|
||||
state.selectedWorkflowId = null;
|
||||
state.sourceEntityIdentifier = null;
|
||||
state.fieldValues = null;
|
||||
state.selectedImageFieldKey = null;
|
||||
state.isProcessing = false;
|
||||
},
|
||||
canvasWorkflowIntegrationWorkflowSelected: (state, action: PayloadAction<{ workflowId: string | null }>) => {
|
||||
state.selectedWorkflowId = action.payload.workflowId;
|
||||
// Reset field values when switching workflows
|
||||
state.fieldValues = null;
|
||||
state.selectedImageFieldKey = null;
|
||||
},
|
||||
canvasWorkflowIntegrationImageFieldSelected: (state, action: PayloadAction<{ fieldKey: string | null }>) => {
|
||||
state.selectedImageFieldKey = action.payload.fieldKey;
|
||||
},
|
||||
canvasWorkflowIntegrationFieldValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ fieldName: string; value: unknown }>
|
||||
) => {
|
||||
if (!state.fieldValues) {
|
||||
state.fieldValues = {};
|
||||
}
|
||||
state.fieldValues[action.payload.fieldName] = action.payload.value;
|
||||
},
|
||||
canvasWorkflowIntegrationFieldValuesReset: (state) => {
|
||||
state.fieldValues = null;
|
||||
},
|
||||
canvasWorkflowIntegrationProcessingStarted: (state) => {
|
||||
state.isProcessing = true;
|
||||
},
|
||||
canvasWorkflowIntegrationProcessingCompleted: (state) => {
|
||||
state.isProcessing = false;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
canvasWorkflowIntegrationOpened,
|
||||
canvasWorkflowIntegrationClosed,
|
||||
canvasWorkflowIntegrationWorkflowSelected,
|
||||
canvasWorkflowIntegrationImageFieldSelected,
|
||||
canvasWorkflowIntegrationFieldValueChanged,
|
||||
canvasWorkflowIntegrationProcessingStarted,
|
||||
canvasWorkflowIntegrationProcessingCompleted,
|
||||
} = slice.actions;
|
||||
|
||||
export const canvasWorkflowIntegrationSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
schema: zCanvasWorkflowIntegrationState,
|
||||
getInitialState,
|
||||
persistConfig: {
|
||||
migrate: (state) => {
|
||||
assert(isPlainObject(state));
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return zCanvasWorkflowIntegrationState.parse(state);
|
||||
},
|
||||
persistDenylist: ['isOpen', 'isProcessing', 'sourceEntityIdentifier'],
|
||||
},
|
||||
};
|
||||
|
||||
const selectCanvasWorkflowIntegrationSlice = (state: RootState) => state.canvasWorkflowIntegration;
|
||||
const createCanvasWorkflowIntegrationSelector = <T>(selector: Selector<CanvasWorkflowIntegrationState, T>) =>
|
||||
createSelector(selectCanvasWorkflowIntegrationSlice, selector);
|
||||
|
||||
export const selectCanvasWorkflowIntegrationIsOpen = createCanvasWorkflowIntegrationSelector((state) => state.isOpen);
|
||||
export const selectCanvasWorkflowIntegrationSelectedWorkflowId = createCanvasWorkflowIntegrationSelector(
|
||||
(state) => state.selectedWorkflowId
|
||||
);
|
||||
export const selectCanvasWorkflowIntegrationSourceEntityIdentifier = createCanvasWorkflowIntegrationSelector(
|
||||
(state) => state.sourceEntityIdentifier
|
||||
);
|
||||
export const selectCanvasWorkflowIntegrationFieldValues = createCanvasWorkflowIntegrationSelector(
|
||||
(state) => state.fieldValues
|
||||
);
|
||||
export const selectCanvasWorkflowIntegrationSelectedImageFieldKey = createCanvasWorkflowIntegrationSelector(
|
||||
(state) => state.selectedImageFieldKey
|
||||
);
|
||||
export const selectCanvasWorkflowIntegrationIsProcessing = createCanvasWorkflowIntegrationSelector(
|
||||
(state) => state.isProcessing
|
||||
);
|
||||
@@ -1,6 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
|
||||
export { getPrefixedId };
|
||||
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
@@ -205,7 +207,7 @@ export const getInfill = (
|
||||
assert(false, 'Unknown infill method');
|
||||
};
|
||||
|
||||
const CANVAS_OUTPUT_PREFIX = 'canvas_output';
|
||||
export const CANVAS_OUTPUT_PREFIX = 'canvas_output';
|
||||
|
||||
export const isMainModelWithoutUnet = (modelLoader: Invocation<MainModelLoaderNodes>) => {
|
||||
return (
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, AppGetState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import {
|
||||
selectAutoSwitch,
|
||||
selectGalleryView,
|
||||
@@ -13,8 +14,10 @@ import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks
|
||||
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import type { LRUCache } from 'lru-cache';
|
||||
import { LIST_ALL_TAG } from 'services/api';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { getCategories } from 'services/api/util';
|
||||
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
|
||||
@@ -217,6 +220,27 @@ export const buildOnInvocationComplete = (
|
||||
return imageDTOs;
|
||||
};
|
||||
|
||||
const clearCanvasWorkflowIntegrationProcessing = (data: S['InvocationCompleteEvent']) => {
|
||||
// Check if this is a canvas workflow integration result
|
||||
// Results go to staging area automatically via destination = canvasSessionId
|
||||
if (data.origin !== 'canvas_workflow_integration') {
|
||||
return;
|
||||
}
|
||||
// Clear processing state so the modal loading spinner stops
|
||||
dispatch(canvasWorkflowIntegrationProcessingCompleted());
|
||||
|
||||
// Check if this invocation produced an image output
|
||||
const hasImageOutput = objectEntries(data.result).some(([_name, value]) => {
|
||||
return isImageField(value) || isImageFieldCollection(value);
|
||||
});
|
||||
|
||||
// Only invalidate if this invocation produced an image - this ensures the staging area
|
||||
// gets updated immediately when output images are available, without invalidating on every invocation
|
||||
if (hasImageOutput) {
|
||||
dispatch(queueApi.util.invalidateTags([{ type: 'SessionQueueItem', id: LIST_ALL_TAG }]));
|
||||
}
|
||||
};
|
||||
|
||||
return async (data: S['InvocationCompleteEvent']) => {
|
||||
if (finishedQueueItemIds.has(data.item_id)) {
|
||||
log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`);
|
||||
@@ -236,6 +260,10 @@ export const buildOnInvocationComplete = (
|
||||
upsertExecutionState(_nodeExecutionState.nodeId, _nodeExecutionState);
|
||||
}
|
||||
|
||||
// Clear canvas workflow integration processing state if needed
|
||||
clearCanvasWorkflowIntegrationProcessing(data);
|
||||
|
||||
// Add images to gallery (canvas workflow integration results go to staging area automatically)
|
||||
await addImagesToGallery(data);
|
||||
|
||||
$lastProgressEvent.set(null);
|
||||
|
||||
@@ -5,6 +5,7 @@ import type { AppStore } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { forEach, isNil, round } from 'es-toolkit/compat';
|
||||
import { allEntitiesDeleted, controlLayerRecalled } from 'features/controlLayers/store/canvasSlice';
|
||||
import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
|
||||
import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
|
||||
import {
|
||||
heightChanged,
|
||||
@@ -160,6 +161,10 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
||||
};
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
// Clear canvas workflow integration processing state on error
|
||||
if (data.origin === 'canvas_workflow_integration') {
|
||||
dispatch(canvasWorkflowIntegrationProcessingCompleted());
|
||||
}
|
||||
});
|
||||
|
||||
const onInvocationComplete = buildOnInvocationComplete(getState, dispatch, finishedQueueItemIds);
|
||||
@@ -407,6 +412,25 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
||||
})
|
||||
);
|
||||
|
||||
// Optimistically update the listAllQueueItems cache for this destination so the canvas
|
||||
// staging area immediately reflects status changes without waiting for a tag-based refetch
|
||||
if (destination) {
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('listAllQueueItems', { destination }, (draft) => {
|
||||
const item = draft.find((i) => i.item_id === item_id);
|
||||
if (item) {
|
||||
item.status = status;
|
||||
item.started_at = started_at;
|
||||
item.updated_at = updated_at;
|
||||
item.completed_at = completed_at;
|
||||
item.error_type = error_type;
|
||||
item.error_message = error_message;
|
||||
item.error_traceback = error_traceback;
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Invalidate caches for things we cannot easily update
|
||||
// Invalidate SessionQueueStatus to refetch with user-specific counts
|
||||
const tagsToInvalidate: ApiTagDescription[] = [
|
||||
|
||||
Reference in New Issue
Block a user