diff --git a/invokeai/app/invocations/canvas.py b/invokeai/app/invocations/canvas.py new file mode 100644 index 0000000000..cf13c3334f --- /dev/null +++ b/invokeai/app/invocations/canvas.py @@ -0,0 +1,27 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation( + "canvas_output", + title="Canvas Output", + tags=["canvas", "output", "image"], + category="canvas", + version="1.0.0", + use_cache=False, +) +class CanvasOutputInvocation(BaseInvocation): + """Outputs an image to the canvas staging area. + + Use this node in workflows intended for canvas workflow integration. + Connect the final image of your workflow to this node to send it + to the canvas staging area when run via 'Run Workflow on Canvas'.""" + + image: ImageField = InputField(description=FieldDescriptions.image) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + image_dto = context.images.save(image=image) + return ImageOutput.build(image_dto) diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 1619c9d6f0..791ded2ed0 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -711,6 +711,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "transformer.layers.", # OneTrainer/diffusers prefix variant + "base_model.model.transformer.layers.", # PEFT-wrapped variant }, ) @@ -747,6 +749,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "transformer.layers.", # OneTrainer/diffusers prefix variant + "base_model.model.transformer.layers.", # PEFT-wrapped variant }, ) diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index 6f737ceb92..dff887f7d0 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -323,6 +323,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool: return False +def _filename_suggests_base(name: str) -> bool: + """Check if a model name/filename suggests it is a Base (undistilled) variant. + + Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished + from the state dict. We use the filename as a heuristic: filenames containing "base" + (e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model. + """ + return "base" in name.lower() + + def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None: """Determine FLUX.2 variant from state dict. @@ -330,9 +340,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N - Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560) - Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096) - Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare. - We default to Klein9B (distilled) for all 9B models since GGUF models may not - include guidance embedding keys needed to distinguish them. + Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and cannot be distinguished from the state dict alone. This function defaults to Klein9B + for all 9B models. Callers should use filename heuristics to detect Klein9BBase. Supports both BFL format (checkpoint) and diffusers format keys: - BFL format: txt_in.weight (context embedder) @@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N context_in_dim = shape[1] # Determine variant based on context dimension if context_in_dim == KLEIN_9B_CONTEXT_DIM: - # Default to Klein9B (distilled) - the official/common 9B model + # Default to Klein9B - callers use filename heuristics to detect Klein9BBase return Flux2VariantType.Klein9B elif context_in_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B @@ -553,6 +563,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -720,6 +735,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -829,12 +849,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size) - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size) - To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled), - we check guidance_embeds: - - Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation) - - Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference) - - Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False. + Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and both have guidance_embeds=False. We use a filename heuristic to detect Base models. """ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560 KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096 @@ -842,17 +858,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json") joint_attention_dim = transformer_config.get("joint_attention_dim", 4096) - guidance_embeds = transformer_config.get("guidance_embeds", False) # Determine variant based on joint_attention_dim if joint_attention_dim == KLEIN_9B_CONTEXT_DIM: - # Check guidance_embeds to distinguish distilled from undistilled - # Klein 9B (distilled): guidance_embeds = False (guidance is baked in) - # Klein 9B Base (undistilled): guidance_embeds = True (needs guidance) - if guidance_embeds: + if _filename_suggests_base(mod.name): return Flux2VariantType.Klein9BBase - else: - return Flux2VariantType.Klein9B + return Flux2VariantType.Klein9B elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B elif joint_attention_dim > 4096: diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index d39982456a..67d862a01d 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -44,6 +44,10 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict, ) +from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_bfl_format, + lora_model_from_flux_onetrainer_bfl_state_dict, +) from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, lora_model_from_flux_onetrainer_state_dict, @@ -128,6 +132,8 @@ class LoRALoader(ModelLoader): model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict=state_dict): + model = lora_model_from_flux_onetrainer_bfl_state_dict(state_dict=state_dict) elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) elif is_state_dict_likely_flux_control(state_dict=state_dict): diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index 7fe50d91fc..8b2173a87e 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -81,7 +81,7 @@ t5_base_encoder = StarterModel( name="t5_base_encoder", base=BaseModelType.Any, source="InvokeAI/t5-v1_1-xxl::bfloat16", - description="T5-XXL text encoder (used in FLUX pipelines). ~8GB", + description="T5-XXL text encoder (used in FLUX pipelines). ~9.5GB", type=ModelType.T5Encoder, ) @@ -166,7 +166,7 @@ flux_kontext_quantized = StarterModel( name="FLUX.1 Kontext dev (quantized)", base=BaseModelType.Flux, source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf", - description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB", + description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~12GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -174,7 +174,7 @@ flux_krea = StarterModel( name="FLUX.1 Krea dev", base=BaseModelType.Flux, source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors", - description="FLUX.1 Krea dev. Total size with dependencies: ~33GB", + description="FLUX.1 Krea dev. Total size with dependencies: ~29GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -182,7 +182,7 @@ flux_krea_quantized = StarterModel( name="FLUX.1 Krea dev (quantized)", base=BaseModelType.Flux, source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf", - description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB", + description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~12GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -190,7 +190,7 @@ sd35_medium = StarterModel( name="SD3.5 Medium", base=BaseModelType.StableDiffusion3, source="stabilityai/stable-diffusion-3.5-medium", - description="Medium SD3.5 Model: ~15GB", + description="Medium SD3.5 Model: ~16GB", type=ModelType.Main, dependencies=[], ) @@ -198,7 +198,7 @@ sd35_large = StarterModel( name="SD3.5 Large", base=BaseModelType.StableDiffusion3, source="stabilityai/stable-diffusion-3.5-large", - description="Large SD3.5 Model: ~19G", + description="Large SD3.5 Model: ~28GB", type=ModelType.Main, dependencies=[], ) @@ -654,7 +654,7 @@ cogview4 = StarterModel( name="CogView4", base=BaseModelType.CogView4, source="THUDM/CogView4-6B", - description="The base CogView4 model (~29GB).", + description="The base CogView4 model (~31GB).", type=ModelType.Main, ) # endregion @@ -705,7 +705,7 @@ flux2_vae = StarterModel( name="FLUX.2 VAE", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-4B::vae", - description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~335MB", + description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~168MB", type=ModelType.VAE, ) @@ -729,7 +729,7 @@ flux2_klein_4b = StarterModel( name="FLUX.2 Klein 4B (Diffusers)", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-4B", - description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~10GB", + description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~16GB", type=ModelType.Main, ) @@ -755,7 +755,7 @@ flux2_klein_9b = StarterModel( name="FLUX.2 Klein 9B (Diffusers)", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-9B", - description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~20GB", + description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~35GB", type=ModelType.Main, ) @@ -831,7 +831,7 @@ z_image_turbo = StarterModel( name="Z-Image Turbo", base=BaseModelType.ZImage, source="Tongyi-MAI/Z-Image-Turbo", - description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~30.6GB", + description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~33GB", type=ModelType.Main, ) diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 4bf3461a8b..81232a9adc 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -215,6 +215,7 @@ class FluxLoRAFormat(str, Enum): AIToolkit = "flux.aitoolkit" XLabs = "flux.xlabs" BflPeft = "flux.bfl_peft" + OneTrainerBfl = "flux.onetrainer_bfl" AnyVariant: TypeAlias = Union[ diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py new file mode 100644 index 0000000000..b2109222a3 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py @@ -0,0 +1,168 @@ +"""Utilities for detecting and converting FLUX LoRAs in OneTrainer BFL format. + +This format is produced by newer versions of OneTrainer and uses BFL internal key names +(double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and +InvokeAI-native LoRA suffixes (lora_down.weight, lora_up.weight, alpha). + +Unlike the standard BFL PEFT format (which uses 'diffusion_model.' prefix and lora_A/lora_B), +this format also has split QKV projections: + - double_blocks.{i}.img_attn.qkv.{0,1,2} (Q, K, V separate) + - double_blocks.{i}.txt_attn.qkv.{0,1,2} (Q, K, V separate) + - single_blocks.{i}.linear1.{0,1,2,3} (Q, K, V, MLP separate) + +Example keys: + transformer.double_blocks.0.img_attn.qkv.0.lora_down.weight + transformer.double_blocks.0.img_attn.qkv.0.lora_up.weight + transformer.double_blocks.0.img_attn.qkv.0.alpha + transformer.single_blocks.0.linear1.3.lora_down.weight + transformer.double_blocks.0.img_mlp.0.lora_down.weight +""" + +import re +from typing import Any, Dict + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw + +_TRANSFORMER_PREFIX = "transformer." + +# Valid LoRA weight suffixes in this format. +_LORA_SUFFIXES = ("lora_down.weight", "lora_up.weight", "alpha") + +# Regex to detect split QKV keys in double blocks: e.g. "double_blocks.0.img_attn.qkv.1" +_SPLIT_QKV_RE = re.compile(r"^(double_blocks\.\d+\.(img_attn|txt_attn)\.qkv)\.\d+$") + +# Regex to detect split linear1 keys in single blocks: e.g. "single_blocks.0.linear1.2" +_SPLIT_LINEAR1_RE = re.compile(r"^(single_blocks\.\d+\.linear1)\.\d+$") + + +def is_state_dict_likely_in_flux_onetrainer_bfl_format( + state_dict: dict[str | int, Any], + metadata: dict[str, Any] | None = None, +) -> bool: + """Checks if the provided state dict is likely in the OneTrainer BFL FLUX LoRA format. + + This format uses BFL internal key names with 'transformer.' prefix and split QKV projections. + """ + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + if not str_keys: + return False + + # All keys must start with 'transformer.' + if not all(k.startswith(_TRANSFORMER_PREFIX) for k in str_keys): + return False + + # All keys must end with recognized LoRA suffixes. + if not all(k.endswith(_LORA_SUFFIXES) for k in str_keys): + return False + + # Must have BFL block structure (double_blocks or single_blocks) under transformer prefix. + has_bfl_blocks = any( + k.startswith("transformer.double_blocks.") or k.startswith("transformer.single_blocks.") for k in str_keys + ) + if not has_bfl_blocks: + return False + + # Must have split QKV pattern (qkv.0, qkv.1, qkv.2) to distinguish from other formats + # that might use transformer. prefix in the future. + has_split_qkv = any(".qkv.0." in k or ".qkv.1." in k or ".qkv.2." in k or ".linear1.0." in k for k in str_keys) + if not has_split_qkv: + return False + + return True + + +def _split_key(key: str) -> tuple[str, str]: + """Split a key into (layer_name, weight_suffix). + + Handles: + - 2-component suffixes ending with '.weight': e.g., 'lora_down.weight' → split at 2nd-to-last dot + - 1-component suffixes: e.g., 'alpha' → split at last dot + """ + if key.endswith(".weight"): + parts = key.rsplit(".", maxsplit=2) + return parts[0], f"{parts[1]}.{parts[2]}" + else: + parts = key.rsplit(".", maxsplit=1) + return parts[0], parts[1] + + +def lora_model_from_flux_onetrainer_bfl_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: + """Convert a OneTrainer BFL format FLUX LoRA state dict to a ModelPatchRaw. + + Strips the 'transformer.' prefix, groups by layer, and merges split QKV/linear1 + layers into MergedLayerPatch instances. + """ + # Step 1: Strip prefix and group by layer name. + grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {} + for key, value in state_dict.items(): + if not isinstance(key, str): + continue + + # Strip 'transformer.' prefix. + key = key[len(_TRANSFORMER_PREFIX) :] + + layer_name, suffix = _split_key(key) + + if layer_name not in grouped_state_dict: + grouped_state_dict[layer_name] = {} + grouped_state_dict[layer_name][suffix] = value + + # Step 2: Build LoRA layers, merging split QKV and linear1. + layers: dict[str, BaseLayerPatch] = {} + + # Identify which layers need merging. + merge_groups: dict[str, list[str]] = {} + standalone_keys: list[str] = [] + + for layer_key in grouped_state_dict: + qkv_match = _SPLIT_QKV_RE.match(layer_key) + linear1_match = _SPLIT_LINEAR1_RE.match(layer_key) + + if qkv_match: + parent = qkv_match.group(1) + if parent not in merge_groups: + merge_groups[parent] = [] + merge_groups[parent].append(layer_key) + elif linear1_match: + parent = linear1_match.group(1) + if parent not in merge_groups: + merge_groups[parent] = [] + merge_groups[parent].append(layer_key) + else: + standalone_keys.append(layer_key) + + # Process standalone layers. + for layer_key in standalone_keys: + layer_sd = grouped_state_dict[layer_key] + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"] = any_lora_layer_from_state_dict(layer_sd) + + # Process merged layers. + for parent_key, sub_keys in merge_groups.items(): + # Sort by the numeric index at the end (e.g., qkv.0, qkv.1, qkv.2). + sub_keys.sort(key=lambda k: int(k.rsplit(".", maxsplit=1)[1])) + + sub_layers: list[BaseLayerPatch] = [] + sub_ranges: list[Range] = [] + dim_0_offset = 0 + + for sub_key in sub_keys: + layer_sd = grouped_state_dict[sub_key] + sub_layer = any_lora_layer_from_state_dict(layer_sd) + + # Determine the output dimension from the up weight shape. + up_weight = layer_sd["lora_up.weight"] + out_dim = up_weight.shape[0] + + sub_layers.append(sub_layer) + sub_ranges.append(Range(dim_0_offset, dim_0_offset + out_dim)) + dim_0_offset += out_dim + + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{parent_key}"] = MergedLayerPatch(sub_layers, sub_ranges) + + return ModelPatchRaw(layers=layers) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 0b316602fc..b3e00c288b 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -14,6 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( is_state_dict_likely_in_flux_kohya_format, ) +from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_bfl_format, +) from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, ) @@ -28,6 +31,8 @@ def flux_format_from_state_dict( ) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya + elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict, metadata): + return FluxLoRAFormat.OneTrainerBfl elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): return FluxLoRAFormat.OneTrainer elif is_state_dict_likely_in_flux_diffusers_format(state_dict): diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ef058b09e2..f535642d89 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2415,6 +2415,27 @@ "pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage", "addAdjustments": "Add Adjustments", "removeAdjustments": "Remove Adjustments", + "workflowIntegration": { + "title": "Run Workflow on Canvas", + "description": "Select a workflow with a Canvas Output node and an image parameter to run on the current canvas layer. You can adjust parameters before executing. The result will be added back to the canvas.", + "execute": "Execute Workflow", + "executing": "Executing...", + "runWorkflow": "Run Workflow", + "filteringWorkflows": "Filtering workflows...", + "loadingWorkflows": "Loading workflows...", + "noWorkflowsFound": "No workflows found.", + "noWorkflowsWithImageField": "No compatible workflows found. A workflow needs a Form Builder with an image input field and a Canvas Output node.", + "selectWorkflow": "Select Workflow", + "selectPlaceholder": "Choose a workflow...", + "unnamedWorkflow": "Unnamed Workflow", + "loadingParameters": "Loading workflow parameters...", + "noFormBuilderError": "This workflow has no form builder and cannot be used. Please select a different workflow.", + "imageFieldSelected": "This field will receive the canvas image", + "imageFieldNotSelected": "Click to use this field for canvas image", + "executionStarted": "Workflow execution started", + "executionStartedDescription": "The result will appear in the staging area when complete.", + "executionFailed": "Failed to execute workflow" + }, "compositeOperation": { "label": "Blend Mode", "add": "Add Blend Mode", diff --git a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx index 5c1446662e..ef0747707f 100644 --- a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx +++ b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx @@ -1,6 +1,7 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal'; +import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal'; import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { CropImageModal } from 'features/cropper/components/CropImageModal'; import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal'; @@ -51,6 +52,7 @@ export const GlobalModalIsolator = memo(() => { + diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts index 6c843068df..d20ef77090 100644 --- a/invokeai/frontend/web/src/app/logging/logger.ts +++ b/invokeai/frontend/web/src/app/logging/logger.ts @@ -16,6 +16,7 @@ const $logger = atom(Roarr.child(BASE_CONTEXT)); export const zLogNamespace = z.enum([ 'canvas', + 'canvas-workflow-integration', 'config', 'dnd', 'events', diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 8f077baaea..f24d2d0105 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -25,6 +25,7 @@ import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSe import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice'; import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { canvasTextSliceConfig } from 'features/controlLayers/store/canvasTextSlice'; +import { canvasWorkflowIntegrationSliceConfig } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice'; import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice'; import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice'; @@ -67,6 +68,7 @@ const SLICE_CONFIGS = { [canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig, [canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig, [canvasSliceConfig.slice.reducerPath]: canvasSliceConfig, + [canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig, [changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig, [dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig, [gallerySliceConfig.slice.reducerPath]: gallerySliceConfig, @@ -98,6 +100,7 @@ const ALL_REDUCERS = { canvasSliceConfig.slice.reducer, canvasSliceConfig.undoableConfig?.reduxUndoOptions ), + [canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig.slice.reducer, [changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer, [dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer, [gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer, diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal.tsx new file mode 100644 index 0000000000..94a123fa91 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal.tsx @@ -0,0 +1,93 @@ +import { + Button, + ButtonGroup, + Flex, + Heading, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + Spacer, + Spinner, + Text, +} from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + canvasWorkflowIntegrationClosed, + selectCanvasWorkflowIntegrationIsOpen, + selectCanvasWorkflowIntegrationIsProcessing, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { CanvasWorkflowIntegrationParameterPanel } from './CanvasWorkflowIntegrationParameterPanel'; +import { CanvasWorkflowIntegrationWorkflowSelector } from './CanvasWorkflowIntegrationWorkflowSelector'; +import { useCanvasWorkflowIntegrationExecute } from './useCanvasWorkflowIntegrationExecute'; + +export const CanvasWorkflowIntegrationModal = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const isOpen = useAppSelector(selectCanvasWorkflowIntegrationIsOpen); + const isProcessing = useAppSelector(selectCanvasWorkflowIntegrationIsProcessing); + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + + const { execute, canExecute } = useCanvasWorkflowIntegrationExecute(); + + const onClose = useCallback(() => { + if (!isProcessing) { + dispatch(canvasWorkflowIntegrationClosed()); + } + }, [dispatch, isProcessing]); + + const onExecute = useCallback(() => { + execute(); + }, [execute]); + + return ( + + + + + {t('controlLayers.workflowIntegration.title')} + + + + + + + {t('controlLayers.workflowIntegration.description')} + + + + + {selectedWorkflowId && } + + + + + + + + + + + + + ); +}); + +CanvasWorkflowIntegrationModal.displayName = 'CanvasWorkflowIntegrationModal'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx new file mode 100644 index 0000000000..f59a6c45ed --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx @@ -0,0 +1,13 @@ +import { Box } from '@invoke-ai/ui-library'; +import { WorkflowFormPreview } from 'features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFormPreview'; +import { memo } from 'react'; + +export const CanvasWorkflowIntegrationParameterPanel = memo(() => { + return ( + + + + ); +}); + +CanvasWorkflowIntegrationParameterPanel.displayName = 'CanvasWorkflowIntegrationParameterPanel'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx new file mode 100644 index 0000000000..30bc60605c --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx @@ -0,0 +1,92 @@ +import { Flex, FormControl, FormLabel, Select, Spinner, Text } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + canvasWorkflowIntegrationWorkflowSelected, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useListWorkflowsInfiniteInfiniteQuery } from 'services/api/endpoints/workflows'; + +import { useFilteredWorkflows } from './useFilteredWorkflows'; + +export const CanvasWorkflowIntegrationWorkflowSelector = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + const { data: workflowsData, isLoading } = useListWorkflowsInfiniteInfiniteQuery( + { + per_page: 100, // Get a reasonable number of workflows + page: 0, + }, + { + selectFromResult: ({ data, isLoading }) => ({ + data, + isLoading, + }), + } + ); + + const workflows = useMemo(() => { + if (!workflowsData) { + return []; + } + // Flatten all pages into a single list + return workflowsData.pages.flatMap((page) => page.items); + }, [workflowsData]); + + // Filter workflows to only show those with ImageFields + const { filteredWorkflows, isFiltering } = useFilteredWorkflows(workflows); + + const onChange = useCallback( + (e: ChangeEvent) => { + const workflowId = e.target.value || null; + dispatch(canvasWorkflowIntegrationWorkflowSelected({ workflowId })); + }, + [dispatch] + ); + + if (isLoading || isFiltering) { + return ( + + + + {isFiltering + ? t('controlLayers.workflowIntegration.filteringWorkflows') + : t('controlLayers.workflowIntegration.loadingWorkflows')} + + + ); + } + + if (filteredWorkflows.length === 0) { + return ( + + {workflows.length === 0 + ? t('controlLayers.workflowIntegration.noWorkflowsFound') + : t('controlLayers.workflowIntegration.noWorkflowsWithImageField')} + + ); + } + + return ( + + {t('controlLayers.workflowIntegration.selectWorkflow')} + + + ); +}); + +CanvasWorkflowIntegrationWorkflowSelector.displayName = 'CanvasWorkflowIntegrationWorkflowSelector'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx new file mode 100644 index 0000000000..2d91be13bf --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx @@ -0,0 +1,548 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { + Combobox, + Flex, + FormControl, + FormLabel, + IconButton, + Input, + Radio, + Select, + Switch, + Text, + Textarea, +} from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { logger } from 'app/logging/logger'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { UploadImageIconButton } from 'common/hooks/useImageUploadButton'; +import { + canvasWorkflowIntegrationFieldValueChanged, + canvasWorkflowIntegrationImageFieldSelected, + selectCanvasWorkflowIntegrationFieldValues, + selectCanvasWorkflowIntegrationSelectedImageFieldKey, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import { DndImage } from 'features/dnd/DndImage'; +import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox'; +import { $templates } from 'features/nodes/store/nodesSlice'; +import type { NodeFieldElement } from 'features/nodes/types/workflow'; +import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants'; +import { isParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import { useGetWorkflowQuery } from 'services/api/endpoints/workflows'; +import type { AnyModelConfig, ImageDTO } from 'services/api/types'; + +const log = logger('canvas-workflow-integration'); + +interface WorkflowFieldRendererProps { + el: NodeFieldElement; +} + +export const WorkflowFieldRenderer = memo(({ el }: WorkflowFieldRendererProps) => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues); + const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey); + const templates = useStore($templates); + + const { data: workflow } = useGetWorkflowQuery(selectedWorkflowId!, { + skip: !selectedWorkflowId, + }); + + // Load boards and models for BoardField and ModelIdentifierField + const { data: boardsData } = useListAllBoardsQuery({ include_archived: true }); + const { data: modelsData, isLoading: isLoadingModels } = useGetModelConfigsQuery(); + + const { fieldIdentifier } = el.data; + const fieldKey = `${fieldIdentifier.nodeId}.${fieldIdentifier.fieldName}`; + + log.debug({ fieldIdentifier, fieldKey }, 'Rendering workflow field'); + + // Get the node, field instance, and field template + const { field, fieldTemplate } = useMemo(() => { + if (!workflow?.workflow.nodes) { + log.warn('No workflow nodes found'); + return { field: null, fieldTemplate: null }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const foundNode = workflow.workflow.nodes.find((n: any) => n.data.id === fieldIdentifier.nodeId); + if (!foundNode) { + log.warn({ nodeId: fieldIdentifier.nodeId }, 'Node not found'); + return { field: null, fieldTemplate: null }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const foundField = (foundNode.data as any).inputs[fieldIdentifier.fieldName]; + if (!foundField) { + log.warn({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName }, 'Field not found in node'); + return { field: null, fieldTemplate: null }; + } + + // Get the field template from the invocation templates + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const nodeType = (foundNode.data as any).type; + const template = templates[nodeType]; + if (!template) { + log.warn({ nodeType }, 'No template found for node type'); + return { field: foundField, fieldTemplate: null }; + } + + const foundFieldTemplate = template.inputs[fieldIdentifier.fieldName]; + if (!foundFieldTemplate) { + log.warn({ nodeType, fieldName: fieldIdentifier.fieldName }, 'Field template not found'); + return { field: foundField, fieldTemplate: null }; + } + + return { field: foundField, fieldTemplate: foundFieldTemplate }; + }, [workflow, fieldIdentifier, templates]); + + // Get the current value from Redux or fallback to field default + const currentValue = useMemo(() => { + if (fieldValues && fieldKey in fieldValues) { + return fieldValues[fieldKey]; + } + + return field?.value ?? fieldTemplate?.default ?? ''; + }, [fieldValues, fieldKey, field, fieldTemplate]); + + // Get field type from the template + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const fieldType = fieldTemplate ? (fieldTemplate as any).type?.name : null; + + const handleChange = useCallback( + (value: unknown) => { + dispatch(canvasWorkflowIntegrationFieldValueChanged({ fieldName: fieldKey, value })); + }, + [dispatch, fieldKey] + ); + + const handleStringChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.value); + }, + [handleChange] + ); + + const handleNumberChange = useCallback( + (e: ChangeEvent) => { + const val = fieldType === 'IntegerField' ? parseInt(e.target.value, 10) : parseFloat(e.target.value); + handleChange(isNaN(val) ? 0 : val); + }, + [handleChange, fieldType] + ); + + const handleBooleanChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.checked); + }, + [handleChange] + ); + + const handleSelectChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.value); + }, + [handleChange] + ); + + // SchedulerField handlers + const handleSchedulerChange = useCallback( + (v) => { + if (!isParameterScheduler(v?.value)) { + return; + } + handleChange(v.value); + }, + [handleChange] + ); + + const schedulerValue = useMemo(() => SCHEDULER_OPTIONS.find((o) => o.value === currentValue), [currentValue]); + + // BoardField handlers + const handleBoardChange = useCallback( + (v) => { + if (!v) { + return; + } + const value = v.value === 'auto' || v.value === 'none' ? v.value : { board_id: v.value }; + handleChange(value); + }, + [handleChange] + ); + + const boardOptions = useMemo(() => { + const _options: ComboboxOption[] = [ + { label: t('common.auto'), value: 'auto' }, + { label: `${t('common.none')} (${t('boards.uncategorized')})`, value: 'none' }, + ]; + if (boardsData) { + for (const board of boardsData) { + _options.push({ + label: board.board_name, + value: board.board_id, + }); + } + } + return _options; + }, [boardsData, t]); + + const boardValue = useMemo(() => { + const _value = currentValue; + const autoOption = boardOptions[0]; + const noneOption = boardOptions[1]; + if (!_value || _value === 'auto') { + return autoOption; + } + if (_value === 'none') { + return noneOption; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const boardId = typeof _value === 'object' ? (_value as any).board_id : _value; + const boardOption = boardOptions.find((o) => o.value === boardId); + return boardOption ?? autoOption; + }, [currentValue, boardOptions]); + + const noOptionsMessage = useCallback(() => t('boards.noMatching'), [t]); + + // ModelIdentifierField handlers + const handleModelChange = useCallback( + (value: AnyModelConfig | null) => { + if (!value) { + return; + } + handleChange(value); + }, + [handleChange] + ); + + const modelConfigs = useMemo(() => { + if (!modelsData) { + return EMPTY_ARRAY; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_base = fieldTemplate ? (fieldTemplate as any)?.ui_model_base : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_type = fieldTemplate ? (fieldTemplate as any)?.ui_model_type : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_variant = fieldTemplate ? (fieldTemplate as any)?.ui_model_variant : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_format = fieldTemplate ? (fieldTemplate as any)?.ui_model_format : null; + + if (!ui_model_base && !ui_model_type) { + return modelConfigsAdapterSelectors.selectAll(modelsData); + } + + return modelConfigsAdapterSelectors.selectAll(modelsData).filter((config) => { + if (ui_model_base && !ui_model_base.includes(config.base)) { + return false; + } + if (ui_model_type && !ui_model_type.includes(config.type)) { + return false; + } + if (ui_model_variant && 'variant' in config && config.variant && !ui_model_variant.includes(config.variant)) { + return false; + } + if (ui_model_format && !ui_model_format.includes(config.format)) { + return false; + } + return true; + }); + }, [modelsData, fieldTemplate]); + + // ImageField handler + const handleImageFieldSelect = useCallback(() => { + dispatch(canvasWorkflowIntegrationImageFieldSelected({ fieldKey })); + }, [dispatch, fieldKey]); + + if (!field || !fieldTemplate) { + log.warn({ fieldIdentifier }, 'Field or template is null - not rendering'); + return null; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const label = (field as any)?.label || (fieldTemplate as any)?.title || fieldIdentifier.fieldName; + + // Log the entire field structure to understand its shape + log.debug( + { fieldType, label, currentValue, fieldStructure: field, fieldTemplateStructure: fieldTemplate }, + 'Field info' + ); + + // ImageField - allow user to select which one receives the canvas image + if (fieldType === 'ImageField') { + return ( + + ); + } + + // Render different input types based on field type + if (fieldType === 'StringField') { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const isTextarea = (fieldTemplate as any)?.ui_component === 'textarea'; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const isRequired = (fieldTemplate as any)?.required ?? false; + + if (isTextarea) { + return ( + + {label} +