From ae42182246dfb28da96798d4e54d2c96109ea4d9 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 7 Apr 2026 03:52:06 +0200 Subject: [PATCH 1/5] fix: detect Z-Image LoRAs with transformer.layers prefix (#8986) OneTrainer exports Z-Image LoRAs with 'transformer.layers.' key prefix instead of 'diffusion_model.layers.'. Add this prefix (and the PEFT-wrapped 'base_model.model.transformer.layers.' variant) to the Z-Image LoRA probe so these models are correctly identified and loaded. --- invokeai/backend/model_manager/configs/lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 1619c9d6f0..791ded2ed0 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -711,6 +711,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "transformer.layers.", # OneTrainer/diffusers prefix variant + "base_model.model.transformer.layers.", # PEFT-wrapped variant }, ) @@ -747,6 +749,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "transformer.layers.", # OneTrainer/diffusers prefix variant + "base_model.model.transformer.layers.", # PEFT-wrapped variant }, ) From f08b8029682fbcbbef0163944b2468cf40a085a6 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 7 Apr 2026 04:04:48 +0200 Subject: [PATCH 2/5] feat: add support for OneTrainer BFL Flux LoRA format (#8984) * feat: add support for OneTrainer BFL Flux LoRA format Newer versions of OneTrainer export Flux LoRAs using BFL internal key names (double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and split QKV projections (qkv.0/1/2, linear1.0/1/2/3). This format was not recognized by any existing detector. Add detection and conversion for this format, merging split QKV and linear1 layers into MergedLayerPatch instances for the fused BFL model. * chore ruff --- .../model_manager/load/model_loaders/lora.py | 6 + invokeai/backend/model_manager/taxonomy.py | 1 + ...ux_onetrainer_bfl_lora_conversion_utils.py | 168 ++++++++++++++++++ .../patches/lora_conversions/formats.py | 5 + 4 files changed, 180 insertions(+) create mode 100644 invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index d39982456a..67d862a01d 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -44,6 +44,10 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict, ) +from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_bfl_format, + lora_model_from_flux_onetrainer_bfl_state_dict, +) from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, lora_model_from_flux_onetrainer_state_dict, @@ -128,6 +132,8 @@ class LoRALoader(ModelLoader): model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict=state_dict): + model = lora_model_from_flux_onetrainer_bfl_state_dict(state_dict=state_dict) elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) elif is_state_dict_likely_flux_control(state_dict=state_dict): diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index c002418a6b..9dc0da7733 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -210,6 +210,7 @@ class FluxLoRAFormat(str, Enum): AIToolkit = "flux.aitoolkit" XLabs = "flux.xlabs" BflPeft = "flux.bfl_peft" + OneTrainerBfl = "flux.onetrainer_bfl" AnyVariant: TypeAlias = Union[ diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py new file mode 100644 index 0000000000..b2109222a3 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py @@ -0,0 +1,168 @@ +"""Utilities for detecting and converting FLUX LoRAs in OneTrainer BFL format. + +This format is produced by newer versions of OneTrainer and uses BFL internal key names +(double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and +InvokeAI-native LoRA suffixes (lora_down.weight, lora_up.weight, alpha). + +Unlike the standard BFL PEFT format (which uses 'diffusion_model.' prefix and lora_A/lora_B), +this format also has split QKV projections: + - double_blocks.{i}.img_attn.qkv.{0,1,2} (Q, K, V separate) + - double_blocks.{i}.txt_attn.qkv.{0,1,2} (Q, K, V separate) + - single_blocks.{i}.linear1.{0,1,2,3} (Q, K, V, MLP separate) + +Example keys: + transformer.double_blocks.0.img_attn.qkv.0.lora_down.weight + transformer.double_blocks.0.img_attn.qkv.0.lora_up.weight + transformer.double_blocks.0.img_attn.qkv.0.alpha + transformer.single_blocks.0.linear1.3.lora_down.weight + transformer.double_blocks.0.img_mlp.0.lora_down.weight +""" + +import re +from typing import Any, Dict + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw + +_TRANSFORMER_PREFIX = "transformer." + +# Valid LoRA weight suffixes in this format. +_LORA_SUFFIXES = ("lora_down.weight", "lora_up.weight", "alpha") + +# Regex to detect split QKV keys in double blocks: e.g. "double_blocks.0.img_attn.qkv.1" +_SPLIT_QKV_RE = re.compile(r"^(double_blocks\.\d+\.(img_attn|txt_attn)\.qkv)\.\d+$") + +# Regex to detect split linear1 keys in single blocks: e.g. "single_blocks.0.linear1.2" +_SPLIT_LINEAR1_RE = re.compile(r"^(single_blocks\.\d+\.linear1)\.\d+$") + + +def is_state_dict_likely_in_flux_onetrainer_bfl_format( + state_dict: dict[str | int, Any], + metadata: dict[str, Any] | None = None, +) -> bool: + """Checks if the provided state dict is likely in the OneTrainer BFL FLUX LoRA format. + + This format uses BFL internal key names with 'transformer.' prefix and split QKV projections. + """ + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + if not str_keys: + return False + + # All keys must start with 'transformer.' + if not all(k.startswith(_TRANSFORMER_PREFIX) for k in str_keys): + return False + + # All keys must end with recognized LoRA suffixes. + if not all(k.endswith(_LORA_SUFFIXES) for k in str_keys): + return False + + # Must have BFL block structure (double_blocks or single_blocks) under transformer prefix. + has_bfl_blocks = any( + k.startswith("transformer.double_blocks.") or k.startswith("transformer.single_blocks.") for k in str_keys + ) + if not has_bfl_blocks: + return False + + # Must have split QKV pattern (qkv.0, qkv.1, qkv.2) to distinguish from other formats + # that might use transformer. prefix in the future. + has_split_qkv = any(".qkv.0." in k or ".qkv.1." in k or ".qkv.2." in k or ".linear1.0." in k for k in str_keys) + if not has_split_qkv: + return False + + return True + + +def _split_key(key: str) -> tuple[str, str]: + """Split a key into (layer_name, weight_suffix). + + Handles: + - 2-component suffixes ending with '.weight': e.g., 'lora_down.weight' → split at 2nd-to-last dot + - 1-component suffixes: e.g., 'alpha' → split at last dot + """ + if key.endswith(".weight"): + parts = key.rsplit(".", maxsplit=2) + return parts[0], f"{parts[1]}.{parts[2]}" + else: + parts = key.rsplit(".", maxsplit=1) + return parts[0], parts[1] + + +def lora_model_from_flux_onetrainer_bfl_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: + """Convert a OneTrainer BFL format FLUX LoRA state dict to a ModelPatchRaw. + + Strips the 'transformer.' prefix, groups by layer, and merges split QKV/linear1 + layers into MergedLayerPatch instances. + """ + # Step 1: Strip prefix and group by layer name. + grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {} + for key, value in state_dict.items(): + if not isinstance(key, str): + continue + + # Strip 'transformer.' prefix. + key = key[len(_TRANSFORMER_PREFIX) :] + + layer_name, suffix = _split_key(key) + + if layer_name not in grouped_state_dict: + grouped_state_dict[layer_name] = {} + grouped_state_dict[layer_name][suffix] = value + + # Step 2: Build LoRA layers, merging split QKV and linear1. + layers: dict[str, BaseLayerPatch] = {} + + # Identify which layers need merging. + merge_groups: dict[str, list[str]] = {} + standalone_keys: list[str] = [] + + for layer_key in grouped_state_dict: + qkv_match = _SPLIT_QKV_RE.match(layer_key) + linear1_match = _SPLIT_LINEAR1_RE.match(layer_key) + + if qkv_match: + parent = qkv_match.group(1) + if parent not in merge_groups: + merge_groups[parent] = [] + merge_groups[parent].append(layer_key) + elif linear1_match: + parent = linear1_match.group(1) + if parent not in merge_groups: + merge_groups[parent] = [] + merge_groups[parent].append(layer_key) + else: + standalone_keys.append(layer_key) + + # Process standalone layers. + for layer_key in standalone_keys: + layer_sd = grouped_state_dict[layer_key] + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"] = any_lora_layer_from_state_dict(layer_sd) + + # Process merged layers. + for parent_key, sub_keys in merge_groups.items(): + # Sort by the numeric index at the end (e.g., qkv.0, qkv.1, qkv.2). + sub_keys.sort(key=lambda k: int(k.rsplit(".", maxsplit=1)[1])) + + sub_layers: list[BaseLayerPatch] = [] + sub_ranges: list[Range] = [] + dim_0_offset = 0 + + for sub_key in sub_keys: + layer_sd = grouped_state_dict[sub_key] + sub_layer = any_lora_layer_from_state_dict(layer_sd) + + # Determine the output dimension from the up weight shape. + up_weight = layer_sd["lora_up.weight"] + out_dim = up_weight.shape[0] + + sub_layers.append(sub_layer) + sub_ranges.append(Range(dim_0_offset, dim_0_offset + out_dim)) + dim_0_offset += out_dim + + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{parent_key}"] = MergedLayerPatch(sub_layers, sub_ranges) + + return ModelPatchRaw(layers=layers) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 0b316602fc..b3e00c288b 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -14,6 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( is_state_dict_likely_in_flux_kohya_format, ) +from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_bfl_format, +) from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, ) @@ -28,6 +31,8 @@ def flux_format_from_state_dict( ) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya + elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict, metadata): + return FluxLoRAFormat.OneTrainerBfl elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): return FluxLoRAFormat.OneTrainer elif is_state_dict_likely_in_flux_diffusers_format(state_dict): From dbbf28925b7fb8ae06d5e89b407b86aa8dc6a5c4 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 7 Apr 2026 04:31:33 +0200 Subject: [PATCH 3/5] fix: detect FLUX.2 Klein 9B Base variant via filename heuristic (#9011) Klein 9B Base (undistilled) and Klein 9B (distilled) have identical architectures and cannot be distinguished from the state dict alone. Use a filename heuristic ("base" in the name) to detect the Base variant for checkpoint, GGUF, and diffusers format models. Also fixes the incorrect guidance_embeds-based detection for diffusers format, since both variants have guidance_embeds=False. --- .../backend/model_manager/configs/main.py | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index 6f737ceb92..dff887f7d0 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -323,6 +323,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool: return False +def _filename_suggests_base(name: str) -> bool: + """Check if a model name/filename suggests it is a Base (undistilled) variant. + + Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished + from the state dict. We use the filename as a heuristic: filenames containing "base" + (e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model. + """ + return "base" in name.lower() + + def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None: """Determine FLUX.2 variant from state dict. @@ -330,9 +340,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N - Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560) - Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096) - Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare. - We default to Klein9B (distilled) for all 9B models since GGUF models may not - include guidance embedding keys needed to distinguish them. + Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and cannot be distinguished from the state dict alone. This function defaults to Klein9B + for all 9B models. Callers should use filename heuristics to detect Klein9BBase. Supports both BFL format (checkpoint) and diffusers format keys: - BFL format: txt_in.weight (context embedder) @@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N context_in_dim = shape[1] # Determine variant based on context dimension if context_in_dim == KLEIN_9B_CONTEXT_DIM: - # Default to Klein9B (distilled) - the official/common 9B model + # Default to Klein9B - callers use filename heuristics to detect Klein9BBase return Flux2VariantType.Klein9B elif context_in_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B @@ -553,6 +563,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -720,6 +735,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -829,12 +849,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size) - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size) - To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled), - we check guidance_embeds: - - Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation) - - Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference) - - Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False. + Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and both have guidance_embeds=False. We use a filename heuristic to detect Base models. """ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560 KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096 @@ -842,17 +858,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json") joint_attention_dim = transformer_config.get("joint_attention_dim", 4096) - guidance_embeds = transformer_config.get("guidance_embeds", False) # Determine variant based on joint_attention_dim if joint_attention_dim == KLEIN_9B_CONTEXT_DIM: - # Check guidance_embeds to distinguish distilled from undistilled - # Klein 9B (distilled): guidance_embeds = False (guidance is baked in) - # Klein 9B Base (undistilled): guidance_embeds = True (needs guidance) - if guidance_embeds: + if _filename_suggests_base(mod.name): return Flux2VariantType.Klein9BBase - else: - return Flux2VariantType.Klein9B + return Flux2VariantType.Klein9B elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B elif joint_attention_dim > 4096: From 80be1b72822d9c6df38f03adde38d85ab959b6a6 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 7 Apr 2026 05:09:29 +0200 Subject: [PATCH 4/5] fix: correct inaccurate download size estimates in starter models (#8968) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Verified model sizes against Hugging Face repositories and corrected 11 descriptions that had wrong or outdated download size estimates. Key corrections: - T5-XXL base encoder: ~8GB → ~9.5GB - FLUX.2 VAE: ~335MB → ~168MB (was confused with FLUX.1 VAE) - FLUX.1 Krea dev: ~33GB → ~29GB (uses quantized T5, not full) - FLUX.2 Klein 4B/9B Diffusers: ~10GB/~20GB → ~16GB/~35GB - SD3.5 Medium/Large: ~15GB/~19G → ~16GB/~28GB - CogView4: ~29GB → ~31GB - Z-Image Turbo: ~30.6GB → ~33GB - FLUX.1 Kontext/Krea quantized: ~14GB → ~12GB --- .../backend/model_manager/starter_models.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index 9f86f83dc5..3f09ddbe76 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -71,7 +71,7 @@ t5_base_encoder = StarterModel( name="t5_base_encoder", base=BaseModelType.Any, source="InvokeAI/t5-v1_1-xxl::bfloat16", - description="T5-XXL text encoder (used in FLUX pipelines). ~8GB", + description="T5-XXL text encoder (used in FLUX pipelines). ~9.5GB", type=ModelType.T5Encoder, ) @@ -156,7 +156,7 @@ flux_kontext_quantized = StarterModel( name="FLUX.1 Kontext dev (quantized)", base=BaseModelType.Flux, source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf", - description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB", + description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~12GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -164,7 +164,7 @@ flux_krea = StarterModel( name="FLUX.1 Krea dev", base=BaseModelType.Flux, source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors", - description="FLUX.1 Krea dev. Total size with dependencies: ~33GB", + description="FLUX.1 Krea dev. Total size with dependencies: ~29GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -172,7 +172,7 @@ flux_krea_quantized = StarterModel( name="FLUX.1 Krea dev (quantized)", base=BaseModelType.Flux, source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf", - description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB", + description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~12GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -180,7 +180,7 @@ sd35_medium = StarterModel( name="SD3.5 Medium", base=BaseModelType.StableDiffusion3, source="stabilityai/stable-diffusion-3.5-medium", - description="Medium SD3.5 Model: ~15GB", + description="Medium SD3.5 Model: ~16GB", type=ModelType.Main, dependencies=[], ) @@ -188,7 +188,7 @@ sd35_large = StarterModel( name="SD3.5 Large", base=BaseModelType.StableDiffusion3, source="stabilityai/stable-diffusion-3.5-large", - description="Large SD3.5 Model: ~19G", + description="Large SD3.5 Model: ~28GB", type=ModelType.Main, dependencies=[], ) @@ -644,7 +644,7 @@ cogview4 = StarterModel( name="CogView4", base=BaseModelType.CogView4, source="THUDM/CogView4-6B", - description="The base CogView4 model (~29GB).", + description="The base CogView4 model (~31GB).", type=ModelType.Main, ) # endregion @@ -695,7 +695,7 @@ flux2_vae = StarterModel( name="FLUX.2 VAE", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-4B::vae", - description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~335MB", + description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~168MB", type=ModelType.VAE, ) @@ -719,7 +719,7 @@ flux2_klein_4b = StarterModel( name="FLUX.2 Klein 4B (Diffusers)", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-4B", - description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~10GB", + description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~16GB", type=ModelType.Main, ) @@ -745,7 +745,7 @@ flux2_klein_9b = StarterModel( name="FLUX.2 Klein 9B (Diffusers)", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-9B", - description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~20GB", + description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~35GB", type=ModelType.Main, ) @@ -821,7 +821,7 @@ z_image_turbo = StarterModel( name="Z-Image Turbo", base=BaseModelType.ZImage, source="Tongyi-MAI/Z-Image-Turbo", - description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~30.6GB", + description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~33GB", type=ModelType.Main, ) From 60d0bcdbc1755dad460eb7123082eb51093bfea8 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 7 Apr 2026 05:25:06 +0200 Subject: [PATCH 5/5] Feature(UI): Canvas Workflow Integration - Run Workflow on Raster Layer (#8665) * feat: Add canvas-workflow integration feature This commit implements a new feature that allows users to run workflows directly from the unified canvas. Users can now: - Access a "Run Workflow" option from the canvas layer context menu - Select a workflow with image parameters from a modal dialog - Customize workflow parameters (non-image fields) - Execute the workflow with the current canvas layer as input - Have the result automatically added back to the canvas Key changes: - Added canvasWorkflowIntegrationSlice for state management - Created CanvasWorkflowIntegrationModal and related UI components - Added context menu item to raster layers - Integrated workflow execution with canvas image extraction - Added modal to global modal isolator This integration enhances the canvas by allowing users to leverage custom workflows for advanced image processing directly within the canvas workspace. Implements feature request for deeper workflow-canvas integration. * refactor(ui): simplify canvas workflow integration field rendering - Extract WorkflowFieldRenderer component for individual field rendering - Add WorkflowFormPreview component to handle workflow parameter display - Remove workflow compatibility filtering - allow all workflows - Simplify workflow selector to use flattened workflow list - Add comprehensive field type support (String, Integer, Float, Boolean, Enum, Scheduler, Board, Model, Image, Color) - Implement image field selection UI with radio * feat(ui): add canvas-workflow-integration logging namespace * feat(ui): add workflow filtering for canvas-workflow integration - Add useFilteredWorkflows hook to filter workflows with ImageField inputs - Add workflowHasImageField utility to check for ImageField in Form Builder - Only show workflows that have Form Builder with at least one ImageField - Add loading state while filtering workflows - Improve error messages to clarify Form Builder requirement - Update modal description to mention Form Builder and parameter adjustment - Add fallback error message for workflows without Form Builder * feat(ui): add persistence and migration for canvas workflow integration state - Add _version field (v1) to canvasWorkflowIntegrationState for future migrations - Add persistConfig with migration function to handle version upgrades - Add persistDenylist to exclude transient state (isOpen, isProcessing, sourceEntityIdentifier) - Use es-toolkit isPlainObject and tsafe assert for type-safe migration - Persist selectedWorkflowId and fieldValues across sessions * pnpm fix imports * fix(ui): handle workflow errors in canvas staging area and improve form UX - Clear processing state when workflow execution fails at enqueue time or during invocation, so the modal doesn't get stuck - Optimistically update listAllQueueItems cache on queue item status changes so the staging area immediately exits on failure - Clear processing state on invocation_error for canvas workflow origin - Auto-select the only unfilled ImageField in workflow form - Fix image field overflow and thumbnail sizing in workflow form * feat(ui): add canvas_output node and entry-based staging area Add a dedicated `canvas_output` backend invocation node that explicitly marks which images go to the canvas staging area, replacing the fragile board-based heuristic. Each `canvas_output` node produces a separate navigable entry in the staging area, allowing workflows with multiple outputs to be individually previewed and accepted. Key changes: - New `CanvasOutputInvocation` backend node (canvas.py) - Entry-based staging area model where each output image is a separate navigable entry with flat next/prev cycling across all items - Frontend execute hook uses `canvas_output` type detection instead of board field heuristic, with proper board field value translation - Workflow filtering requires both Form Builder and canvas_output node - Updated QueueItemPreviewMini and StagingAreaItemsList for entries - Tests for entry-based navigation, multi-output, and race conditions * Chore pnp run fix * Chore eslint fix * Remove unused useOutputImageDTO export to fix knip lint * Update invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/useCanvasWorkflowIntegrationExecute.tsx Co-authored-by: dunkeroni * move UI text to en.json * fix conflicts merge with main * generate schema * Chore typegen --------- Co-authored-by: Claude Co-authored-by: Lincoln Stein Co-authored-by: dunkeroni --- invokeai/app/invocations/canvas.py | 27 + invokeai/frontend/web/public/locales/en.json | 21 + .../app/components/GlobalModalIsolator.tsx | 2 + .../frontend/web/src/app/logging/logger.ts | 1 + invokeai/frontend/web/src/app/store/store.ts | 3 + .../CanvasWorkflowIntegrationModal.tsx | 93 +++ ...anvasWorkflowIntegrationParameterPanel.tsx | 13 + ...vasWorkflowIntegrationWorkflowSelector.tsx | 92 +++ .../WorkflowFieldRenderer.tsx | 548 ++++++++++++++++++ .../WorkflowFormPreview.tsx | 289 +++++++++ .../useCanvasWorkflowIntegrationExecute.tsx | 302 ++++++++++ .../useFilteredWorkflows.tsx | 107 ++++ .../workflowHasImageField.tsx | 86 +++ .../RasterLayer/RasterLayerMenuItems.tsx | 2 + .../StagingArea/QueueItemPreviewMini.tsx | 36 +- .../StagingArea/StagingAreaItemsList.tsx | 20 +- .../components/StagingArea/context.tsx | 11 +- .../components/StagingArea/shared.test.ts | 38 +- .../components/StagingArea/shared.ts | 29 +- .../components/StagingArea/state.test.ts | 427 +++++++++++++- .../components/StagingArea/state.ts | 257 +++++--- .../CanvasEntityMenuItemsRunWorkflow.tsx | 25 + .../konva/CanvasStagingAreaModule.ts | 4 +- .../store/canvasWorkflowIntegrationSlice.ts | 134 +++++ .../nodes/util/graph/graphBuilderUtils.ts | 4 +- .../frontend/web/src/services/api/schema.ts | 49 +- .../services/events/onInvocationComplete.tsx | 28 + .../src/services/events/setEventListeners.tsx | 24 + 28 files changed, 2510 insertions(+), 162 deletions(-) create mode 100644 invokeai/app/invocations/canvas.py create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFormPreview.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/useCanvasWorkflowIntegrationExecute.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/useFilteredWorkflows.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/workflowHasImageField.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityMenuItemsRunWorkflow.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/store/canvasWorkflowIntegrationSlice.ts diff --git a/invokeai/app/invocations/canvas.py b/invokeai/app/invocations/canvas.py new file mode 100644 index 0000000000..cf13c3334f --- /dev/null +++ b/invokeai/app/invocations/canvas.py @@ -0,0 +1,27 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation( + "canvas_output", + title="Canvas Output", + tags=["canvas", "output", "image"], + category="canvas", + version="1.0.0", + use_cache=False, +) +class CanvasOutputInvocation(BaseInvocation): + """Outputs an image to the canvas staging area. + + Use this node in workflows intended for canvas workflow integration. + Connect the final image of your workflow to this node to send it + to the canvas staging area when run via 'Run Workflow on Canvas'.""" + + image: ImageField = InputField(description=FieldDescriptions.image) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + image_dto = context.images.save(image=image) + return ImageOutput.build(image_dto) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 047d5a4007..9ba645eef8 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2377,6 +2377,27 @@ "pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage", "addAdjustments": "Add Adjustments", "removeAdjustments": "Remove Adjustments", + "workflowIntegration": { + "title": "Run Workflow on Canvas", + "description": "Select a workflow with a Canvas Output node and an image parameter to run on the current canvas layer. You can adjust parameters before executing. The result will be added back to the canvas.", + "execute": "Execute Workflow", + "executing": "Executing...", + "runWorkflow": "Run Workflow", + "filteringWorkflows": "Filtering workflows...", + "loadingWorkflows": "Loading workflows...", + "noWorkflowsFound": "No workflows found.", + "noWorkflowsWithImageField": "No compatible workflows found. A workflow needs a Form Builder with an image input field and a Canvas Output node.", + "selectWorkflow": "Select Workflow", + "selectPlaceholder": "Choose a workflow...", + "unnamedWorkflow": "Unnamed Workflow", + "loadingParameters": "Loading workflow parameters...", + "noFormBuilderError": "This workflow has no form builder and cannot be used. Please select a different workflow.", + "imageFieldSelected": "This field will receive the canvas image", + "imageFieldNotSelected": "Click to use this field for canvas image", + "executionStarted": "Workflow execution started", + "executionStartedDescription": "The result will appear in the staging area when complete.", + "executionFailed": "Failed to execute workflow" + }, "compositeOperation": { "label": "Blend Mode", "add": "Add Blend Mode", diff --git a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx index 5c1446662e..ef0747707f 100644 --- a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx +++ b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx @@ -1,6 +1,7 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal'; +import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal'; import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { CropImageModal } from 'features/cropper/components/CropImageModal'; import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal'; @@ -51,6 +52,7 @@ export const GlobalModalIsolator = memo(() => { + diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts index 6c843068df..d20ef77090 100644 --- a/invokeai/frontend/web/src/app/logging/logger.ts +++ b/invokeai/frontend/web/src/app/logging/logger.ts @@ -16,6 +16,7 @@ const $logger = atom(Roarr.child(BASE_CONTEXT)); export const zLogNamespace = z.enum([ 'canvas', + 'canvas-workflow-integration', 'config', 'dnd', 'events', diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 8f077baaea..f24d2d0105 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -25,6 +25,7 @@ import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSe import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice'; import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { canvasTextSliceConfig } from 'features/controlLayers/store/canvasTextSlice'; +import { canvasWorkflowIntegrationSliceConfig } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice'; import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice'; import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice'; @@ -67,6 +68,7 @@ const SLICE_CONFIGS = { [canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig, [canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig, [canvasSliceConfig.slice.reducerPath]: canvasSliceConfig, + [canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig, [changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig, [dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig, [gallerySliceConfig.slice.reducerPath]: gallerySliceConfig, @@ -98,6 +100,7 @@ const ALL_REDUCERS = { canvasSliceConfig.slice.reducer, canvasSliceConfig.undoableConfig?.reduxUndoOptions ), + [canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig.slice.reducer, [changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer, [dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer, [gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer, diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal.tsx new file mode 100644 index 0000000000..94a123fa91 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal.tsx @@ -0,0 +1,93 @@ +import { + Button, + ButtonGroup, + Flex, + Heading, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + Spacer, + Spinner, + Text, +} from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + canvasWorkflowIntegrationClosed, + selectCanvasWorkflowIntegrationIsOpen, + selectCanvasWorkflowIntegrationIsProcessing, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { CanvasWorkflowIntegrationParameterPanel } from './CanvasWorkflowIntegrationParameterPanel'; +import { CanvasWorkflowIntegrationWorkflowSelector } from './CanvasWorkflowIntegrationWorkflowSelector'; +import { useCanvasWorkflowIntegrationExecute } from './useCanvasWorkflowIntegrationExecute'; + +export const CanvasWorkflowIntegrationModal = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const isOpen = useAppSelector(selectCanvasWorkflowIntegrationIsOpen); + const isProcessing = useAppSelector(selectCanvasWorkflowIntegrationIsProcessing); + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + + const { execute, canExecute } = useCanvasWorkflowIntegrationExecute(); + + const onClose = useCallback(() => { + if (!isProcessing) { + dispatch(canvasWorkflowIntegrationClosed()); + } + }, [dispatch, isProcessing]); + + const onExecute = useCallback(() => { + execute(); + }, [execute]); + + return ( + + + + + {t('controlLayers.workflowIntegration.title')} + + + + + + + {t('controlLayers.workflowIntegration.description')} + + + + + {selectedWorkflowId && } + + + + + + + + + + + + + ); +}); + +CanvasWorkflowIntegrationModal.displayName = 'CanvasWorkflowIntegrationModal'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx new file mode 100644 index 0000000000..f59a6c45ed --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx @@ -0,0 +1,13 @@ +import { Box } from '@invoke-ai/ui-library'; +import { WorkflowFormPreview } from 'features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFormPreview'; +import { memo } from 'react'; + +export const CanvasWorkflowIntegrationParameterPanel = memo(() => { + return ( + + + + ); +}); + +CanvasWorkflowIntegrationParameterPanel.displayName = 'CanvasWorkflowIntegrationParameterPanel'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx new file mode 100644 index 0000000000..30bc60605c --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx @@ -0,0 +1,92 @@ +import { Flex, FormControl, FormLabel, Select, Spinner, Text } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + canvasWorkflowIntegrationWorkflowSelected, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useListWorkflowsInfiniteInfiniteQuery } from 'services/api/endpoints/workflows'; + +import { useFilteredWorkflows } from './useFilteredWorkflows'; + +export const CanvasWorkflowIntegrationWorkflowSelector = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + const { data: workflowsData, isLoading } = useListWorkflowsInfiniteInfiniteQuery( + { + per_page: 100, // Get a reasonable number of workflows + page: 0, + }, + { + selectFromResult: ({ data, isLoading }) => ({ + data, + isLoading, + }), + } + ); + + const workflows = useMemo(() => { + if (!workflowsData) { + return []; + } + // Flatten all pages into a single list + return workflowsData.pages.flatMap((page) => page.items); + }, [workflowsData]); + + // Filter workflows to only show those with ImageFields + const { filteredWorkflows, isFiltering } = useFilteredWorkflows(workflows); + + const onChange = useCallback( + (e: ChangeEvent) => { + const workflowId = e.target.value || null; + dispatch(canvasWorkflowIntegrationWorkflowSelected({ workflowId })); + }, + [dispatch] + ); + + if (isLoading || isFiltering) { + return ( + + + + {isFiltering + ? t('controlLayers.workflowIntegration.filteringWorkflows') + : t('controlLayers.workflowIntegration.loadingWorkflows')} + + + ); + } + + if (filteredWorkflows.length === 0) { + return ( + + {workflows.length === 0 + ? t('controlLayers.workflowIntegration.noWorkflowsFound') + : t('controlLayers.workflowIntegration.noWorkflowsWithImageField')} + + ); + } + + return ( + + {t('controlLayers.workflowIntegration.selectWorkflow')} + + + ); +}); + +CanvasWorkflowIntegrationWorkflowSelector.displayName = 'CanvasWorkflowIntegrationWorkflowSelector'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx new file mode 100644 index 0000000000..2d91be13bf --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx @@ -0,0 +1,548 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { + Combobox, + Flex, + FormControl, + FormLabel, + IconButton, + Input, + Radio, + Select, + Switch, + Text, + Textarea, +} from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { logger } from 'app/logging/logger'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { UploadImageIconButton } from 'common/hooks/useImageUploadButton'; +import { + canvasWorkflowIntegrationFieldValueChanged, + canvasWorkflowIntegrationImageFieldSelected, + selectCanvasWorkflowIntegrationFieldValues, + selectCanvasWorkflowIntegrationSelectedImageFieldKey, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import { DndImage } from 'features/dnd/DndImage'; +import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox'; +import { $templates } from 'features/nodes/store/nodesSlice'; +import type { NodeFieldElement } from 'features/nodes/types/workflow'; +import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants'; +import { isParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import { useGetWorkflowQuery } from 'services/api/endpoints/workflows'; +import type { AnyModelConfig, ImageDTO } from 'services/api/types'; + +const log = logger('canvas-workflow-integration'); + +interface WorkflowFieldRendererProps { + el: NodeFieldElement; +} + +export const WorkflowFieldRenderer = memo(({ el }: WorkflowFieldRendererProps) => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues); + const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey); + const templates = useStore($templates); + + const { data: workflow } = useGetWorkflowQuery(selectedWorkflowId!, { + skip: !selectedWorkflowId, + }); + + // Load boards and models for BoardField and ModelIdentifierField + const { data: boardsData } = useListAllBoardsQuery({ include_archived: true }); + const { data: modelsData, isLoading: isLoadingModels } = useGetModelConfigsQuery(); + + const { fieldIdentifier } = el.data; + const fieldKey = `${fieldIdentifier.nodeId}.${fieldIdentifier.fieldName}`; + + log.debug({ fieldIdentifier, fieldKey }, 'Rendering workflow field'); + + // Get the node, field instance, and field template + const { field, fieldTemplate } = useMemo(() => { + if (!workflow?.workflow.nodes) { + log.warn('No workflow nodes found'); + return { field: null, fieldTemplate: null }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const foundNode = workflow.workflow.nodes.find((n: any) => n.data.id === fieldIdentifier.nodeId); + if (!foundNode) { + log.warn({ nodeId: fieldIdentifier.nodeId }, 'Node not found'); + return { field: null, fieldTemplate: null }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const foundField = (foundNode.data as any).inputs[fieldIdentifier.fieldName]; + if (!foundField) { + log.warn({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName }, 'Field not found in node'); + return { field: null, fieldTemplate: null }; + } + + // Get the field template from the invocation templates + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const nodeType = (foundNode.data as any).type; + const template = templates[nodeType]; + if (!template) { + log.warn({ nodeType }, 'No template found for node type'); + return { field: foundField, fieldTemplate: null }; + } + + const foundFieldTemplate = template.inputs[fieldIdentifier.fieldName]; + if (!foundFieldTemplate) { + log.warn({ nodeType, fieldName: fieldIdentifier.fieldName }, 'Field template not found'); + return { field: foundField, fieldTemplate: null }; + } + + return { field: foundField, fieldTemplate: foundFieldTemplate }; + }, [workflow, fieldIdentifier, templates]); + + // Get the current value from Redux or fallback to field default + const currentValue = useMemo(() => { + if (fieldValues && fieldKey in fieldValues) { + return fieldValues[fieldKey]; + } + + return field?.value ?? fieldTemplate?.default ?? ''; + }, [fieldValues, fieldKey, field, fieldTemplate]); + + // Get field type from the template + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const fieldType = fieldTemplate ? (fieldTemplate as any).type?.name : null; + + const handleChange = useCallback( + (value: unknown) => { + dispatch(canvasWorkflowIntegrationFieldValueChanged({ fieldName: fieldKey, value })); + }, + [dispatch, fieldKey] + ); + + const handleStringChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.value); + }, + [handleChange] + ); + + const handleNumberChange = useCallback( + (e: ChangeEvent) => { + const val = fieldType === 'IntegerField' ? parseInt(e.target.value, 10) : parseFloat(e.target.value); + handleChange(isNaN(val) ? 0 : val); + }, + [handleChange, fieldType] + ); + + const handleBooleanChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.checked); + }, + [handleChange] + ); + + const handleSelectChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.value); + }, + [handleChange] + ); + + // SchedulerField handlers + const handleSchedulerChange = useCallback( + (v) => { + if (!isParameterScheduler(v?.value)) { + return; + } + handleChange(v.value); + }, + [handleChange] + ); + + const schedulerValue = useMemo(() => SCHEDULER_OPTIONS.find((o) => o.value === currentValue), [currentValue]); + + // BoardField handlers + const handleBoardChange = useCallback( + (v) => { + if (!v) { + return; + } + const value = v.value === 'auto' || v.value === 'none' ? v.value : { board_id: v.value }; + handleChange(value); + }, + [handleChange] + ); + + const boardOptions = useMemo(() => { + const _options: ComboboxOption[] = [ + { label: t('common.auto'), value: 'auto' }, + { label: `${t('common.none')} (${t('boards.uncategorized')})`, value: 'none' }, + ]; + if (boardsData) { + for (const board of boardsData) { + _options.push({ + label: board.board_name, + value: board.board_id, + }); + } + } + return _options; + }, [boardsData, t]); + + const boardValue = useMemo(() => { + const _value = currentValue; + const autoOption = boardOptions[0]; + const noneOption = boardOptions[1]; + if (!_value || _value === 'auto') { + return autoOption; + } + if (_value === 'none') { + return noneOption; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const boardId = typeof _value === 'object' ? (_value as any).board_id : _value; + const boardOption = boardOptions.find((o) => o.value === boardId); + return boardOption ?? autoOption; + }, [currentValue, boardOptions]); + + const noOptionsMessage = useCallback(() => t('boards.noMatching'), [t]); + + // ModelIdentifierField handlers + const handleModelChange = useCallback( + (value: AnyModelConfig | null) => { + if (!value) { + return; + } + handleChange(value); + }, + [handleChange] + ); + + const modelConfigs = useMemo(() => { + if (!modelsData) { + return EMPTY_ARRAY; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_base = fieldTemplate ? (fieldTemplate as any)?.ui_model_base : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_type = fieldTemplate ? (fieldTemplate as any)?.ui_model_type : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_variant = fieldTemplate ? (fieldTemplate as any)?.ui_model_variant : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_format = fieldTemplate ? (fieldTemplate as any)?.ui_model_format : null; + + if (!ui_model_base && !ui_model_type) { + return modelConfigsAdapterSelectors.selectAll(modelsData); + } + + return modelConfigsAdapterSelectors.selectAll(modelsData).filter((config) => { + if (ui_model_base && !ui_model_base.includes(config.base)) { + return false; + } + if (ui_model_type && !ui_model_type.includes(config.type)) { + return false; + } + if (ui_model_variant && 'variant' in config && config.variant && !ui_model_variant.includes(config.variant)) { + return false; + } + if (ui_model_format && !ui_model_format.includes(config.format)) { + return false; + } + return true; + }); + }, [modelsData, fieldTemplate]); + + // ImageField handler + const handleImageFieldSelect = useCallback(() => { + dispatch(canvasWorkflowIntegrationImageFieldSelected({ fieldKey })); + }, [dispatch, fieldKey]); + + if (!field || !fieldTemplate) { + log.warn({ fieldIdentifier }, 'Field or template is null - not rendering'); + return null; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const label = (field as any)?.label || (fieldTemplate as any)?.title || fieldIdentifier.fieldName; + + // Log the entire field structure to understand its shape + log.debug( + { fieldType, label, currentValue, fieldStructure: field, fieldTemplateStructure: fieldTemplate }, + 'Field info' + ); + + // ImageField - allow user to select which one receives the canvas image + if (fieldType === 'ImageField') { + return ( + + ); + } + + // Render different input types based on field type + if (fieldType === 'StringField') { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const isTextarea = (fieldTemplate as any)?.ui_component === 'textarea'; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const isRequired = (fieldTemplate as any)?.required ?? false; + + if (isTextarea) { + return ( + + {label} +