Merge remote-tracking branch 'upstream/main' into external-models

This commit is contained in:
Alexander Eichhorn
2026-04-07 23:54:26 +02:00
35 changed files with 2734 additions and 191 deletions

View File

@@ -0,0 +1,27 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation(
"canvas_output",
title="Canvas Output",
tags=["canvas", "output", "image"],
category="canvas",
version="1.0.0",
use_cache=False,
)
class CanvasOutputInvocation(BaseInvocation):
"""Outputs an image to the canvas staging area.
Use this node in workflows intended for canvas workflow integration.
Connect the final image of your workflow to this node to send it
to the canvas staging area when run via 'Run Workflow on Canvas'."""
image: ImageField = InputField(description=FieldDescriptions.image)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -711,6 +711,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
"transformer.layers.", # OneTrainer/diffusers prefix variant
"base_model.model.transformer.layers.", # PEFT-wrapped variant
},
)
@@ -747,6 +749,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
"transformer.layers.", # OneTrainer/diffusers prefix variant
"base_model.model.transformer.layers.", # PEFT-wrapped variant
},
)

View File

@@ -323,6 +323,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
return False
def _filename_suggests_base(name: str) -> bool:
"""Check if a model name/filename suggests it is a Base (undistilled) variant.
Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished
from the state dict. We use the filename as a heuristic: filenames containing "base"
(e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model.
"""
return "base" in name.lower()
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
"""Determine FLUX.2 variant from state dict.
@@ -330,9 +340,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
We default to Klein9B (distilled) for all 9B models since GGUF models may not
include guidance embedding keys needed to distinguish them.
Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
and cannot be distinguished from the state dict alone. This function defaults to Klein9B
for all 9B models. Callers should use filename heuristics to detect Klein9BBase.
Supports both BFL format (checkpoint) and diffusers format keys:
- BFL format: txt_in.weight (context embedder)
@@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
context_in_dim = shape[1]
# Determine variant based on context dimension
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
# Default to Klein9B (distilled) - the official/common 9B model
# Default to Klein9B - callers use filename heuristics to detect Klein9BBase
return Flux2VariantType.Klein9B
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
@@ -553,6 +563,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
# Klein 9B Base and Klein 9B have identical architectures.
# Use filename heuristic to detect the Base (undistilled) variant.
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
return variant
@classmethod
@@ -720,6 +735,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
# Klein 9B Base and Klein 9B have identical architectures.
# Use filename heuristic to detect the Base (undistilled) variant.
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
return variant
@classmethod
@@ -829,12 +849,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
we check guidance_embeds:
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
"""
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
@@ -842,17 +858,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
guidance_embeds = transformer_config.get("guidance_embeds", False)
# Determine variant based on joint_attention_dim
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
# Check guidance_embeds to distinguish distilled from undistilled
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
if guidance_embeds:
if _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
else:
return Flux2VariantType.Klein9B
return Flux2VariantType.Klein9B
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
elif joint_attention_dim > 4096:

View File

@@ -44,6 +44,10 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_bfl_format,
lora_model_from_flux_onetrainer_bfl_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
@@ -128,6 +132,8 @@ class LoRALoader(ModelLoader):
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_bfl_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):

View File

@@ -81,7 +81,7 @@ t5_base_encoder = StarterModel(
name="t5_base_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bfloat16",
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
description="T5-XXL text encoder (used in FLUX pipelines). ~9.5GB",
type=ModelType.T5Encoder,
)
@@ -166,7 +166,7 @@ flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
@@ -174,7 +174,7 @@ flux_krea = StarterModel(
name="FLUX.1 Krea dev",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors",
description="FLUX.1 Krea dev. Total size with dependencies: ~33GB",
description="FLUX.1 Krea dev. Total size with dependencies: ~29GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
@@ -182,7 +182,7 @@ flux_krea_quantized = StarterModel(
name="FLUX.1 Krea dev (quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf",
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB",
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
@@ -190,7 +190,7 @@ sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-medium",
description="Medium SD3.5 Model: ~15GB",
description="Medium SD3.5 Model: ~16GB",
type=ModelType.Main,
dependencies=[],
)
@@ -198,7 +198,7 @@ sd35_large = StarterModel(
name="SD3.5 Large",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-large",
description="Large SD3.5 Model: ~19G",
description="Large SD3.5 Model: ~28GB",
type=ModelType.Main,
dependencies=[],
)
@@ -654,7 +654,7 @@ cogview4 = StarterModel(
name="CogView4",
base=BaseModelType.CogView4,
source="THUDM/CogView4-6B",
description="The base CogView4 model (~29GB).",
description="The base CogView4 model (~31GB).",
type=ModelType.Main,
)
# endregion
@@ -705,7 +705,7 @@ flux2_vae = StarterModel(
name="FLUX.2 VAE",
base=BaseModelType.Flux2,
source="black-forest-labs/FLUX.2-klein-4B::vae",
description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~335MB",
description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~168MB",
type=ModelType.VAE,
)
@@ -729,7 +729,7 @@ flux2_klein_4b = StarterModel(
name="FLUX.2 Klein 4B (Diffusers)",
base=BaseModelType.Flux2,
source="black-forest-labs/FLUX.2-klein-4B",
description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~10GB",
description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~16GB",
type=ModelType.Main,
)
@@ -755,7 +755,7 @@ flux2_klein_9b = StarterModel(
name="FLUX.2 Klein 9B (Diffusers)",
base=BaseModelType.Flux2,
source="black-forest-labs/FLUX.2-klein-9B",
description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~20GB",
description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~35GB",
type=ModelType.Main,
)
@@ -831,7 +831,7 @@ z_image_turbo = StarterModel(
name="Z-Image Turbo",
base=BaseModelType.ZImage,
source="Tongyi-MAI/Z-Image-Turbo",
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~30.6GB",
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~33GB",
type=ModelType.Main,
)

View File

@@ -215,6 +215,7 @@ class FluxLoRAFormat(str, Enum):
AIToolkit = "flux.aitoolkit"
XLabs = "flux.xlabs"
BflPeft = "flux.bfl_peft"
OneTrainerBfl = "flux.onetrainer_bfl"
AnyVariant: TypeAlias = Union[

View File

@@ -0,0 +1,168 @@
"""Utilities for detecting and converting FLUX LoRAs in OneTrainer BFL format.
This format is produced by newer versions of OneTrainer and uses BFL internal key names
(double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and
InvokeAI-native LoRA suffixes (lora_down.weight, lora_up.weight, alpha).
Unlike the standard BFL PEFT format (which uses 'diffusion_model.' prefix and lora_A/lora_B),
this format also has split QKV projections:
- double_blocks.{i}.img_attn.qkv.{0,1,2} (Q, K, V separate)
- double_blocks.{i}.txt_attn.qkv.{0,1,2} (Q, K, V separate)
- single_blocks.{i}.linear1.{0,1,2,3} (Q, K, V, MLP separate)
Example keys:
transformer.double_blocks.0.img_attn.qkv.0.lora_down.weight
transformer.double_blocks.0.img_attn.qkv.0.lora_up.weight
transformer.double_blocks.0.img_attn.qkv.0.alpha
transformer.single_blocks.0.linear1.3.lora_down.weight
transformer.double_blocks.0.img_mlp.0.lora_down.weight
"""
import re
from typing import Any, Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
_TRANSFORMER_PREFIX = "transformer."
# Valid LoRA weight suffixes in this format.
_LORA_SUFFIXES = ("lora_down.weight", "lora_up.weight", "alpha")
# Regex to detect split QKV keys in double blocks: e.g. "double_blocks.0.img_attn.qkv.1"
_SPLIT_QKV_RE = re.compile(r"^(double_blocks\.\d+\.(img_attn|txt_attn)\.qkv)\.\d+$")
# Regex to detect split linear1 keys in single blocks: e.g. "single_blocks.0.linear1.2"
_SPLIT_LINEAR1_RE = re.compile(r"^(single_blocks\.\d+\.linear1)\.\d+$")
def is_state_dict_likely_in_flux_onetrainer_bfl_format(
state_dict: dict[str | int, Any],
metadata: dict[str, Any] | None = None,
) -> bool:
"""Checks if the provided state dict is likely in the OneTrainer BFL FLUX LoRA format.
This format uses BFL internal key names with 'transformer.' prefix and split QKV projections.
"""
str_keys = [k for k in state_dict.keys() if isinstance(k, str)]
if not str_keys:
return False
# All keys must start with 'transformer.'
if not all(k.startswith(_TRANSFORMER_PREFIX) for k in str_keys):
return False
# All keys must end with recognized LoRA suffixes.
if not all(k.endswith(_LORA_SUFFIXES) for k in str_keys):
return False
# Must have BFL block structure (double_blocks or single_blocks) under transformer prefix.
has_bfl_blocks = any(
k.startswith("transformer.double_blocks.") or k.startswith("transformer.single_blocks.") for k in str_keys
)
if not has_bfl_blocks:
return False
# Must have split QKV pattern (qkv.0, qkv.1, qkv.2) to distinguish from other formats
# that might use transformer. prefix in the future.
has_split_qkv = any(".qkv.0." in k or ".qkv.1." in k or ".qkv.2." in k or ".linear1.0." in k for k in str_keys)
if not has_split_qkv:
return False
return True
def _split_key(key: str) -> tuple[str, str]:
"""Split a key into (layer_name, weight_suffix).
Handles:
- 2-component suffixes ending with '.weight': e.g., 'lora_down.weight' → split at 2nd-to-last dot
- 1-component suffixes: e.g., 'alpha' → split at last dot
"""
if key.endswith(".weight"):
parts = key.rsplit(".", maxsplit=2)
return parts[0], f"{parts[1]}.{parts[2]}"
else:
parts = key.rsplit(".", maxsplit=1)
return parts[0], parts[1]
def lora_model_from_flux_onetrainer_bfl_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
"""Convert a OneTrainer BFL format FLUX LoRA state dict to a ModelPatchRaw.
Strips the 'transformer.' prefix, groups by layer, and merges split QKV/linear1
layers into MergedLayerPatch instances.
"""
# Step 1: Strip prefix and group by layer name.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
if not isinstance(key, str):
continue
# Strip 'transformer.' prefix.
key = key[len(_TRANSFORMER_PREFIX) :]
layer_name, suffix = _split_key(key)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][suffix] = value
# Step 2: Build LoRA layers, merging split QKV and linear1.
layers: dict[str, BaseLayerPatch] = {}
# Identify which layers need merging.
merge_groups: dict[str, list[str]] = {}
standalone_keys: list[str] = []
for layer_key in grouped_state_dict:
qkv_match = _SPLIT_QKV_RE.match(layer_key)
linear1_match = _SPLIT_LINEAR1_RE.match(layer_key)
if qkv_match:
parent = qkv_match.group(1)
if parent not in merge_groups:
merge_groups[parent] = []
merge_groups[parent].append(layer_key)
elif linear1_match:
parent = linear1_match.group(1)
if parent not in merge_groups:
merge_groups[parent] = []
merge_groups[parent].append(layer_key)
else:
standalone_keys.append(layer_key)
# Process standalone layers.
for layer_key in standalone_keys:
layer_sd = grouped_state_dict[layer_key]
layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"] = any_lora_layer_from_state_dict(layer_sd)
# Process merged layers.
for parent_key, sub_keys in merge_groups.items():
# Sort by the numeric index at the end (e.g., qkv.0, qkv.1, qkv.2).
sub_keys.sort(key=lambda k: int(k.rsplit(".", maxsplit=1)[1]))
sub_layers: list[BaseLayerPatch] = []
sub_ranges: list[Range] = []
dim_0_offset = 0
for sub_key in sub_keys:
layer_sd = grouped_state_dict[sub_key]
sub_layer = any_lora_layer_from_state_dict(layer_sd)
# Determine the output dimension from the up weight shape.
up_weight = layer_sd["lora_up.weight"]
out_dim = up_weight.shape[0]
sub_layers.append(sub_layer)
sub_ranges.append(Range(dim_0_offset, dim_0_offset + out_dim))
dim_0_offset += out_dim
layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{parent_key}"] = MergedLayerPatch(sub_layers, sub_ranges)
return ModelPatchRaw(layers=layers)

View File

@@ -14,6 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_bfl_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
@@ -28,6 +31,8 @@ def flux_format_from_state_dict(
) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict, metadata):
return FluxLoRAFormat.OneTrainerBfl
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
return FluxLoRAFormat.OneTrainer
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):

View File

@@ -2415,6 +2415,27 @@
"pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage",
"addAdjustments": "Add Adjustments",
"removeAdjustments": "Remove Adjustments",
"workflowIntegration": {
"title": "Run Workflow on Canvas",
"description": "Select a workflow with a Canvas Output node and an image parameter to run on the current canvas layer. You can adjust parameters before executing. The result will be added back to the canvas.",
"execute": "Execute Workflow",
"executing": "Executing...",
"runWorkflow": "Run Workflow",
"filteringWorkflows": "Filtering workflows...",
"loadingWorkflows": "Loading workflows...",
"noWorkflowsFound": "No workflows found.",
"noWorkflowsWithImageField": "No compatible workflows found. A workflow needs a Form Builder with an image input field and a Canvas Output node.",
"selectWorkflow": "Select Workflow",
"selectPlaceholder": "Choose a workflow...",
"unnamedWorkflow": "Unnamed Workflow",
"loadingParameters": "Loading workflow parameters...",
"noFormBuilderError": "This workflow has no form builder and cannot be used. Please select a different workflow.",
"imageFieldSelected": "This field will receive the canvas image",
"imageFieldNotSelected": "Click to use this field for canvas image",
"executionStarted": "Workflow execution started",
"executionStartedDescription": "The result will appear in the staging area when complete.",
"executionFailed": "Failed to execute workflow"
},
"compositeOperation": {
"label": "Blend Mode",
"add": "Add Blend Mode",

View File

@@ -1,6 +1,7 @@
import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { CropImageModal } from 'features/cropper/components/CropImageModal';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
@@ -51,6 +52,7 @@ export const GlobalModalIsolator = memo(() => {
<SaveWorkflowAsDialog />
<CanvasManagerProviderGate>
<CanvasPasteModal />
<CanvasWorkflowIntegrationModal />
</CanvasManagerProviderGate>
<LoadWorkflowFromGraphModal />
<CropImageModal />

View File

@@ -16,6 +16,7 @@ const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export const zLogNamespace = z.enum([
'canvas',
'canvas-workflow-integration',
'config',
'dnd',
'events',

View File

@@ -25,6 +25,7 @@ import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSe
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { canvasTextSliceConfig } from 'features/controlLayers/store/canvasTextSlice';
import { canvasWorkflowIntegrationSliceConfig } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
@@ -67,6 +68,7 @@ const SLICE_CONFIGS = {
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
[canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig,
[canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
[canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig,
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig,
@@ -98,6 +100,7 @@ const ALL_REDUCERS = {
canvasSliceConfig.slice.reducer,
canvasSliceConfig.undoableConfig?.reduxUndoOptions
),
[canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig.slice.reducer,
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer,

View File

@@ -0,0 +1,93 @@
import {
Button,
ButtonGroup,
Flex,
Heading,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Spacer,
Spinner,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
canvasWorkflowIntegrationClosed,
selectCanvasWorkflowIntegrationIsOpen,
selectCanvasWorkflowIntegrationIsProcessing,
selectCanvasWorkflowIntegrationSelectedWorkflowId,
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { CanvasWorkflowIntegrationParameterPanel } from './CanvasWorkflowIntegrationParameterPanel';
import { CanvasWorkflowIntegrationWorkflowSelector } from './CanvasWorkflowIntegrationWorkflowSelector';
import { useCanvasWorkflowIntegrationExecute } from './useCanvasWorkflowIntegrationExecute';
export const CanvasWorkflowIntegrationModal = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isOpen = useAppSelector(selectCanvasWorkflowIntegrationIsOpen);
const isProcessing = useAppSelector(selectCanvasWorkflowIntegrationIsProcessing);
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
const { execute, canExecute } = useCanvasWorkflowIntegrationExecute();
const onClose = useCallback(() => {
if (!isProcessing) {
dispatch(canvasWorkflowIntegrationClosed());
}
}, [dispatch, isProcessing]);
const onExecute = useCallback(() => {
execute();
}, [execute]);
return (
<Modal isOpen={isOpen} onClose={onClose} size="xl" isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>
<Heading size="md">{t('controlLayers.workflowIntegration.title')}</Heading>
</ModalHeader>
<ModalCloseButton isDisabled={isProcessing} />
<ModalBody>
<Flex direction="column" gap={4}>
<Text fontSize="sm" color="base.400">
{t('controlLayers.workflowIntegration.description')}
</Text>
<CanvasWorkflowIntegrationWorkflowSelector />
{selectedWorkflowId && <CanvasWorkflowIntegrationParameterPanel />}
</Flex>
</ModalBody>
<ModalFooter>
<ButtonGroup>
<Button variant="ghost" onClick={onClose} isDisabled={isProcessing}>
{t('common.cancel')}
</Button>
<Spacer />
<Button
onClick={onExecute}
isDisabled={!canExecute || isProcessing}
loadingText={t('controlLayers.workflowIntegration.executing')}
>
{isProcessing && <Spinner size="sm" mr={2} />}
{t('controlLayers.workflowIntegration.execute')}
</Button>
</ButtonGroup>
</ModalFooter>
</ModalContent>
</Modal>
);
});
CanvasWorkflowIntegrationModal.displayName = 'CanvasWorkflowIntegrationModal';

View File

@@ -0,0 +1,13 @@
import { Box } from '@invoke-ai/ui-library';
import { WorkflowFormPreview } from 'features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFormPreview';
import { memo } from 'react';
export const CanvasWorkflowIntegrationParameterPanel = memo(() => {
return (
<Box w="full">
<WorkflowFormPreview />
</Box>
);
});
CanvasWorkflowIntegrationParameterPanel.displayName = 'CanvasWorkflowIntegrationParameterPanel';

View File

@@ -0,0 +1,92 @@
import { Flex, FormControl, FormLabel, Select, Spinner, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
canvasWorkflowIntegrationWorkflowSelected,
selectCanvasWorkflowIntegrationSelectedWorkflowId,
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useListWorkflowsInfiniteInfiniteQuery } from 'services/api/endpoints/workflows';
import { useFilteredWorkflows } from './useFilteredWorkflows';
export const CanvasWorkflowIntegrationWorkflowSelector = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
const { data: workflowsData, isLoading } = useListWorkflowsInfiniteInfiniteQuery(
{
per_page: 100, // Get a reasonable number of workflows
page: 0,
},
{
selectFromResult: ({ data, isLoading }) => ({
data,
isLoading,
}),
}
);
const workflows = useMemo(() => {
if (!workflowsData) {
return [];
}
// Flatten all pages into a single list
return workflowsData.pages.flatMap((page) => page.items);
}, [workflowsData]);
// Filter workflows to only show those with ImageFields
const { filteredWorkflows, isFiltering } = useFilteredWorkflows(workflows);
const onChange = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const workflowId = e.target.value || null;
dispatch(canvasWorkflowIntegrationWorkflowSelected({ workflowId }));
},
[dispatch]
);
if (isLoading || isFiltering) {
return (
<Flex alignItems="center" gap={2}>
<Spinner size="sm" />
<Text>
{isFiltering
? t('controlLayers.workflowIntegration.filteringWorkflows')
: t('controlLayers.workflowIntegration.loadingWorkflows')}
</Text>
</Flex>
);
}
if (filteredWorkflows.length === 0) {
return (
<Text color="warning.400" fontSize="sm">
{workflows.length === 0
? t('controlLayers.workflowIntegration.noWorkflowsFound')
: t('controlLayers.workflowIntegration.noWorkflowsWithImageField')}
</Text>
);
}
return (
<FormControl>
<FormLabel>{t('controlLayers.workflowIntegration.selectWorkflow')}</FormLabel>
<Select
placeholder={t('controlLayers.workflowIntegration.selectPlaceholder')}
value={selectedWorkflowId || ''}
onChange={onChange}
>
{filteredWorkflows.map((workflow) => (
<option key={workflow.workflow_id} value={workflow.workflow_id}>
{workflow.name || t('controlLayers.workflowIntegration.unnamedWorkflow')}
</option>
))}
</Select>
</FormControl>
);
});
CanvasWorkflowIntegrationWorkflowSelector.displayName = 'CanvasWorkflowIntegrationWorkflowSelector';

View File

@@ -0,0 +1,548 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import {
Combobox,
Flex,
FormControl,
FormLabel,
IconButton,
Input,
Radio,
Select,
Switch,
Text,
Textarea,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { logger } from 'app/logging/logger';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import {
canvasWorkflowIntegrationFieldValueChanged,
canvasWorkflowIntegrationImageFieldSelected,
selectCanvasWorkflowIntegrationFieldValues,
selectCanvasWorkflowIntegrationSelectedImageFieldKey,
selectCanvasWorkflowIntegrationSelectedWorkflowId,
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import { DndImage } from 'features/dnd/DndImage';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
import type { AnyModelConfig, ImageDTO } from 'services/api/types';
const log = logger('canvas-workflow-integration');
interface WorkflowFieldRendererProps {
el: NodeFieldElement;
}
export const WorkflowFieldRenderer = memo(({ el }: WorkflowFieldRendererProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues);
const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey);
const templates = useStore($templates);
const { data: workflow } = useGetWorkflowQuery(selectedWorkflowId!, {
skip: !selectedWorkflowId,
});
// Load boards and models for BoardField and ModelIdentifierField
const { data: boardsData } = useListAllBoardsQuery({ include_archived: true });
const { data: modelsData, isLoading: isLoadingModels } = useGetModelConfigsQuery();
const { fieldIdentifier } = el.data;
const fieldKey = `${fieldIdentifier.nodeId}.${fieldIdentifier.fieldName}`;
log.debug({ fieldIdentifier, fieldKey }, 'Rendering workflow field');
// Get the node, field instance, and field template
const { field, fieldTemplate } = useMemo(() => {
if (!workflow?.workflow.nodes) {
log.warn('No workflow nodes found');
return { field: null, fieldTemplate: null };
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const foundNode = workflow.workflow.nodes.find((n: any) => n.data.id === fieldIdentifier.nodeId);
if (!foundNode) {
log.warn({ nodeId: fieldIdentifier.nodeId }, 'Node not found');
return { field: null, fieldTemplate: null };
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const foundField = (foundNode.data as any).inputs[fieldIdentifier.fieldName];
if (!foundField) {
log.warn({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName }, 'Field not found in node');
return { field: null, fieldTemplate: null };
}
// Get the field template from the invocation templates
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const nodeType = (foundNode.data as any).type;
const template = templates[nodeType];
if (!template) {
log.warn({ nodeType }, 'No template found for node type');
return { field: foundField, fieldTemplate: null };
}
const foundFieldTemplate = template.inputs[fieldIdentifier.fieldName];
if (!foundFieldTemplate) {
log.warn({ nodeType, fieldName: fieldIdentifier.fieldName }, 'Field template not found');
return { field: foundField, fieldTemplate: null };
}
return { field: foundField, fieldTemplate: foundFieldTemplate };
}, [workflow, fieldIdentifier, templates]);
// Get the current value from Redux or fallback to field default
const currentValue = useMemo(() => {
if (fieldValues && fieldKey in fieldValues) {
return fieldValues[fieldKey];
}
return field?.value ?? fieldTemplate?.default ?? '';
}, [fieldValues, fieldKey, field, fieldTemplate]);
// Get field type from the template
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const fieldType = fieldTemplate ? (fieldTemplate as any).type?.name : null;
const handleChange = useCallback(
(value: unknown) => {
dispatch(canvasWorkflowIntegrationFieldValueChanged({ fieldName: fieldKey, value }));
},
[dispatch, fieldKey]
);
const handleStringChange = useCallback(
(e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
handleChange(e.target.value);
},
[handleChange]
);
const handleNumberChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
const val = fieldType === 'IntegerField' ? parseInt(e.target.value, 10) : parseFloat(e.target.value);
handleChange(isNaN(val) ? 0 : val);
},
[handleChange, fieldType]
);
const handleBooleanChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
handleChange(e.target.checked);
},
[handleChange]
);
const handleSelectChange = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
handleChange(e.target.value);
},
[handleChange]
);
// SchedulerField handlers
const handleSchedulerChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterScheduler(v?.value)) {
return;
}
handleChange(v.value);
},
[handleChange]
);
const schedulerValue = useMemo(() => SCHEDULER_OPTIONS.find((o) => o.value === currentValue), [currentValue]);
// BoardField handlers
const handleBoardChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
return;
}
const value = v.value === 'auto' || v.value === 'none' ? v.value : { board_id: v.value };
handleChange(value);
},
[handleChange]
);
const boardOptions = useMemo<ComboboxOption[]>(() => {
const _options: ComboboxOption[] = [
{ label: t('common.auto'), value: 'auto' },
{ label: `${t('common.none')} (${t('boards.uncategorized')})`, value: 'none' },
];
if (boardsData) {
for (const board of boardsData) {
_options.push({
label: board.board_name,
value: board.board_id,
});
}
}
return _options;
}, [boardsData, t]);
const boardValue = useMemo(() => {
const _value = currentValue;
const autoOption = boardOptions[0];
const noneOption = boardOptions[1];
if (!_value || _value === 'auto') {
return autoOption;
}
if (_value === 'none') {
return noneOption;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const boardId = typeof _value === 'object' ? (_value as any).board_id : _value;
const boardOption = boardOptions.find((o) => o.value === boardId);
return boardOption ?? autoOption;
}, [currentValue, boardOptions]);
const noOptionsMessage = useCallback(() => t('boards.noMatching'), [t]);
// ModelIdentifierField handlers
const handleModelChange = useCallback(
(value: AnyModelConfig | null) => {
if (!value) {
return;
}
handleChange(value);
},
[handleChange]
);
const modelConfigs = useMemo(() => {
if (!modelsData) {
return EMPTY_ARRAY;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const ui_model_base = fieldTemplate ? (fieldTemplate as any)?.ui_model_base : null;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const ui_model_type = fieldTemplate ? (fieldTemplate as any)?.ui_model_type : null;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const ui_model_variant = fieldTemplate ? (fieldTemplate as any)?.ui_model_variant : null;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const ui_model_format = fieldTemplate ? (fieldTemplate as any)?.ui_model_format : null;
if (!ui_model_base && !ui_model_type) {
return modelConfigsAdapterSelectors.selectAll(modelsData);
}
return modelConfigsAdapterSelectors.selectAll(modelsData).filter((config) => {
if (ui_model_base && !ui_model_base.includes(config.base)) {
return false;
}
if (ui_model_type && !ui_model_type.includes(config.type)) {
return false;
}
if (ui_model_variant && 'variant' in config && config.variant && !ui_model_variant.includes(config.variant)) {
return false;
}
if (ui_model_format && !ui_model_format.includes(config.format)) {
return false;
}
return true;
});
}, [modelsData, fieldTemplate]);
// ImageField handler
const handleImageFieldSelect = useCallback(() => {
dispatch(canvasWorkflowIntegrationImageFieldSelected({ fieldKey }));
}, [dispatch, fieldKey]);
if (!field || !fieldTemplate) {
log.warn({ fieldIdentifier }, 'Field or template is null - not rendering');
return null;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const label = (field as any)?.label || (fieldTemplate as any)?.title || fieldIdentifier.fieldName;
// Log the entire field structure to understand its shape
log.debug(
{ fieldType, label, currentValue, fieldStructure: field, fieldTemplateStructure: fieldTemplate },
'Field info'
);
// ImageField - allow user to select which one receives the canvas image
if (fieldType === 'ImageField') {
return (
<ImageFieldComponent
label={label}
fieldKey={fieldKey}
currentValue={currentValue}
selectedImageFieldKey={selectedImageFieldKey}
fieldTemplate={fieldTemplate}
handleImageFieldSelect={handleImageFieldSelect}
handleChange={handleChange}
/>
);
}
// Render different input types based on field type
if (fieldType === 'StringField') {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const isTextarea = (fieldTemplate as any)?.ui_component === 'textarea';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const isRequired = (fieldTemplate as any)?.required ?? false;
if (isTextarea) {
return (
<FormControl isRequired={isRequired}>
<FormLabel>{label}</FormLabel>
<Textarea
value={String(currentValue)}
onChange={handleStringChange}
placeholder={label}
rows={3}
isRequired={isRequired}
/>
</FormControl>
);
}
return (
<FormControl isRequired={isRequired}>
<FormLabel>{label}</FormLabel>
<Input value={String(currentValue)} onChange={handleStringChange} placeholder={label} isRequired={isRequired} />
</FormControl>
);
}
if (fieldType === 'IntegerField' || fieldType === 'FloatField') {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const min = (fieldTemplate as any)?.minimum;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const max = (fieldTemplate as any)?.maximum;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const step = fieldType === 'IntegerField' ? 1 : ((fieldTemplate as any)?.multipleOf ?? 0.01);
return (
<FormControl>
<FormLabel>{label}</FormLabel>
<Flex gap={2} alignItems="center">
<Input
type="number"
value={Number(currentValue)}
onChange={handleNumberChange}
min={min}
max={max}
step={step}
/>
</Flex>
</FormControl>
);
}
if (fieldType === 'BooleanField') {
return (
<FormControl>
<FormLabel>{label}</FormLabel>
<Switch isChecked={Boolean(currentValue)} onChange={handleBooleanChange} />
</FormControl>
);
}
if (fieldType === 'EnumField') {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const options = (fieldTemplate as any)?.options ?? (fieldTemplate as any)?.ui_choice_labels ?? [];
const optionsList = Array.isArray(options) ? options : Object.keys(options);
return (
<FormControl>
<FormLabel>{label}</FormLabel>
<Select value={String(currentValue)} onChange={handleSelectChange}>
{optionsList.map((option: string) => (
<option key={option} value={option}>
{option}
</option>
))}
</Select>
</FormControl>
);
}
if (fieldType === 'SchedulerField') {
return (
<FormControl>
<FormLabel>{label}</FormLabel>
<Combobox value={schedulerValue} options={SCHEDULER_OPTIONS} onChange={handleSchedulerChange} />
</FormControl>
);
}
if (fieldType === 'BoardField') {
return (
<FormControl>
<FormLabel>{label}</FormLabel>
<Combobox
value={boardValue}
options={boardOptions}
onChange={handleBoardChange}
placeholder={t('boards.selectBoard')}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
);
}
if (fieldType === 'ModelIdentifierField') {
return (
<FormControl>
<FormLabel>{label}</FormLabel>
<ModelFieldCombobox
value={currentValue}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoadingModels}
onChange={handleModelChange}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
required={(fieldTemplate as any).required}
groupByType
/>
</FormControl>
);
}
// For other field types, show a read-only message
log.warn(`Unsupported field type "${fieldType}" for field "${label}" - showing as read-only`);
return (
<FormControl>
<FormLabel>{label}</FormLabel>
<Input value={`${fieldType} (read-only)`} isReadOnly />
</FormControl>
);
});
WorkflowFieldRenderer.displayName = 'WorkflowFieldRenderer';
// Separate component for ImageField to avoid conditional hooks
interface ImageFieldComponentProps {
label: string;
fieldKey: string;
currentValue: unknown;
selectedImageFieldKey: string | null;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fieldTemplate: any;
handleImageFieldSelect: () => void;
handleChange: (value: unknown) => void;
}
const ImageFieldComponent = memo(
({
label,
fieldKey,
currentValue,
selectedImageFieldKey,
fieldTemplate,
handleImageFieldSelect,
handleChange,
}: ImageFieldComponentProps) => {
const { t } = useTranslation();
const isSelected = selectedImageFieldKey === fieldKey;
// Get image from field values (uploaded image) or from workflow field (default/saved image)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const imageValue = currentValue as any;
const imageName = imageValue?.image_name;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
const handleImageUpload = useCallback(
(uploadedImage: ImageDTO) => {
handleChange(uploadedImage);
},
[handleChange]
);
const handleImageClear = useCallback(() => {
handleChange(undefined);
}, [handleChange]);
return (
<FormControl overflow="hidden">
<Flex alignItems="center" gap={2} mb={2}>
<Radio isChecked={isSelected} onChange={handleImageFieldSelect} />
<FormLabel mb={0} cursor="pointer" onClick={handleImageFieldSelect}>
{label}
</FormLabel>
</Flex>
<Text fontSize="sm" color="base.400" ml={6} mb={2}>
{isSelected
? t('controlLayers.workflowIntegration.imageFieldSelected')
: t('controlLayers.workflowIntegration.imageFieldNotSelected')}
</Text>
{/* Show image upload/preview for non-selected fields */}
{!isSelected && (
<Flex ml={6} position="relative" h={32} alignItems="stretch" maxW="calc(100% - 1.5rem)">
{!imageDTO && (
<UploadImageIconButton
w="full"
h="auto"
// eslint-disable-next-line @typescript-eslint/no-explicit-any
isError={(fieldTemplate as any)?.required && !imageValue}
onUpload={handleImageUpload}
fontSize={24}
/>
)}
{imageDTO && (
<Flex gap={2} alignItems="center" maxW="full">
<Flex
borderRadius="base"
borderWidth={1}
borderStyle="solid"
overflow="hidden"
position="relative"
h={32}
maxH={32}
>
<DndImage imageDTO={imageDTO} asThumbnail />
<Text
position="absolute"
background="base.900"
color="base.50"
fontSize="sm"
fontWeight="semibold"
insetInlineEnd={1}
insetBlockEnd={1}
opacity={0.7}
px={2}
borderRadius="base"
pointerEvents="none"
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
</Flex>
<IconButton
aria-label={t('common.clearImage', 'Clear image')}
icon={<PiTrashSimpleBold />}
onClick={handleImageClear}
size="sm"
variant="ghost"
colorScheme="error"
/>
</Flex>
)}
</Flex>
)}
</FormControl>
);
}
);
ImageFieldComponent.displayName = 'ImageFieldComponent';

View File

@@ -0,0 +1,289 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { logger } from 'app/logging/logger';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { WorkflowFieldRenderer } from 'features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer';
import {
canvasWorkflowIntegrationImageFieldSelected,
selectCanvasWorkflowIntegrationFieldValues,
selectCanvasWorkflowIntegrationSelectedImageFieldKey,
selectCanvasWorkflowIntegrationSelectedWorkflowId,
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import {
ContainerContextProvider,
DepthContextProvider,
useContainerContext,
useDepthContext,
} from 'features/nodes/components/sidePanel/builder/contexts';
import { DividerElement } from 'features/nodes/components/sidePanel/builder/DividerElement';
import { HeadingElement } from 'features/nodes/components/sidePanel/builder/HeadingElement';
import { TextElement } from 'features/nodes/components/sidePanel/builder/TextElement';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { FormElement } from 'features/nodes/types/workflow';
import {
CONTAINER_CLASS_NAME,
isContainerElement,
isDividerElement,
isHeadingElement,
isNodeFieldElement,
isTextElement,
ROOT_CONTAINER_CLASS_NAME,
} from 'features/nodes/types/workflow';
import { memo, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
const log = logger('canvas-workflow-integration');
const rootViewModeSx: SystemStyleObject = {
borderRadius: 'base',
w: 'full',
h: 'full',
gap: 2,
display: 'flex',
flex: 1,
maxW: '768px',
'&[data-self-layout="column"]': {
flexDir: 'column',
alignItems: 'stretch',
},
'&[data-self-layout="row"]': {
flexDir: 'row',
alignItems: 'flex-start',
},
};
const containerViewModeSx: SystemStyleObject = {
gap: 2,
'&[data-self-layout="column"]': {
flexDir: 'column',
alignItems: 'stretch',
},
'&[data-self-layout="row"]': {
flexDir: 'row',
alignItems: 'flex-start',
overflowX: 'auto',
overflowY: 'visible',
h: 'min-content',
flexShrink: 0,
},
'&[data-parent-layout="column"]': {
w: 'full',
h: 'min-content',
},
'&[data-parent-layout="row"]': {
flex: '1 1 0',
minW: 32,
},
};
export const WorkflowFormPreview = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey);
const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues);
const templates = useStore($templates);
const { data: workflow, isLoading } = useGetWorkflowQuery(selectedWorkflowId!, {
skip: !selectedWorkflowId,
});
const elements = useMemo((): Record<string, FormElement> => {
if (!workflow?.workflow.form) {
return {};
}
const els = workflow.workflow.form.elements as Record<string, FormElement>;
log.debug({ elementCount: Object.keys(els).length, elementIds: Object.keys(els) }, 'Form elements loaded');
return els;
}, [workflow]);
const rootElementId = useMemo((): string => {
if (!workflow?.workflow.form) {
return '';
}
const rootId = workflow.workflow.form.rootElementId as string;
log.debug({ rootElementId: rootId }, 'Root element ID');
return rootId;
}, [workflow]);
// Auto-select the image field if there's only one unfilled ImageField
useEffect(() => {
// Don't auto-select if user already selected one
if (selectedImageFieldKey) {
return;
}
if (!workflow?.workflow.nodes || Object.keys(elements).length === 0) {
return;
}
const unfilledImageFieldKeys: string[] = [];
for (const element of Object.values(elements)) {
if (!isNodeFieldElement(element)) {
continue;
}
const { fieldIdentifier } = element.data;
const fieldKey = `${fieldIdentifier.nodeId}.${fieldIdentifier.fieldName}`;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const node = workflow.workflow.nodes.find((n: any) => n.data?.id === fieldIdentifier.nodeId);
if (!node) {
continue;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const nodeType = (node.data as any)?.type;
const template = templates[nodeType];
if (!template?.inputs) {
continue;
}
const fieldTemplate = template.inputs[fieldIdentifier.fieldName];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const fieldType = (fieldTemplate as any)?.type?.name;
if (fieldType !== 'ImageField') {
continue;
}
// Check if the field already has a value
const hasReduxValue = fieldValues && fieldKey in fieldValues && fieldValues[fieldKey]?.image_name;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const fieldInstance = (node.data as any)?.inputs?.[fieldIdentifier.fieldName];
const hasWorkflowValue = fieldInstance?.value?.image_name;
if (!hasReduxValue && !hasWorkflowValue) {
unfilledImageFieldKeys.push(fieldKey);
}
}
if (unfilledImageFieldKeys.length === 1) {
log.debug({ fieldKey: unfilledImageFieldKeys[0] }, 'Auto-selecting the only unfilled ImageField');
dispatch(canvasWorkflowIntegrationImageFieldSelected({ fieldKey: unfilledImageFieldKeys[0]! }));
}
}, [workflow, elements, templates, selectedImageFieldKey, fieldValues, dispatch]);
if (isLoading) {
return (
<Flex alignItems="center" gap={2}>
<Spinner size="sm" />
<Text>{t('controlLayers.workflowIntegration.loadingParameters')}</Text>
</Flex>
);
}
if (!workflow) {
return null;
}
// If workflow has no form builder, it should have been filtered out
// This is a fallback in case something went wrong
if (Object.keys(elements).length === 0 || !rootElementId) {
return (
<Text fontSize="sm" color="error.400">
{t('controlLayers.workflowIntegration.noFormBuilderError')}
</Text>
);
}
const rootElement = elements[rootElementId];
if (!rootElement || !isContainerElement(rootElement)) {
return null;
}
const { id, data } = rootElement;
const { children, layout } = data;
return (
<DepthContextProvider depth={0}>
<ContainerContextProvider id={id} layout={layout}>
<Box id={id} className={ROOT_CONTAINER_CLASS_NAME} sx={rootViewModeSx} data-self-layout={layout} data-depth={0}>
{children.map((childId) => (
<FormElementComponentPreview key={childId} id={childId} elements={elements} />
))}
</Box>
</ContainerContextProvider>
</DepthContextProvider>
);
});
WorkflowFormPreview.displayName = 'WorkflowFormPreview';
const FormElementComponentPreview = memo(({ id, elements }: { id: string; elements: Record<string, FormElement> }) => {
const el = elements[id];
if (!el) {
log.warn({ id }, 'Element not found in elements map');
return null;
}
log.debug({ id, type: el.type }, 'Rendering form element');
if (isContainerElement(el)) {
return <ContainerElementPreview el={el} elements={elements} />;
}
if (isDividerElement(el)) {
return <DividerElement id={id} />;
}
if (isHeadingElement(el)) {
return <HeadingElement id={id} />;
}
if (isTextElement(el)) {
return <TextElement id={id} />;
}
if (isNodeFieldElement(el)) {
return <WorkflowFieldRenderer el={el} />;
}
// If we get here, it's an unknown element type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
log.warn({ id, type: (el as any).type }, 'Unknown element type - not rendering');
return null;
});
FormElementComponentPreview.displayName = 'FormElementComponentPreview';
const ContainerElementPreview = memo(({ el, elements }: { el: FormElement; elements: Record<string, FormElement> }) => {
const { t } = useTranslation();
const depth = useDepthContext();
const containerCtx = useContainerContext();
if (!isContainerElement(el)) {
return null;
}
const { id, data } = el;
const { children, layout } = data;
return (
<DepthContextProvider depth={depth + 1}>
<ContainerContextProvider id={id} layout={layout}>
<Flex
id={id}
className={CONTAINER_CLASS_NAME}
sx={containerViewModeSx}
data-self-layout={layout}
data-depth={depth}
data-parent-layout={containerCtx.layout}
>
{children.map((childId) => (
<FormElementComponentPreview key={childId} id={childId} elements={elements} />
))}
{children.length === 0 && (
<Flex p={8} w="full" h="full" alignItems="center" justifyContent="center">
<Text color="base.500" fontSize="sm" fontStyle="oblique 10deg">
{t('workflows.builder.emptyContainer')}
</Text>
</Flex>
)}
</Flex>
</ContainerContextProvider>
</DepthContextProvider>
);
});
ContainerElementPreview.displayName = 'ContainerElementPreview';

View File

@@ -0,0 +1,302 @@
import { logger } from 'app/logging/logger';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
canvasWorkflowIntegrationClosed,
canvasWorkflowIntegrationProcessingCompleted,
canvasWorkflowIntegrationProcessingStarted,
selectCanvasWorkflowIntegrationFieldValues,
selectCanvasWorkflowIntegrationSelectedImageFieldKey,
selectCanvasWorkflowIntegrationSelectedWorkflowId,
selectCanvasWorkflowIntegrationSourceEntityIdentifier,
} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import { CANVAS_OUTPUT_PREFIX, getPrefixedId } from 'features/nodes/util/graph/graphBuilderUtils';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { useLazyGetWorkflowQuery } from 'services/api/endpoints/workflows';
const log = logger('canvas-workflow-integration');
export const useCanvasWorkflowIntegrationExecute = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const canvasManager = useCanvasManager();
const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId);
const sourceEntityIdentifier = useAppSelector(selectCanvasWorkflowIntegrationSourceEntityIdentifier);
const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues);
const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey);
const canvasSessionId = useAppSelector(selectCanvasSessionId);
const [getWorkflow] = useLazyGetWorkflowQuery();
const canExecute = useMemo(() => {
return Boolean(selectedWorkflowId && sourceEntityIdentifier);
}, [selectedWorkflowId, sourceEntityIdentifier]);
const execute = useCallback(async () => {
if (!selectedWorkflowId || !sourceEntityIdentifier || !canvasManager) {
return;
}
try {
dispatch(canvasWorkflowIntegrationProcessingStarted());
// 1. Extract the canvas layer as an image
const adapter = canvasManager.getAdapter(sourceEntityIdentifier);
if (!adapter) {
throw new Error('Could not find canvas entity adapter');
}
const rect = adapter.transformer.getRelativeRect();
const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } });
// 2. Fetch the workflow
const { data: workflow } = await getWorkflow(selectedWorkflowId);
if (!workflow) {
throw new Error('Failed to load workflow');
}
// 3. Build the workflow graph with the canvas image
// Use the user-selected ImageField, or find one automatically
let imageFieldIdentifier: { nodeId: string; fieldName: string } | undefined;
// Method 1: Use user-selected ImageField (preferred)
if (selectedImageFieldKey) {
const [nodeId, fieldName] = selectedImageFieldKey.split('.');
if (nodeId && fieldName) {
imageFieldIdentifier = { nodeId, fieldName };
}
}
// Method 2: Search through form elements for an ImageField (fallback)
if (!imageFieldIdentifier && workflow.workflow.form && workflow.workflow.form.elements) {
for (const element of Object.values(workflow.workflow.form.elements)) {
if (element.type !== 'node-field') {
continue;
}
const fieldIdentifier = element.data?.fieldIdentifier;
if (!fieldIdentifier) {
continue;
}
// @ts-expect-error - node data type is complex
const node = workflow.workflow.nodes.find((n) => n.data.id === fieldIdentifier.nodeId);
if (!node) {
continue;
}
// @ts-expect-error - node.data type is complex
if (node.data.type === 'image') {
imageFieldIdentifier = fieldIdentifier;
break;
}
// Check if field type is ImageField
// @ts-expect-error - field type is complex
const field = node.data.inputs[fieldIdentifier.fieldName];
if (field?.type?.name === 'ImageField') {
imageFieldIdentifier = fieldIdentifier;
break;
}
}
}
// Method 3: Fallback to exposedFields
if (!imageFieldIdentifier && workflow.workflow.exposedFields) {
imageFieldIdentifier = workflow.workflow.exposedFields.find((fieldIdentifier) => {
// @ts-expect-error - node data type is complex
const node = workflow.workflow.nodes.find((n) => n.data.id === fieldIdentifier.nodeId);
if (!node) {
return false;
}
// @ts-expect-error - node.data type is complex
if (node.data.type === 'image') {
return true;
}
// Check if field type is ImageField
// @ts-expect-error - field type is complex
const field = node.data.inputs[fieldIdentifier.fieldName];
return field?.type?.name === 'ImageField';
});
}
if (!imageFieldIdentifier) {
throw new Error('Workflow does not have an image input field in the Form Builder');
}
// Update the workflow nodes with canvas image and user field values
const updatedWorkflow = {
...workflow.workflow,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
nodes: workflow.workflow.nodes.map((node: any) => {
const nodeId = node.data.id;
let updatedInputs = { ...node.data.inputs };
const updatedData = { ...node.data };
let hasChanges = false;
// Apply image field if this is the image node
if (nodeId === imageFieldIdentifier.nodeId) {
updatedInputs[imageFieldIdentifier.fieldName] = {
...updatedInputs[imageFieldIdentifier.fieldName],
value: imageDTO,
};
hasChanges = true;
}
// Apply any field values from Redux state
if (fieldValues) {
Object.entries(fieldValues).forEach(([fieldKey, value]) => {
const [fieldNodeId, fieldName] = fieldKey.split('.');
if (fieldNodeId && fieldName && fieldNodeId === nodeId && updatedInputs[fieldName]) {
updatedInputs[fieldName] = {
...updatedInputs[fieldName],
value: value,
};
hasChanges = true;
}
});
}
// If anything was modified, return updated node
if (hasChanges) {
updatedData.inputs = updatedInputs;
return {
...node,
data: updatedData,
};
}
return node;
}),
};
// Validate that the workflow has a canvas_output node
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const hasCanvasOutputNode = updatedWorkflow.nodes.some((node: any) => node.data?.type === 'canvas_output');
if (!hasCanvasOutputNode) {
throw new Error('Workflow does not have a Canvas Output node');
}
// 4. Convert workflow to graph format
const graphNodes: Record<string, unknown> = {};
const nodeIdMapping: Record<string, string> = {}; // Map original IDs to new IDs
for (const node of updatedWorkflow.nodes) {
const nodeData = node.data;
const isCanvasOutputNode = nodeData.type === 'canvas_output';
// Prefix canvas_output node IDs so the staging area can find them
const nodeId = isCanvasOutputNode ? getPrefixedId(CANVAS_OUTPUT_PREFIX) : nodeData.id;
nodeIdMapping[nodeData.id] = nodeId;
const invocation: Record<string, unknown> = {
id: nodeId,
type: nodeData.type,
// Canvas output nodes are always intermediate (they go to the staging area, not gallery)
is_intermediate: isCanvasOutputNode ? true : (nodeData.isIntermediate ?? false),
use_cache: nodeData.useCache ?? true,
};
// Add input values to the invocation
for (const [fieldName, fieldData] of Object.entries(nodeData.inputs)) {
const fieldValue = (fieldData as { value?: unknown }).value;
if (fieldValue === undefined) {
continue;
}
// The frontend stores board fields as 'auto', 'none', or { board_id: string }.
// The backend expects null or { board_id: string }. Translate accordingly.
if (fieldName === 'board') {
if (fieldValue === 'auto' || fieldValue === 'none' || fieldValue === null) {
continue;
}
}
invocation[fieldName] = fieldValue;
}
graphNodes[nodeId] = invocation;
}
// Convert edges to graph format, using the node ID mapping
const edgesArray = updatedWorkflow.edges as Array<{
source: string;
target: string;
sourceHandle: string;
targetHandle: string;
}>;
const graphEdges = edgesArray.map((edge) => ({
source: {
node_id: nodeIdMapping[edge.source] || edge.source,
field: edge.sourceHandle,
},
destination: {
node_id: nodeIdMapping[edge.target] || edge.target,
field: edge.targetHandle,
},
}));
const graph = {
id: workflow.workflow.id || workflow.workflow_id || 'temp',
nodes: graphNodes,
edges: graphEdges,
};
log.debug({ workflowName: workflow.name, destination: canvasSessionId }, 'Enqueueing workflow on canvas');
await dispatch(
queueApi.endpoints.enqueueBatch.initiate({
batch: {
workflow: updatedWorkflow,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
graph: graph as any,
runs: 1,
origin: 'canvas_workflow_integration',
destination: canvasSessionId,
},
prepend: false,
})
).unwrap();
// 5. Close the modal and show success message
// Results will appear in the staging area where user can accept/discard them
toast({
status: 'success',
title: t('controlLayers.workflowIntegration.executionStarted'),
description: t('controlLayers.workflowIntegration.executionStartedDescription'),
});
dispatch(canvasWorkflowIntegrationClosed());
} catch (error) {
log.error('Error executing workflow');
dispatch(canvasWorkflowIntegrationProcessingCompleted());
toast({
status: 'error',
title: t('controlLayers.workflowIntegration.executionFailed'),
description: error instanceof Error ? error.message : 'Unknown error',
});
}
}, [
selectedWorkflowId,
sourceEntityIdentifier,
canvasManager,
dispatch,
getWorkflow,
t,
fieldValues,
selectedImageFieldKey,
canvasSessionId,
]);
return {
execute,
canExecute,
};
};

View File

@@ -0,0 +1,107 @@
import { logger } from 'app/logging/logger';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback, useEffect, useState } from 'react';
import { workflowsApi } from 'services/api/endpoints/workflows';
import type { paths } from 'services/api/schema';
import { workflowHasImageField } from './workflowHasImageField';
const log = logger('canvas-workflow-integration');
type WorkflowListItem =
paths['/api/v1/workflows/']['get']['responses']['200']['content']['application/json']['items'][number];
interface UseFilteredWorkflowsResult {
filteredWorkflows: WorkflowListItem[];
isFiltering: boolean;
}
/**
* Hook that filters workflows to only include those with at least one ImageField
* @param workflows The list of workflows to filter
* @returns Filtered list of workflows that have ImageFields
*/
export function useFilteredWorkflows(workflows: WorkflowListItem[]): UseFilteredWorkflowsResult {
const dispatch = useAppDispatch();
const [filteredWorkflows, setFilteredWorkflows] = useState<WorkflowListItem[]>([]);
const [isFiltering, setIsFiltering] = useState(false);
const filterWorkflows = useCallback(async () => {
if (workflows.length === 0) {
setFilteredWorkflows([]);
return;
}
setIsFiltering(true);
try {
// Load all workflows in parallel and check for ImageFields
const workflowChecks = await Promise.all(
workflows.map(async (workflow) => {
try {
// Fetch the full workflow data using dispatch
const result = await dispatch(
workflowsApi.endpoints.getWorkflow.initiate(workflow.workflow_id, {
subscribe: false,
forceRefetch: false,
})
);
// Get the data from the result
const data = 'data' in result ? result.data : undefined;
const hasImageField = workflowHasImageField(data);
log.debug(
{ workflowId: workflow.workflow_id, name: workflow.name, hasImageField },
'Checked workflow for ImageField'
);
// Clean up the subscription
if ('unsubscribe' in result && typeof result.unsubscribe === 'function') {
result.unsubscribe();
}
return {
workflow,
hasImageField,
};
} catch (error) {
log.error(
{
error: error instanceof Error ? error.message : String(error),
workflowId: workflow.workflow_id,
},
'Error checking workflow for ImageField'
);
return {
workflow,
hasImageField: false,
};
}
})
);
// Filter to only include workflows with ImageFields
const filtered = workflowChecks.filter((check) => check.hasImageField).map((check) => check.workflow);
log.debug({ totalWorkflows: workflows.length, filteredCount: filtered.length }, 'Filtered workflows');
setFilteredWorkflows(filtered);
} catch (error) {
log.error({ error: error instanceof Error ? error.message : String(error) }, 'Error filtering workflows');
setFilteredWorkflows([]);
} finally {
setIsFiltering(false);
}
}, [workflows, dispatch]);
useEffect(() => {
filterWorkflows();
}, [filterWorkflows]);
return {
filteredWorkflows,
isFiltering,
};
}

View File

@@ -0,0 +1,86 @@
import { logger } from 'app/logging/logger';
import { $templates } from 'features/nodes/store/nodesSlice';
import { isNodeFieldElement } from 'features/nodes/types/workflow';
import type { paths } from 'services/api/schema';
const log = logger('canvas-workflow-integration');
type WorkflowResponse =
paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'];
/**
* Checks if a workflow is compatible with canvas workflow integration.
* Requirements:
* 1. Has a Form Builder (allows users to modify parameters)
* 2. Has a canvas_output node (explicit canvas output target)
* 3. Has at least one ImageField in the Form Builder (receives the canvas image)
* @param workflow The workflow to check
* @returns true if the workflow meets all requirements, false otherwise
*/
export function workflowHasImageField(workflow: WorkflowResponse | undefined): boolean {
if (!workflow?.workflow) {
log.debug('No workflow data provided');
return false;
}
// Only workflows with Form Builder are supported
// Workflows without Form Builder don't allow changing models and other parameters
if (!workflow.workflow.form?.elements) {
log.debug('Workflow has no form builder - excluding from list');
return false;
}
// Must have a canvas_output node to define where the output image goes
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const hasCanvasOutputNode = workflow.workflow.nodes?.some((n: any) => (n.data as any)?.type === 'canvas_output');
if (!hasCanvasOutputNode) {
log.debug('Workflow has no canvas_output node - excluding from list');
return false;
}
const templates = $templates.get();
const elements = workflow.workflow.form.elements;
log.debug('Workflow has form builder and canvas_output node, checking form elements for ImageField');
for (const [elementId, element] of Object.entries(elements)) {
if (isNodeFieldElement(element)) {
const { fieldIdentifier } = element.data;
// Find the node that contains this field
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const node = workflow.workflow.nodes?.find((n: any) => n.data?.id === fieldIdentifier.nodeId);
if (!node) {
continue;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const nodeType = (node.data as any)?.type;
if (!nodeType) {
continue;
}
const template = templates[nodeType];
if (!template?.inputs) {
continue;
}
const fieldTemplate = template.inputs[fieldIdentifier.fieldName];
if (!fieldTemplate) {
continue;
}
// Check if this is an ImageField
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const fieldType = (fieldTemplate as any).type?.name;
if (fieldType === 'ImageField') {
log.debug({ elementId, fieldName: fieldIdentifier.fieldName }, 'Found ImageField in workflow form');
return true;
}
}
}
// If we have a form but no ImageFields were found in it, return false
log.debug('Workflow has form builder but no ImageField found in form elements');
return false;
}

View File

@@ -6,6 +6,7 @@ import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/c
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { CanvasEntityMenuItemsFilter } from 'features/controlLayers/components/common/CanvasEntityMenuItemsFilter';
import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/components/common/CanvasEntityMenuItemsMergeDown';
import { CanvasEntityMenuItemsRunWorkflow } from 'features/controlLayers/components/common/CanvasEntityMenuItemsRunWorkflow';
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
@@ -25,6 +26,7 @@ export const RasterLayerMenuItems = memo(() => {
</IconMenuItemGroup>
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsRunWorkflow />
<CanvasEntityMenuItemsSelectObject />
<RasterLayerMenuItemsAdjustments />
<MenuDivider />

View File

@@ -13,10 +13,10 @@ import {
import { DndImage } from 'features/dnd/DndImage';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import type { S } from 'services/api/types';
import { useOutputImageDTO, useStagingAreaContext } from './context';
import { useStagingAreaContext } from './context';
import { QueueItemNumber } from './QueueItemNumber';
import type { StagingEntry } from './state';
const sx = {
cursor: 'pointer',
@@ -37,21 +37,23 @@ const sx = {
} satisfies SystemStyleObject;
type Props = {
item: S['SessionQueueItem'];
entry: StagingEntry;
index: number;
};
export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
export const QueueItemPreviewMini = memo(({ entry, index }: Props) => {
const ctx = useStagingAreaContext();
const dispatch = useAppDispatch();
const $isSelected = useMemo(() => ctx.buildIsSelectedComputed(item.item_id), [ctx, item.item_id]);
const $isSelected = useMemo(
() => ctx.buildIsSelectedComputed(entry.item.item_id, entry.imageIndex),
[ctx, entry.item.item_id, entry.imageIndex]
);
const isSelected = useStore($isSelected);
const imageDTO = useOutputImageDTO(item.item_id);
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
const onClick = useCallback(() => {
ctx.select(item.item_id);
}, [ctx, item.item_id]);
ctx.select(entry.item.item_id, entry.imageIndex);
}, [ctx, entry.item.item_id, entry.imageIndex]);
const onDoubleClick = useCallback(() => {
if (autoSwitch !== 'off') {
@@ -63,8 +65,8 @@ export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
}, [autoSwitch, dispatch]);
const onLoad = useCallback(() => {
ctx.onImageLoaded(item.item_id);
}, [ctx, item.item_id]);
ctx.onImageLoaded(entry.item.item_id);
}, [ctx, entry.item.item_id]);
return (
<Flex
@@ -74,11 +76,17 @@ export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
onClick={onClick}
onDoubleClick={onDoubleClick}
>
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} position="absolute" onLoad={onLoad} />}
<QueueItemProgressImage itemId={item.item_id} position="absolute" />
<QueueItemStatusLabel item={entry.item} position="absolute" margin="auto" />
{entry.imageDTO && <DndImage imageDTO={entry.imageDTO} position="absolute" onLoad={onLoad} />}
<QueueItemProgressImage itemId={entry.item.item_id} position="absolute" />
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
<QueueItemCircularProgress
itemId={entry.item.item_id}
status={entry.item.status}
position="absolute"
top={1}
right={2}
/>
</Flex>
);
});

View File

@@ -8,10 +8,10 @@ import type { CSSProperties, RefObject } from 'react';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import type { Components, ComputeItemKey, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import type { S } from 'services/api/types';
import { useStagingAreaContext } from './context';
import { getQueueItemElementId, STAGING_AREA_THUMBNAIL_STRIP_HEIGHT } from './shared';
import type { StagingEntry } from './state';
const log = logger('system');
@@ -141,7 +141,7 @@ export const StagingAreaItemsList = memo(() => {
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const items = useStore(ctx.$items);
const entries = useStore(ctx.$entries);
const scrollerRef = useScrollableStagingArea(rootRef);
@@ -172,9 +172,9 @@ export const StagingAreaItemsList = memo(() => {
return (
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
<Virtuoso<S['SessionQueueItem']>
<Virtuoso<StagingEntry>
ref={virtuosoRef}
data={items}
data={entries}
horizontalDirection
style={virtuosoStyles}
computeItemKey={computeItemKey}
@@ -183,19 +183,19 @@ export const StagingAreaItemsList = memo(() => {
components={components}
rangeChanged={onRangeChanged}
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], void>['scrollerRef']}
scrollerRef={scrollerRef as VirtuosoProps<StagingEntry, void>['scrollerRef']}
/>
</Box>
);
});
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
const computeItemKey: ComputeItemKey<S['SessionQueueItem'], void> = (_, item: S['SessionQueueItem']) => {
return item.item_id;
const computeItemKey: ComputeItemKey<StagingEntry, void> = (_, entry: StagingEntry) => {
return `${entry.item.item_id}-${entry.imageIndex}`;
};
const itemContent: ItemContent<S['SessionQueueItem'], void> = (index, item) => (
<QueueItemPreviewMini key={`${item.item_id}-mini`} item={item} index={index} />
const itemContent: ItemContent<StagingEntry, void> = (index, entry) => (
<QueueItemPreviewMini key={`${entry.item.item_id}-${entry.imageIndex}-mini`} entry={entry} index={index} />
);
const listSx = {
@@ -207,7 +207,7 @@ const listSx = {
},
};
const components: Components<S['SessionQueueItem']> = {
const components: Components<StagingEntry> = {
List: forwardRef(({ context: _, ...rest }, ref) => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);

View File

@@ -84,12 +84,13 @@ export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWi
y: y + Math.round((bboxRect.height - scaledHeight) / 2),
};
const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState());
const overrides: Partial<CanvasRasterLayerState> = {
position,
objects: [imageObject],
};
store.dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
store.dispatch(canvasSessionReset());
store.dispatch(
queueApi.endpoints.cancelQueueItemsByDestination.initiate({ destination: sessionId }, { track: false })
@@ -129,12 +130,6 @@ export const useStagingAreaContext = () => {
return ctx;
};
export const useOutputImageDTO = (itemId: number) => {
const ctx = useStagingAreaContext();
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });
return allProgressData[itemId]?.imageDTO ?? null;
};
export const useProgressDatum = (itemId: number): ProgressData => {
const ctx = useStagingAreaContext();
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });

View File

@@ -1,7 +1,7 @@
import type { S } from 'services/api/types';
import { describe, expect, it } from 'vitest';
import { getOutputImageName, getProgressMessage, getQueueItemElementId } from './shared';
import { getOutputImageNames, getProgressMessage, getQueueItemElementId } from './shared';
describe('StagingAreaApi Utility Functions', () => {
describe('getProgressMessage', () => {
@@ -34,7 +34,7 @@ describe('StagingAreaApi Utility Functions', () => {
});
});
describe('getOutputImageName', () => {
describe('getOutputImageNames', () => {
it('should extract image name from completed queue item', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
@@ -61,10 +61,10 @@ describe('StagingAreaApi Utility Functions', () => {
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe('test-output.png');
expect(getOutputImageNames(queueItem)).toEqual(['test-output.png']);
});
it('should return null when no canvas output node found', () => {
it('should return empty array when no canvas output node found', () => {
const queueItem = {
item_id: 1,
status: 'completed',
@@ -93,10 +93,10 @@ describe('StagingAreaApi Utility Functions', () => {
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
expect(getOutputImageNames(queueItem)).toEqual([]);
});
it('should return null when output node has no results', () => {
it('should return empty array when output node has no results', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
@@ -116,10 +116,10 @@ describe('StagingAreaApi Utility Functions', () => {
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
expect(getOutputImageNames(queueItem)).toEqual([]);
});
it('should return null when results contain no image fields', () => {
it('should return empty array when results contain no image fields', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
@@ -144,10 +144,10 @@ describe('StagingAreaApi Utility Functions', () => {
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
expect(getOutputImageNames(queueItem)).toEqual([]);
});
it('should handle multiple outputs and return first image', () => {
it('should collect images from multiple canvas_output nodes', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
@@ -161,15 +161,17 @@ describe('StagingAreaApi Utility Functions', () => {
session: {
id: 'test-session',
source_prepared_mapping: {
canvas_output: ['output-node-id'],
'canvas_output:abc123': ['output-node-1'],
'canvas_output:def456': ['output-node-2'],
},
results: {
'output-node-id': {
text: 'some text',
first_image: {
'output-node-1': {
image: {
image_name: 'first-image.png',
},
second_image: {
},
'output-node-2': {
image: {
image_name: 'second-image.png',
},
},
@@ -177,8 +179,8 @@ describe('StagingAreaApi Utility Functions', () => {
},
} as unknown as S['SessionQueueItem'];
const result = getOutputImageName(queueItem);
expect(result).toBe('first-image.png');
const result = getOutputImageNames(queueItem);
expect(result).toEqual(['first-image.png', 'second-image.png']);
});
it('should return first image from image collections', () => {
@@ -226,7 +228,7 @@ describe('StagingAreaApi Utility Functions', () => {
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
expect(getOutputImageNames(queueItem)).toEqual([]);
});
});
});

View File

@@ -16,24 +16,27 @@ export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0p
export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
export const STAGING_AREA_THUMBNAIL_STRIP_HEIGHT = '72px';
export const getOutputImageName = (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
isCanvasOutputNodeId(nodeId)
)?.[1][0];
const output = nodeId ? item.session.results[nodeId] : undefined;
export const getOutputImageNames = (item: S['SessionQueueItem']): string[] => {
const imageNames: string[] = [];
if (!output) {
return null;
}
for (const [_name, value] of objectEntries(output)) {
if (isImageField(value)) {
return value.image_name;
for (const [sourceNodeId, preparedNodeIds] of Object.entries(item.session.source_prepared_mapping)) {
if (!isCanvasOutputNodeId(sourceNodeId)) {
continue;
}
const nodeId = preparedNodeIds[0];
const output = nodeId ? item.session.results[nodeId] : undefined;
if (!output) {
continue;
}
for (const [_name, value] of objectEntries(output)) {
if (isImageField(value)) {
imageNames.push(value.image_name);
}
}
if (isImageFieldCollection(value)) {
return value[0]?.image_name ?? null;
}
}
return null;
return imageNames;
};

View File

@@ -41,18 +41,34 @@ describe('StagingAreaApi', () => {
expect(api.$items.get()).toEqual([]);
expect(api.$progressData.get()).toEqual({});
expect(api.$selectedItemId.get()).toBe(null);
expect(api.$selectedImageIndex.get()).toBe(0);
});
});
describe('Computed Values', () => {
it('should compute item count correctly', () => {
it('should compute item count as entry count for single-output items', () => {
expect(api.$itemCount.get()).toBe(0);
const items = [createMockQueueItem({ item_id: 1 })];
api.$items.set(items);
// Item with no imageDTOs produces 1 entry
expect(api.$itemCount.get()).toBe(1);
});
it('should compute item count as entry count for multi-output items', () => {
const items = [createMockQueueItem({ item_id: 1 })];
api.$items.set(items);
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [createMockImageDTO({ image_name: 'a.png' }), createMockImageDTO({ image_name: 'b.png' })],
imageLoaded: false,
});
// 2 imageDTOs = 2 entries
expect(api.$itemCount.get()).toBe(2);
});
it('should compute hasItems correctly', () => {
expect(api.$hasItems.get()).toBe(false);
@@ -95,7 +111,7 @@ describe('StagingAreaApi', () => {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTO,
imageDTOs: [imageDTO],
imageLoaded: false,
});
@@ -111,6 +127,95 @@ describe('StagingAreaApi', () => {
});
});
describe('Entries Computed', () => {
it('should produce one entry per item when each has 0 or 1 imageDTOs', () => {
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
api.$items.set(items);
const entries = api.$entries.get();
expect(entries).toHaveLength(2);
expect(entries[0]?.item.item_id).toBe(1);
expect(entries[0]?.imageIndex).toBe(0);
expect(entries[0]?.imageDTO).toBe(null);
expect(entries[1]?.item.item_id).toBe(2);
expect(entries[1]?.imageIndex).toBe(0);
});
it('should produce one entry per item when each has exactly 1 imageDTO', () => {
const imageDTO = createMockImageDTO({ image_name: 'img.png' });
const items = [createMockQueueItem({ item_id: 1 })];
api.$items.set(items);
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [imageDTO],
imageLoaded: false,
});
const entries = api.$entries.get();
expect(entries).toHaveLength(1);
expect(entries[0]?.imageDTO).toBe(imageDTO);
expect(entries[0]?.imageIndex).toBe(0);
});
it('should produce multiple entries for items with multiple imageDTOs', () => {
const img1 = createMockImageDTO({ image_name: 'a.png' });
const img2 = createMockImageDTO({ image_name: 'b.png' });
const items = [createMockQueueItem({ item_id: 1 })];
api.$items.set(items);
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [img1, img2],
imageLoaded: false,
});
const entries = api.$entries.get();
expect(entries).toHaveLength(2);
expect(entries[0]?.item.item_id).toBe(1);
expect(entries[0]?.imageDTO).toBe(img1);
expect(entries[0]?.imageIndex).toBe(0);
expect(entries[1]?.item.item_id).toBe(1);
expect(entries[1]?.imageDTO).toBe(img2);
expect(entries[1]?.imageIndex).toBe(1);
});
it('should interleave entries from multiple items', () => {
const img1a = createMockImageDTO({ image_name: 'a1.png' });
const img1b = createMockImageDTO({ image_name: 'a2.png' });
const img2 = createMockImageDTO({ image_name: 'b1.png' });
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
api.$items.set(items);
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [img1a, img1b],
imageLoaded: false,
});
api.$progressData.setKey(2, {
itemId: 2,
progressEvent: null,
progressImage: null,
imageDTOs: [img2],
imageLoaded: false,
});
const entries = api.$entries.get();
expect(entries).toHaveLength(3);
// Item 1 entries first
expect(entries[0]?.item.item_id).toBe(1);
expect(entries[0]?.imageDTO).toBe(img1a);
expect(entries[1]?.item.item_id).toBe(1);
expect(entries[1]?.imageDTO).toBe(img1b);
// Then item 2
expect(entries[2]?.item.item_id).toBe(2);
expect(entries[2]?.imageDTO).toBe(img2);
});
});
describe('Selection Methods', () => {
beforeEach(() => {
const items = [
@@ -124,9 +229,16 @@ describe('StagingAreaApi', () => {
it('should select item by ID', () => {
api.select(2);
expect(api.$selectedItemId.get()).toBe(2);
expect(api.$selectedImageIndex.get()).toBe(0);
expect(mockApp.onSelect).toHaveBeenCalledWith(2);
});
it('should select item by ID and imageIndex', () => {
api.select(2, 1);
expect(api.$selectedItemId.get()).toBe(2);
expect(api.$selectedImageIndex.get()).toBe(1);
});
it('should select next item', () => {
api.$selectedItemId.set(1);
api.selectNext();
@@ -184,6 +296,124 @@ describe('StagingAreaApi', () => {
});
});
describe('Entry-Based Navigation', () => {
it('should navigate through entries of multi-output items', () => {
const img1 = createMockImageDTO({ image_name: 'a.png' });
const img2 = createMockImageDTO({ image_name: 'b.png' });
const items = [createMockQueueItem({ item_id: 1 })];
api.$items.set(items);
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [img1, img2],
imageLoaded: false,
});
api.$selectedItemId.set(1);
api.$selectedImageIndex.set(0);
// Should be on first entry (imageIndex 0)
expect(api.$selectedItem.get()?.imageDTO).toBe(img1);
expect(api.$selectedItem.get()?.index).toBe(0);
// Navigate to next entry (imageIndex 1, same item)
api.selectNext();
expect(api.$selectedItemId.get()).toBe(1);
expect(api.$selectedImageIndex.get()).toBe(1);
expect(api.$selectedItem.get()?.imageDTO).toBe(img2);
expect(api.$selectedItem.get()?.index).toBe(1);
// Navigate past last entry - should wrap to first
api.selectNext();
expect(api.$selectedItemId.get()).toBe(1);
expect(api.$selectedImageIndex.get()).toBe(0);
expect(api.$selectedItem.get()?.imageDTO).toBe(img1);
});
it('should navigate across items and their entries', () => {
const img1a = createMockImageDTO({ image_name: 'a1.png' });
const img1b = createMockImageDTO({ image_name: 'a2.png' });
const img2 = createMockImageDTO({ image_name: 'b1.png' });
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
api.$items.set(items);
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [img1a, img1b],
imageLoaded: false,
});
api.$progressData.setKey(2, {
itemId: 2,
progressEvent: null,
progressImage: null,
imageDTOs: [img2],
imageLoaded: false,
});
// Start on first entry
api.$selectedItemId.set(1);
api.$selectedImageIndex.set(0);
// entries: [item1/img0, item1/img1, item2/img0]
expect(api.$entries.get()).toHaveLength(3);
// Next -> item1, imageIndex 1
api.selectNext();
expect(api.$selectedItemId.get()).toBe(1);
expect(api.$selectedImageIndex.get()).toBe(1);
// Next -> item2, imageIndex 0
api.selectNext();
expect(api.$selectedItemId.get()).toBe(2);
expect(api.$selectedImageIndex.get()).toBe(0);
// Next -> wraps to item1, imageIndex 0
api.selectNext();
expect(api.$selectedItemId.get()).toBe(1);
expect(api.$selectedImageIndex.get()).toBe(0);
// Prev -> wraps to item2, imageIndex 0
api.selectPrev();
expect(api.$selectedItemId.get()).toBe(2);
expect(api.$selectedImageIndex.get()).toBe(0);
});
it('should select correct entry with selectFirst and selectLast', () => {
const img1a = createMockImageDTO({ image_name: 'a1.png' });
const img1b = createMockImageDTO({ image_name: 'a2.png' });
const img2 = createMockImageDTO({ image_name: 'b1.png' });
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
api.$items.set(items);
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [img1a, img1b],
imageLoaded: false,
});
api.$progressData.setKey(2, {
itemId: 2,
progressEvent: null,
progressImage: null,
imageDTOs: [img2],
imageLoaded: false,
});
api.selectLast();
expect(api.$selectedItemId.get()).toBe(2);
expect(api.$selectedImageIndex.get()).toBe(0);
expect(api.$selectedItem.get()?.imageDTO).toBe(img2);
api.selectFirst();
expect(api.$selectedItemId.get()).toBe(1);
expect(api.$selectedImageIndex.get()).toBe(0);
expect(api.$selectedItem.get()?.imageDTO).toBe(img1a);
});
});
describe('Discard Methods', () => {
beforeEach(() => {
const items = [
@@ -201,6 +431,7 @@ describe('StagingAreaApi', () => {
api.discardSelected();
expect(api.$selectedItemId.get()).toBe(3);
expect(api.$selectedImageIndex.get()).toBe(0);
expect(mockApp.onDiscard).toHaveBeenCalledWith(selectedItem?.item);
});
@@ -238,6 +469,7 @@ describe('StagingAreaApi', () => {
api.discardAll();
expect(api.$selectedItemId.get()).toBe(null);
expect(api.$selectedImageIndex.get()).toBe(0);
expect(mockApp.onDiscardAll).toHaveBeenCalled();
});
@@ -247,6 +479,14 @@ describe('StagingAreaApi', () => {
api.$selectedItemId.set(1);
expect(api.$discardSelectedIsEnabled.get()).toBe(true);
});
it('should reset selectedImageIndex when discarding', () => {
api.$selectedItemId.set(1);
api.$selectedImageIndex.set(2);
api.discardSelected();
expect(api.$selectedImageIndex.get()).toBe(0);
});
});
describe('Accept Methods', () => {
@@ -256,13 +496,13 @@ describe('StagingAreaApi', () => {
api.$selectedItemId.set(1);
});
it('should accept selected item when image is available', () => {
it('should accept selected item with single imageDTO', () => {
const imageDTO = createMockImageDTO();
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTO,
imageDTOs: [imageDTO],
imageLoaded: false,
});
@@ -272,12 +512,34 @@ describe('StagingAreaApi', () => {
expect(mockApp.onAccept).toHaveBeenCalledWith(selectedItem?.item, imageDTO);
});
it('should accept the correct imageDTO from a multi-output entry', () => {
const img1 = createMockImageDTO({ image_name: 'a.png' });
const img2 = createMockImageDTO({ image_name: 'b.png' });
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTOs: [img1, img2],
imageLoaded: false,
});
// Select the second image
api.$selectedImageIndex.set(1);
const selectedItem = api.$selectedItem.get();
expect(selectedItem?.imageDTO).toBe(img2);
api.acceptSelected();
expect(mockApp.onAccept).toHaveBeenCalledWith(selectedItem?.item, img2);
});
it('should do nothing when no image is available', () => {
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTO: null,
imageDTOs: [],
imageLoaded: false,
});
@@ -301,7 +563,7 @@ describe('StagingAreaApi', () => {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTO,
imageDTOs: [imageDTO],
imageLoaded: false,
});
@@ -339,7 +601,7 @@ describe('StagingAreaApi', () => {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTO: createMockImageDTO(),
imageDTOs: [createMockImageDTO()],
imageLoaded: false,
});
@@ -352,7 +614,7 @@ describe('StagingAreaApi', () => {
const progressData = api.$progressData.get();
expect(progressData[1]?.progressEvent).toBe(progressEvent);
expect(progressData[1]?.imageDTO).toBeTruthy();
expect(progressData[1]?.imageDTOs.length).toBeGreaterThan(0);
});
});
@@ -438,7 +700,7 @@ describe('StagingAreaApi', () => {
await api.onItemsChangedEvent(items);
const progressData = api.$progressData.get();
expect(progressData[1]?.imageDTO).toBe(imageDTO);
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO);
});
it('should handle auto-switch on completion', async () => {
@@ -473,7 +735,7 @@ describe('StagingAreaApi', () => {
itemId: 999,
progressEvent: null,
progressImage: null,
imageDTO: null,
imageDTOs: [],
imageLoaded: false,
});
@@ -490,7 +752,7 @@ describe('StagingAreaApi', () => {
itemId: 1,
progressEvent: createMockProgressEvent({ item_id: 1 }),
progressImage: null,
imageDTO: createMockImageDTO(),
imageDTOs: [createMockImageDTO()],
imageLoaded: false,
});
@@ -501,7 +763,7 @@ describe('StagingAreaApi', () => {
const progressData = api.$progressData.get();
expect(progressData[1]?.progressEvent).toBe(null);
expect(progressData[1]?.progressImage).toBe(null);
expect(progressData[1]?.imageDTO).toBe(null);
expect(progressData[1]?.imageDTOs).toEqual([]);
});
});
@@ -547,18 +809,34 @@ describe('StagingAreaApi', () => {
});
describe('Utility Methods', () => {
it('should build isSelected computed correctly', () => {
it('should build isSelected computed correctly for default imageIndex', () => {
const isSelected = api.buildIsSelectedComputed(1);
expect(isSelected.get()).toBe(false);
api.$selectedItemId.set(1);
api.$selectedImageIndex.set(0);
expect(isSelected.get()).toBe(true);
});
it('should build isSelected computed correctly with specific imageIndex', () => {
const isSelected0 = api.buildIsSelectedComputed(1, 0);
const isSelected1 = api.buildIsSelectedComputed(1, 1);
api.$selectedItemId.set(1);
api.$selectedImageIndex.set(0);
expect(isSelected0.get()).toBe(true);
expect(isSelected1.get()).toBe(false);
api.$selectedImageIndex.set(1);
expect(isSelected0.get()).toBe(false);
expect(isSelected1.get()).toBe(true);
});
});
describe('Cleanup', () => {
it('should reset all state on cleanup', () => {
api.$selectedItemId.set(1);
api.$selectedImageIndex.set(2);
api.$items.set([createMockQueueItem({ item_id: 1 })]);
api.$lastStartedItemId.set(1);
api.$lastCompletedItemId.set(1);
@@ -566,13 +844,14 @@ describe('StagingAreaApi', () => {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTO: null,
imageDTOs: [],
imageLoaded: false,
});
api.cleanup();
expect(api.$selectedItemId.get()).toBe(null);
expect(api.$selectedImageIndex.get()).toBe(0);
expect(api.$items.get()).toEqual([]);
expect(api.$lastStartedItemId.get()).toBe(null);
expect(api.$lastCompletedItemId.get()).toBe(null);
@@ -622,13 +901,13 @@ describe('StagingAreaApi', () => {
expect(progressData[1]?.progressEvent).toBe(progressEvent);
});
it('should preserve imageDTO when updating progress', () => {
it('should preserve imageDTOs when updating progress', () => {
const imageDTO = createMockImageDTO();
api.$progressData.setKey(1, {
itemId: 1,
progressEvent: null,
progressImage: null,
imageDTO,
imageDTOs: [imageDTO],
imageLoaded: false,
});
@@ -640,7 +919,7 @@ describe('StagingAreaApi', () => {
api.onInvocationProgressEvent(progressEvent);
const progressData = api.$progressData.get();
expect(progressData[1]?.imageDTO).toBe(imageDTO);
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO);
expect(progressData[1]?.progressEvent).toBe(progressEvent);
});
});
@@ -712,6 +991,118 @@ describe('StagingAreaApi', () => {
expect(api.$selectedItem.get()?.item.item_id).toBe(2);
});
it('should not let stale async call overwrite newer data with fewer images', async () => {
const imageDTO1 = createMockImageDTO({ image_name: 'img1.png' });
const imageDTO2 = createMockImageDTO({ image_name: 'img2.png' });
mockApp._setImageDTO('img1.png', imageDTO1);
mockApp._setImageDTO('img2.png', imageDTO2);
// Simulates optimistic update: status=completed but only 1 result in session.results
const itemsStale = [
createMockQueueItem({
item_id: 1,
status: 'completed',
session: {
id: sessionId,
source_prepared_mapping: { 'canvas_output:a': ['node-1'] },
results: { 'node-1': { image: { image_name: 'img1.png' } } },
},
}),
];
// Simulates full refetch: both results available
const itemsFull = [
createMockQueueItem({
item_id: 1,
status: 'completed',
session: {
id: sessionId,
source_prepared_mapping: { 'canvas_output:a': ['node-1'], 'canvas_output:b': ['node-2'] },
results: {
'node-1': { image: { image_name: 'img1.png' } },
'node-2': { image: { image_name: 'img2.png' } },
},
},
}),
];
// Fire both concurrently (stale optimistic update then full refetch)
const promise1 = api.onItemsChangedEvent(itemsStale);
const promise2 = api.onItemsChangedEvent(itemsFull);
await Promise.all([promise1, promise2]);
const progressData = api.$progressData.get();
expect(progressData[1]?.imageDTOs).toHaveLength(2);
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO1);
expect(progressData[1]?.imageDTOs[1]).toBe(imageDTO2);
});
it('should load all images from multiple canvas_output nodes', async () => {
const imageDTO1 = createMockImageDTO({ image_name: 'output1.png' });
const imageDTO2 = createMockImageDTO({ image_name: 'output2.png' });
mockApp._setImageDTO('output1.png', imageDTO1);
mockApp._setImageDTO('output2.png', imageDTO2);
const items = [
createMockQueueItem({
item_id: 1,
status: 'completed',
session: {
id: sessionId,
source_prepared_mapping: {
'canvas_output:abc': ['prepared-1'],
'canvas_output:def': ['prepared-2'],
},
results: {
'prepared-1': { image: { image_name: 'output1.png' } },
'prepared-2': { image: { image_name: 'output2.png' } },
},
},
}),
];
await api.onItemsChangedEvent(items);
const progressData = api.$progressData.get();
expect(progressData[1]?.imageDTOs).toHaveLength(2);
expect(progressData[1]?.imageDTOs[0]).toBe(imageDTO1);
expect(progressData[1]?.imageDTOs[1]).toBe(imageDTO2);
});
it('should create separate entries for multiple canvas_output images', async () => {
const imageDTO1 = createMockImageDTO({ image_name: 'output1.png' });
const imageDTO2 = createMockImageDTO({ image_name: 'output2.png' });
mockApp._setImageDTO('output1.png', imageDTO1);
mockApp._setImageDTO('output2.png', imageDTO2);
const items = [
createMockQueueItem({
item_id: 1,
status: 'completed',
session: {
id: sessionId,
source_prepared_mapping: {
'canvas_output:abc': ['prepared-1'],
'canvas_output:def': ['prepared-2'],
},
results: {
'prepared-1': { image: { image_name: 'output1.png' } },
'prepared-2': { image: { image_name: 'output2.png' } },
},
},
}),
];
await api.onItemsChangedEvent(items);
const entries = api.$entries.get();
expect(entries).toHaveLength(2);
expect(entries[0]?.imageDTO).toBe(imageDTO1);
expect(entries[0]?.imageIndex).toBe(0);
expect(entries[1]?.imageDTO).toBe(imageDTO2);
expect(entries[1]?.imageIndex).toBe(1);
});
it('should handle multiple progress events for same item', () => {
const event1 = createMockProgressEvent({
item_id: 1,
@@ -740,7 +1131,7 @@ describe('StagingAreaApi', () => {
itemId: i,
progressEvent: null,
progressImage: null,
imageDTO: null,
imageDTOs: [],
imageLoaded: false,
});
}

View File

@@ -6,7 +6,7 @@ import { atom, computed, map } from 'nanostores';
import type { ImageDTO, S } from 'services/api/types';
import { objectEntries } from 'tsafe';
import { getOutputImageName } from './shared';
import { getOutputImageNames } from './shared';
/**
* Interface for the app-level API that the StagingAreaApi depends on.
@@ -34,14 +34,23 @@ export type ProgressData = {
itemId: number;
progressEvent: S['InvocationProgressEvent'] | null;
progressImage: ProgressImage | null;
imageDTO: ImageDTO | null;
imageDTOs: ImageDTO[];
imageLoaded: boolean;
};
/** Combined data for the currently selected item */
/** A single entry in the staging area. Each canvas_output image is a separate entry. */
export type StagingEntry = {
item: S['SessionQueueItem'];
imageDTO: ImageDTO | null;
imageIndex: number;
progressData: ProgressData;
};
/** Combined data for the currently selected entry */
export type SelectedItemData = {
item: S['SessionQueueItem'];
index: number;
imageDTO: ImageDTO | null;
progressData: ProgressData;
};
@@ -50,7 +59,7 @@ export const getInitialProgressData = (itemId: number): ProgressData => ({
itemId,
progressEvent: null,
progressImage: null,
imageDTO: null,
imageDTOs: [],
imageLoaded: false,
});
type ProgressDataMap = Record<number, ProgressData | undefined>;
@@ -58,8 +67,7 @@ type ProgressDataMap = Record<number, ProgressData | undefined>;
/**
* API for managing the Canvas Staging Area - a view of the image generation queue.
* Provides reactive state management for pending, in-progress, and completed images.
* Users can accept images to place on canvas, discard them, navigate between items,
* and configure auto-switching behavior.
* Each canvas_output node produces a separate entry that can be individually navigated and accepted.
*/
export class StagingAreaApi {
/** The current session ID. */
@@ -71,6 +79,9 @@ export class StagingAreaApi {
/** A set of subscriptions to be cleaned up when we are finished with a session */
_subscriptions = new Set<() => void>();
/** Generation counter to prevent stale async writes in onItemsChangedEvent */
_itemsEventGeneration = 0;
/** Item ID of the last started item. Used for auto-switch on start. */
$lastStartedItemId = atom<number | null>(null);
@@ -86,8 +97,38 @@ export class StagingAreaApi {
/** ID of the currently selected queue item, or null if none selected. */
$selectedItemId = atom<number | null>(null);
/** Total number of items in the queue. */
$itemCount = computed([this.$items], (items) => items.length);
/** Index of the selected image within the selected queue item (for multi-output items). */
$selectedImageIndex = atom<number>(0);
/**
* Flat list of staging entries. Each canvas_output image from a queue item becomes
* a separate entry. Items with 0 or 1 output images produce a single entry.
*/
$entries = computed([this.$items, this.$progressData], (items, progressData) => {
const entries: StagingEntry[] = [];
for (const item of items) {
const datum = progressData[item.item_id] ?? getInitialProgressData(item.item_id);
if (datum.imageDTOs.length <= 1) {
entries.push({
item,
imageDTO: datum.imageDTOs[0] ?? null,
imageIndex: 0,
progressData: datum,
});
} else {
for (let i = 0; i < datum.imageDTOs.length; i++) {
const imageDTO = datum.imageDTOs[i];
if (imageDTO) {
entries.push({ item, imageDTO, imageIndex: i, progressData: datum });
}
}
}
}
return entries;
});
/** Total number of entries (each canvas_output image counts separately). */
$itemCount = computed([this.$entries], (entries) => entries.length);
/** Whether there are any items in the queue. */
$hasItems = computed([this.$items], (items) => items.length > 0);
@@ -97,113 +138,153 @@ export class StagingAreaApi {
items.some((item) => item.status === 'pending' || item.status === 'in_progress')
);
/** The currently selected queue item with its index and progress data, or null if none selected. */
/** The currently selected entry with its global index, or null if none selected. */
$selectedItem = computed(
[this.$items, this.$selectedItemId, this.$progressData],
(items, selectedItemId, progressData) => {
if (items.length === 0) {
[this.$entries, this.$selectedItemId, this.$selectedImageIndex],
(entries, selectedItemId, selectedImageIndex) => {
if (entries.length === 0 || selectedItemId === null) {
return null;
}
if (selectedItemId === null) {
return null;
// Find the entry matching (selectedItemId, selectedImageIndex)
let targetEntry: StagingEntry | null = null;
let globalIndex = -1;
let imageIdxWithinItem = 0;
for (let i = 0; i < entries.length; i++) {
const entry = entries[i]!;
if (entry.item.item_id === selectedItemId) {
if (imageIdxWithinItem === selectedImageIndex) {
targetEntry = entry;
globalIndex = i;
break;
}
imageIdxWithinItem++;
}
}
const item = items.find(({ item_id }) => item_id === selectedItemId);
if (!item) {
// Fallback: select first entry for this item
if (!targetEntry) {
for (let i = 0; i < entries.length; i++) {
const entry = entries[i]!;
if (entry.item.item_id === selectedItemId) {
targetEntry = entry;
globalIndex = i;
break;
}
}
}
if (!targetEntry || globalIndex === -1) {
return null;
}
return {
item,
index: items.findIndex(({ item_id }) => item_id === selectedItemId),
progressData: progressData[selectedItemId] || getInitialProgressData(selectedItemId),
item: targetEntry.item,
index: globalIndex,
imageDTO: targetEntry.imageDTO,
progressData: targetEntry.progressData,
};
}
);
/** The ImageDTO of the currently selected item, or null if none available. */
/** The ImageDTO of the currently selected entry, or null if none available. */
$selectedItemImageDTO = computed([this.$selectedItem], (selectedItem) => {
return selectedItem?.progressData.imageDTO ?? null;
return selectedItem?.imageDTO ?? null;
});
/** The index of the currently selected item, or null if none selected. */
/** The global entry index of the currently selected entry, or null if none selected. */
$selectedItemIndex = computed([this.$selectedItem], (selectedItem) => {
return selectedItem?.index ?? null;
});
/** Selects a queue item by ID. */
select = (itemId: number) => {
/** Selects a queue item by ID, optionally at a specific image index. */
select = (itemId: number, imageIndex: number = 0) => {
this.$selectedItemId.set(itemId);
this.$selectedImageIndex.set(imageIndex);
this._app?.onSelect?.(itemId);
};
/** Selects the next item in the queue, wrapping to the first item if at the end. */
/** Selects the next entry, cycling through all entries across all items. */
selectNext = () => {
const selectedItem = this.$selectedItem.get();
if (selectedItem === null) {
return;
}
const items = this.$items.get();
const nextIndex = (selectedItem.index + 1) % items.length;
const nextItem = items[nextIndex];
if (!nextItem) {
const entries = this.$entries.get();
if (entries.length <= 1) {
return;
}
this.$selectedItemId.set(nextItem.item_id);
const nextIndex = (selectedItem.index + 1) % entries.length;
const nextEntry = entries[nextIndex];
if (!nextEntry) {
return;
}
this.$selectedItemId.set(nextEntry.item.item_id);
this.$selectedImageIndex.set(nextEntry.imageIndex);
this._app?.onSelectNext?.();
};
/** Selects the previous item in the queue, wrapping to the last item if at the beginning. */
/** Selects the previous entry, cycling through all entries across all items. */
selectPrev = () => {
const selectedItem = this.$selectedItem.get();
if (selectedItem === null) {
return;
}
const items = this.$items.get();
const prevIndex = (selectedItem.index - 1 + items.length) % items.length;
const prevItem = items[prevIndex];
if (!prevItem) {
const entries = this.$entries.get();
if (entries.length <= 1) {
return;
}
this.$selectedItemId.set(prevItem.item_id);
const prevIndex = (selectedItem.index - 1 + entries.length) % entries.length;
const prevEntry = entries[prevIndex];
if (!prevEntry) {
return;
}
this.$selectedItemId.set(prevEntry.item.item_id);
this.$selectedImageIndex.set(prevEntry.imageIndex);
this._app?.onSelectPrev?.();
};
/** Selects the first item in the queue. */
/** Selects the first entry. */
selectFirst = () => {
const items = this.$items.get();
const first = items.at(0);
const entries = this.$entries.get();
const first = entries[0];
if (!first) {
return;
}
this.$selectedItemId.set(first.item_id);
this.$selectedItemId.set(first.item.item_id);
this.$selectedImageIndex.set(first.imageIndex);
this._app?.onSelectFirst?.();
};
/** Selects the last item in the queue. */
/** Selects the last entry. */
selectLast = () => {
const items = this.$items.get();
const last = items.at(-1);
const entries = this.$entries.get();
const last = entries.at(-1);
if (!last) {
return;
}
this.$selectedItemId.set(last.item_id);
this.$selectedItemId.set(last.item.item_id);
this.$selectedImageIndex.set(last.imageIndex);
this._app?.onSelectLast?.();
};
/** Discards the currently selected item and selects the next available item. */
/** Discards the queue item of the currently selected entry and selects the next available entry. */
discardSelected = () => {
const selectedItem = this.$selectedItem.get();
if (selectedItem === null) {
return;
}
const items = this.$items.get();
const nextIndex = clamp(selectedItem.index + 1, 0, items.length - 1);
const nextItem = items[nextIndex];
const itemIndex = items.findIndex((i) => i.item_id === selectedItem.item.item_id);
const nextItemIndex = clamp(itemIndex + 1, 0, items.length - 1);
const nextItem = items[nextItemIndex];
if (nextItem) {
this.$selectedItemId.set(nextItem.item_id);
} else {
this.$selectedItemId.set(null);
}
this.$selectedImageIndex.set(0);
this._app?.onDiscard?.(selectedItem.item);
};
@@ -231,30 +312,22 @@ export class StagingAreaApi {
/** Discards all items in the queue. */
discardAll = () => {
this.$selectedItemId.set(null);
this.$selectedImageIndex.set(0);
this._app?.onDiscardAll?.();
};
/** Accepts the currently selected item if an image is available. */
/** Accepts the currently selected entry if an image is available. */
acceptSelected = () => {
const selectedItem = this.$selectedItem.get();
if (selectedItem === null) {
if (selectedItem === null || !selectedItem.imageDTO) {
return;
}
const progressData = this.$progressData.get();
const datum = progressData[selectedItem.item.item_id];
if (!datum || !datum.imageDTO) {
return;
}
this._app?.onAccept?.(selectedItem.item, datum.imageDTO);
this._app?.onAccept?.(selectedItem.item, selectedItem.imageDTO);
};
/** Whether the accept selected action is enabled. */
$acceptSelectedIsEnabled = computed([this.$selectedItem, this.$progressData], (selectedItem, progressData) => {
if (selectedItem === null) {
return false;
}
const datum = progressData[selectedItem.item.item_id];
return !!datum && !!datum.imageDTO;
$acceptSelectedIsEnabled = computed([this.$selectedItem], (selectedItem) => {
return selectedItem !== null && selectedItem.imageDTO !== null;
});
/** Sets the auto-switch mode. */
@@ -297,6 +370,10 @@ export class StagingAreaApi {
* handles auto-selection, and implements auto-switch behavior.
*/
onItemsChangedEvent = async (items: S['SessionQueueItem'][]) => {
// Increment generation counter. If a newer call starts while we're awaiting,
// we'll detect it and avoid overwriting with stale data.
const generation = ++this._itemsEventGeneration;
const oldItems = this.$items.get();
if (items === oldItems) {
@@ -306,9 +383,11 @@ export class StagingAreaApi {
if (items.length === 0) {
// If there are no items, cannot have a selected item.
this.$selectedItemId.set(null);
this.$selectedImageIndex.set(0);
} else if (this.$selectedItemId.get() === null && items.length > 0) {
// If there is no selected item but there are items, select the first one.
this.$selectedItemId.set(items[0]?.item_id ?? null);
this.$selectedImageIndex.set(0);
}
const progressData = this.$progressData.get();
@@ -328,7 +407,7 @@ export class StagingAreaApi {
...(datum ?? getInitialProgressData(item.item_id)),
progressEvent: null,
progressImage: null,
imageDTO: null,
imageDTOs: [],
});
continue;
}
@@ -336,31 +415,57 @@ export class StagingAreaApi {
if (item.status === 'in_progress') {
if (this.$lastStartedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_start') {
this.$selectedItemId.set(item.item_id);
this.$selectedImageIndex.set(0);
this.$lastStartedItemId.set(null);
}
continue;
}
if (item.status === 'completed') {
if (datum?.imageDTO) {
const outputImageNames = getOutputImageNames(item);
if (outputImageNames.length === 0) {
continue;
}
const outputImageName = getOutputImageName(item);
if (!outputImageName) {
// Check current progress data (not the snapshot) to account for concurrent updates
const currentDatum = this.$progressData.get()[item.item_id];
if (currentDatum && currentDatum.imageDTOs.length === outputImageNames.length) {
continue;
}
const imageDTO = await this._app?.getImageDTO(outputImageName);
if (!imageDTO) {
const imageDTOs: ImageDTO[] = [];
for (const imageName of outputImageNames) {
const imageDTO = await this._app?.getImageDTO(imageName);
if (imageDTO) {
imageDTOs.push(imageDTO);
}
}
if (imageDTOs.length === 0) {
continue;
}
// After async work, check if a newer event has started processing.
// If so, abort to let the newer call handle the update with fresher data.
if (generation !== this._itemsEventGeneration) {
return;
}
// Re-read progress data to avoid overwriting a better result from a concurrent call
const latestDatum = this.$progressData.get()[item.item_id];
if (latestDatum && latestDatum.imageDTOs.length >= imageDTOs.length) {
continue;
}
this.$progressData.setKey(item.item_id, {
...(datum ?? getInitialProgressData(item.item_id)),
imageDTO,
...(latestDatum ?? getInitialProgressData(item.item_id)),
imageDTOs,
});
}
}
// After async work, check if a newer event has started processing
if (generation !== this._itemsEventGeneration) {
return;
}
const selectedItemId = this.$selectedItemId.get();
if (selectedItemId !== null && !items.find(({ item_id }) => item_id === selectedItemId)) {
// If the selected item no longer exists, select the next best item.
@@ -370,15 +475,18 @@ export class StagingAreaApi {
const nextItem = items[nextItemIndex] ?? items[nextItemIndex - 1];
if (nextItem) {
this.$selectedItemId.set(nextItem.item_id);
this.$selectedImageIndex.set(0);
}
} else {
// Next, if there is an in-progress item, select that.
const inProgressItem = items.find(({ status }) => status === 'in_progress');
if (inProgressItem) {
this.$selectedItemId.set(inProgressItem.item_id);
this.$selectedImageIndex.set(0);
}
// Finally just select the first item.
this.$selectedItemId.set(items[0]?.item_id ?? null);
this.$selectedImageIndex.set(0);
}
}
@@ -393,6 +501,7 @@ export class StagingAreaApi {
// This is the load logic mentioned in the comment in the QueueItemStatusChangedEvent handler above.
if (this.$lastCompletedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_finish') {
this.$selectedItemId.set(item.item_id);
this.$selectedImageIndex.set(0);
this.$lastCompletedItemId.set(null);
}
const datum = this.$progressData.get()[item.item_id];
@@ -402,20 +511,22 @@ export class StagingAreaApi {
});
};
/** Creates a computed value that returns true if the given item ID is selected. */
buildIsSelectedComputed = (itemId: number) => {
return computed([this.$selectedItemId], (selectedItemId) => {
return selectedItemId === itemId;
/** Creates a computed value that returns true if the given item ID and image index is selected. */
buildIsSelectedComputed = (itemId: number, imageIndex: number = 0) => {
return computed([this.$selectedItemId, this.$selectedImageIndex], (selectedItemId, selectedImageIndex) => {
return selectedItemId === itemId && selectedImageIndex === imageIndex;
});
};
/** Cleans up all state and unsubscribes from all events. */
cleanup = () => {
this._itemsEventGeneration++;
this.$lastStartedItemId.set(null);
this.$lastCompletedItemId.set(null);
this.$items.set([]);
this.$progressData.set({});
this.$selectedItemId.set(null);
this.$selectedImageIndex.set(0);
this._subscriptions.forEach((unsubscribe) => unsubscribe());
this._subscriptions.clear();
};

View File

@@ -0,0 +1,25 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { canvasWorkflowIntegrationOpened } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFlowArrowBold } from 'react-icons/pi';
export const CanvasEntityMenuItemsRunWorkflow = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
const onClick = useCallback(() => {
dispatch(canvasWorkflowIntegrationOpened({ sourceEntityIdentifier: entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<MenuItem onClick={onClick} icon={<PiFlowArrowBold />}>
{t('controlLayers.workflowIntegration.runWorkflow')}
</MenuItem>
);
});
CanvasEntityMenuItemsRunWorkflow.displayName = 'CanvasEntityMenuItemsRunWorkflow';

View File

@@ -156,8 +156,8 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
return;
}
if (selectedItem.progressData.imageDTO) {
this.$imageSrc.set({ type: 'imageName', data: selectedItem.progressData.imageDTO.image_name });
if (selectedItem.imageDTO) {
this.$imageSrc.set({ type: 'imageName', data: selectedItem.imageDTO.image_name });
return;
} else if (selectedItem.progressData?.progressImage) {
this.$imageSrc.set({ type: 'dataURL', data: selectedItem.progressData.progressImage.dataURL });

View File

@@ -0,0 +1,134 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { assert } from 'tsafe';
import z from 'zod';
const zCanvasWorkflowIntegrationState = z.object({
_version: z.literal(1),
isOpen: z.boolean(),
selectedWorkflowId: z.string().nullable(),
sourceEntityIdentifier: z
.object({
type: z.enum(['raster_layer', 'control_layer', 'regional_guidance', 'inpaint_mask']),
id: z.string(),
})
.nullable(),
fieldValues: z.record(z.string(), z.any()).nullable(),
// Which ImageField to use for canvas image (format: "nodeId.fieldName")
selectedImageFieldKey: z.string().nullable(),
isProcessing: z.boolean(),
});
type CanvasWorkflowIntegrationState = z.infer<typeof zCanvasWorkflowIntegrationState>;
const getInitialState = (): CanvasWorkflowIntegrationState => ({
_version: 1,
isOpen: false,
selectedWorkflowId: null,
sourceEntityIdentifier: null,
fieldValues: null,
selectedImageFieldKey: null,
isProcessing: false,
});
const slice = createSlice({
name: 'canvasWorkflowIntegration',
initialState: getInitialState(),
reducers: {
canvasWorkflowIntegrationOpened: (
state,
action: PayloadAction<{ sourceEntityIdentifier: CanvasEntityIdentifier }>
) => {
state.isOpen = true;
state.sourceEntityIdentifier = action.payload.sourceEntityIdentifier;
state.selectedWorkflowId = null;
state.fieldValues = null;
},
canvasWorkflowIntegrationClosed: (state) => {
state.isOpen = false;
state.selectedWorkflowId = null;
state.sourceEntityIdentifier = null;
state.fieldValues = null;
state.selectedImageFieldKey = null;
state.isProcessing = false;
},
canvasWorkflowIntegrationWorkflowSelected: (state, action: PayloadAction<{ workflowId: string | null }>) => {
state.selectedWorkflowId = action.payload.workflowId;
// Reset field values when switching workflows
state.fieldValues = null;
state.selectedImageFieldKey = null;
},
canvasWorkflowIntegrationImageFieldSelected: (state, action: PayloadAction<{ fieldKey: string | null }>) => {
state.selectedImageFieldKey = action.payload.fieldKey;
},
canvasWorkflowIntegrationFieldValueChanged: (
state,
action: PayloadAction<{ fieldName: string; value: unknown }>
) => {
if (!state.fieldValues) {
state.fieldValues = {};
}
state.fieldValues[action.payload.fieldName] = action.payload.value;
},
canvasWorkflowIntegrationFieldValuesReset: (state) => {
state.fieldValues = null;
},
canvasWorkflowIntegrationProcessingStarted: (state) => {
state.isProcessing = true;
},
canvasWorkflowIntegrationProcessingCompleted: (state) => {
state.isProcessing = false;
},
},
});
export const {
canvasWorkflowIntegrationOpened,
canvasWorkflowIntegrationClosed,
canvasWorkflowIntegrationWorkflowSelected,
canvasWorkflowIntegrationImageFieldSelected,
canvasWorkflowIntegrationFieldValueChanged,
canvasWorkflowIntegrationProcessingStarted,
canvasWorkflowIntegrationProcessingCompleted,
} = slice.actions;
export const canvasWorkflowIntegrationSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zCanvasWorkflowIntegrationState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zCanvasWorkflowIntegrationState.parse(state);
},
persistDenylist: ['isOpen', 'isProcessing', 'sourceEntityIdentifier'],
},
};
const selectCanvasWorkflowIntegrationSlice = (state: RootState) => state.canvasWorkflowIntegration;
const createCanvasWorkflowIntegrationSelector = <T>(selector: Selector<CanvasWorkflowIntegrationState, T>) =>
createSelector(selectCanvasWorkflowIntegrationSlice, selector);
export const selectCanvasWorkflowIntegrationIsOpen = createCanvasWorkflowIntegrationSelector((state) => state.isOpen);
export const selectCanvasWorkflowIntegrationSelectedWorkflowId = createCanvasWorkflowIntegrationSelector(
(state) => state.selectedWorkflowId
);
export const selectCanvasWorkflowIntegrationSourceEntityIdentifier = createCanvasWorkflowIntegrationSelector(
(state) => state.sourceEntityIdentifier
);
export const selectCanvasWorkflowIntegrationFieldValues = createCanvasWorkflowIntegrationSelector(
(state) => state.fieldValues
);
export const selectCanvasWorkflowIntegrationSelectedImageFieldKey = createCanvasWorkflowIntegrationSelector(
(state) => state.selectedImageFieldKey
);
export const selectCanvasWorkflowIntegrationIsProcessing = createCanvasWorkflowIntegrationSelector(
(state) => state.isProcessing
);

View File

@@ -1,6 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
export { getPrefixedId };
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
@@ -205,7 +207,7 @@ export const getInfill = (
assert(false, 'Unknown infill method');
};
const CANVAS_OUTPUT_PREFIX = 'canvas_output';
export const CANVAS_OUTPUT_PREFIX = 'canvas_output';
export const isMainModelWithoutUnet = (modelLoader: Invocation<MainModelLoaderNodes>) => {
return (

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import {
selectAutoSwitch,
selectGalleryView,
@@ -13,8 +14,10 @@ import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { LRUCache } from 'lru-cache';
import { LIST_ALL_TAG } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories } from 'services/api/util';
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
@@ -217,6 +220,27 @@ export const buildOnInvocationComplete = (
return imageDTOs;
};
const clearCanvasWorkflowIntegrationProcessing = (data: S['InvocationCompleteEvent']) => {
// Check if this is a canvas workflow integration result
// Results go to staging area automatically via destination = canvasSessionId
if (data.origin !== 'canvas_workflow_integration') {
return;
}
// Clear processing state so the modal loading spinner stops
dispatch(canvasWorkflowIntegrationProcessingCompleted());
// Check if this invocation produced an image output
const hasImageOutput = objectEntries(data.result).some(([_name, value]) => {
return isImageField(value) || isImageFieldCollection(value);
});
// Only invalidate if this invocation produced an image - this ensures the staging area
// gets updated immediately when output images are available, without invalidating on every invocation
if (hasImageOutput) {
dispatch(queueApi.util.invalidateTags([{ type: 'SessionQueueItem', id: LIST_ALL_TAG }]));
}
};
return async (data: S['InvocationCompleteEvent']) => {
if (finishedQueueItemIds.has(data.item_id)) {
log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`);
@@ -236,6 +260,10 @@ export const buildOnInvocationComplete = (
upsertExecutionState(_nodeExecutionState.nodeId, _nodeExecutionState);
}
// Clear canvas workflow integration processing state if needed
clearCanvasWorkflowIntegrationProcessing(data);
// Add images to gallery (canvas workflow integration results go to staging area automatically)
await addImagesToGallery(data);
$lastProgressEvent.set(null);

View File

@@ -5,6 +5,7 @@ import type { AppStore } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { forEach, isNil, round } from 'es-toolkit/compat';
import { allEntitiesDeleted, controlLayerRecalled } from 'features/controlLayers/store/canvasSlice';
import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
import {
heightChanged,
@@ -160,6 +161,10 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
};
upsertExecutionState(nes.nodeId, nes);
}
// Clear canvas workflow integration processing state on error
if (data.origin === 'canvas_workflow_integration') {
dispatch(canvasWorkflowIntegrationProcessingCompleted());
}
});
const onInvocationComplete = buildOnInvocationComplete(getState, dispatch, finishedQueueItemIds);
@@ -407,6 +412,25 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
})
);
// Optimistically update the listAllQueueItems cache for this destination so the canvas
// staging area immediately reflects status changes without waiting for a tag-based refetch
if (destination) {
dispatch(
queueApi.util.updateQueryData('listAllQueueItems', { destination }, (draft) => {
const item = draft.find((i) => i.item_id === item_id);
if (item) {
item.status = status;
item.started_at = started_at;
item.updated_at = updated_at;
item.completed_at = completed_at;
item.error_type = error_type;
item.error_message = error_message;
item.error_traceback = error_traceback;
}
})
);
}
// Invalidate caches for things we cannot easily update
// Invalidate SessionQueueStatus to refetch with user-specific counts
const tagsToInvalidate: ApiTagDescription[] = [