Compare commits

..

1 Commits

Author SHA1 Message Date
Kent Keirsey
21a05f4287 fix controlnets for latest diffusers v 2025-09-17 13:03:11 -04:00
101 changed files with 3249 additions and 1717 deletions

View File

@@ -5,7 +5,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import (
GlmEncoderField,
ModelIdentifierField,
@@ -14,7 +14,6 @@ from invokeai.app.invocations.model import (
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("cogview4_model_loader_output")
@@ -39,9 +38,8 @@ class CogView4ModelLoaderInvocation(BaseInvocation):
model: ModelIdentifierField = InputField(
description=FieldDescriptions.cogview4_model,
ui_type=UIType.CogView4MainModel,
input=Input.Direct,
ui_model_base=BaseModelType.CogView4,
ui_model_type=ModelType.Main,
)
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:

View File

@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import (
ImageField,
InputField,
OutputField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
@@ -27,7 +28,6 @@ from invokeai.app.util.controlnet_utils import (
heuristic_resize_fast,
)
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class ControlField(BaseModel):
@@ -63,17 +63,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation(
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
)
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL],
ui_model_type=ModelType.ControlNet,
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@@ -7,13 +7,6 @@ from pydantic_core import PydanticUndefined
from invokeai.app.util.metaenum import MetaEnum
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
ModelFormat,
ModelType,
ModelVariantType,
)
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@@ -46,9 +39,42 @@ class UIType(str, Enum, metaclass=MetaEnum):
used, and the type will be ignored. They are included here for backwards compatibility.
"""
# region Model Field Types
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
Gemini2_5Model = "Gemini2_5ModelField"
FluxKontextModel = "FluxKontextModelField"
Veo3Model = "Veo3ModelField"
RunwayModel = "RunwayModelField"
# endregion
# region Misc Field Types
Scheduler = "SchedulerField"
Any = "AnyField"
Video = "VideoField"
# endregion
# region Internal Field Types
@@ -98,38 +124,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
MetadataDict = "DEPRECATED_MetadataDict"
# Deprecated Model Field Types - use ui_model_[base|type|variant|format] instead
MainModel = "DEPRECATED_MainModelField"
CogView4MainModel = "DEPRECATED_CogView4MainModelField"
FluxMainModel = "DEPRECATED_FluxMainModelField"
SD3MainModel = "DEPRECATED_SD3MainModelField"
SDXLMainModel = "DEPRECATED_SDXLMainModelField"
SDXLRefinerModel = "DEPRECATED_SDXLRefinerModelField"
ONNXModel = "DEPRECATED_ONNXModelField"
VAEModel = "DEPRECATED_VAEModelField"
FluxVAEModel = "DEPRECATED_FluxVAEModelField"
LoRAModel = "DEPRECATED_LoRAModelField"
ControlNetModel = "DEPRECATED_ControlNetModelField"
IPAdapterModel = "DEPRECATED_IPAdapterModelField"
T2IAdapterModel = "DEPRECATED_T2IAdapterModelField"
T5EncoderModel = "DEPRECATED_T5EncoderModelField"
CLIPEmbedModel = "DEPRECATED_CLIPEmbedModelField"
CLIPLEmbedModel = "DEPRECATED_CLIPLEmbedModelField"
CLIPGEmbedModel = "DEPRECATED_CLIPGEmbedModelField"
SpandrelImageToImageModel = "DEPRECATED_SpandrelImageToImageModelField"
ControlLoRAModel = "DEPRECATED_ControlLoRAModelField"
SigLipModel = "DEPRECATED_SigLipModelField"
FluxReduxModel = "DEPRECATED_FluxReduxModelField"
LlavaOnevisionModel = "DEPRECATED_LLaVAModelField"
Imagen3Model = "DEPRECATED_Imagen3ModelField"
Imagen4Model = "DEPRECATED_Imagen4ModelField"
ChatGPT4oModel = "DEPRECATED_ChatGPT4oModelField"
Gemini2_5Model = "DEPRECATED_Gemini2_5ModelField"
FluxKontextModel = "DEPRECATED_FluxKontextModelField"
Veo3Model = "DEPRECATED_Veo3ModelField"
RunwayModel = "DEPRECATED_RunwayModelField"
# endregion
class UIComponent(str, Enum, metaclass=MetaEnum):
"""
@@ -415,10 +409,6 @@ class InputFieldJSONSchemaExtra(BaseModel):
ui_component: Optional[UIComponent] = None
ui_order: Optional[int] = None
ui_choice_labels: Optional[dict[str, str]] = None
ui_model_base: Optional[list[BaseModelType]] = None
ui_model_type: Optional[list[ModelType]] = None
ui_model_variant: Optional[list[ClipVariantType | ModelVariantType]] = None
ui_model_format: Optional[list[ModelFormat]] = None
model_config = ConfigDict(
validate_assignment=True,
@@ -475,9 +465,9 @@ class OutputFieldJSONSchemaExtra(BaseModel):
"""
field_kind: FieldKind
ui_hidden: bool = False
ui_order: Optional[int] = None
ui_type: Optional[UIType] = None
ui_hidden: bool
ui_type: Optional[UIType]
ui_order: Optional[int]
model_config = ConfigDict(
validate_assignment=True,
@@ -511,63 +501,35 @@ def InputField(
ui_hidden: Optional[bool] = None,
ui_order: Optional[int] = None,
ui_choice_labels: Optional[dict[str, str]] = None,
ui_model_base: Optional[BaseModelType | list[BaseModelType]] = None,
ui_model_type: Optional[ModelType | list[ModelType]] = None,
ui_model_variant: Optional[ClipVariantType | ModelVariantType | list[ClipVariantType | ModelVariantType]] = None,
ui_model_format: Optional[ModelFormat | list[ModelFormat]] = None,
) -> Any:
"""
Creates an input field for an invocation.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field)
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
that adds a few extra parameters to support graph execution and the node editor UI.
If the field is a `ModelIdentifierField`, use the `ui_model_[base|type|variant|format]` args to filter the model list
in the Workflow Editor. Otherwise, use `ui_type` to provide extra type hints for the UI.
:param Input input: [Input.Any] The kind of input this field requires. \
`Input.Direct` means a value must be provided on instantiation. \
`Input.Connection` means the value must be provided by a connection. \
`Input.Any` means either will do.
Don't use both `ui_type` and `ui_model_[base|type|variant|format]` - if both are provided, a warning will be
logged and `ui_type` will be ignored.
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
Args:
input: The kind of input this field requires.
- `Input.Direct` means a value must be provided on instantiation.
- `Input.Connection` means the value must be provided by a connection.
- `Input.Any` means either will do.
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
The UI will always render a suitable component, but sometimes you want something different than the default. \
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
For this case, you could provide `UIComponent.Textarea`.
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
ui_component: Optionally specifies a specific component to use in the UI. The UI will always render a suitable
component, but sometimes you want something different than the default. For example, a `string` field will
default to a single-line input, but you may want a multi-line textarea instead. In this case, you could use
`UIComponent.Textarea`.
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
ui_hidden: Specifies whether or not this field should be hidden in the UI.
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
ui_model_base: Specifies the base model architectures to filter the model list by in the Workflow Editor. For
example, `ui_model_base=BaseModelType.StableDiffusionXL` will show only SDXL architecture models. This arg is
only valid if this Input field is annotated as a `ModelIdentifierField`.
ui_model_type: Specifies the model type(s) to filter the model list by in the Workflow Editor. For example,
`ui_model_type=ModelType.VAE` will show only VAE models. This arg is only valid if this Input field is
annotated as a `ModelIdentifierField`.
ui_model_variant: Specifies the model variant(s) to filter the model list by in the Workflow Editor. For example,
`ui_model_variant=ModelVariantType.Inpainting` will show only inpainting models. This arg is only valid if this
Input field is annotated as a `ModelIdentifierField`.
ui_model_format: Specifies the model format(s) to filter the model list by in the Workflow Editor. For example,
`ui_model_format=ModelFormat.Diffusers` will show only models in the diffusers format. This arg is only valid
if this Input field is annotated as a `ModelIdentifierField`.
ui_choice_labels: Specifies the labels to use for the choices in an enum field. If omitted, the enum values
will be used. This arg is only valid if the field is annotated with as a `Literal`. For example,
`Literal["choice1", "choice2", "choice3"]` with `ui_choice_labels={"choice1": "Choice 1", "choice2": "Choice 2",
"choice3": "Choice 3"}` will render a dropdown with the labels "Choice 1", "Choice 2" and "Choice 3".
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
"""
json_schema_extra_ = InputFieldJSONSchemaExtra(
@@ -576,92 +538,7 @@ def InputField(
)
if ui_type is not None:
if (
ui_model_base is not None
or ui_model_type is not None
or ui_model_variant is not None
or ui_model_format is not None
):
logger.warning("InputField: Use either ui_type or ui_model_[base|type|variant|format]. Ignoring ui_type.")
# Map old-style UIType to new-style ui_model_[base|type|variant|format]
elif ui_type is UIType.MainModel:
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.CogView4MainModel:
json_schema_extra_.ui_model_base = [BaseModelType.CogView4]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.FluxMainModel:
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.SD3MainModel:
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusion3]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.SDXLMainModel:
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXL]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.SDXLRefinerModel:
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
json_schema_extra_.ui_model_type = [ModelType.Main]
# Think this UIType is unused...?
# elif ui_type is UIType.ONNXModel:
# json_schema_extra_.ui_model_base =
# json_schema_extra_.ui_model_type =
elif ui_type is UIType.VAEModel:
json_schema_extra_.ui_model_type = [ModelType.VAE]
elif ui_type is UIType.FluxVAEModel:
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
json_schema_extra_.ui_model_type = [ModelType.VAE]
elif ui_type is UIType.LoRAModel:
json_schema_extra_.ui_model_type = [ModelType.LoRA]
elif ui_type is UIType.ControlNetModel:
json_schema_extra_.ui_model_type = [ModelType.ControlNet]
elif ui_type is UIType.IPAdapterModel:
json_schema_extra_.ui_model_type = [ModelType.IPAdapter]
elif ui_type is UIType.T2IAdapterModel:
json_schema_extra_.ui_model_type = [ModelType.T2IAdapter]
elif ui_type is UIType.T5EncoderModel:
json_schema_extra_.ui_model_type = [ModelType.T5Encoder]
elif ui_type is UIType.CLIPEmbedModel:
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
elif ui_type is UIType.CLIPLEmbedModel:
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
json_schema_extra_.ui_model_variant = [ClipVariantType.L]
elif ui_type is UIType.CLIPGEmbedModel:
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
json_schema_extra_.ui_model_variant = [ClipVariantType.G]
elif ui_type is UIType.SpandrelImageToImageModel:
json_schema_extra_.ui_model_type = [ModelType.SpandrelImageToImage]
elif ui_type is UIType.ControlLoRAModel:
json_schema_extra_.ui_model_type = [ModelType.ControlLoRa]
elif ui_type is UIType.SigLipModel:
json_schema_extra_.ui_model_type = [ModelType.SigLIP]
elif ui_type is UIType.FluxReduxModel:
json_schema_extra_.ui_model_type = [ModelType.FluxRedux]
elif ui_type is UIType.LlavaOnevisionModel:
json_schema_extra_.ui_model_type = [ModelType.LlavaOnevision]
elif ui_type is UIType.Imagen3Model:
json_schema_extra_.ui_model_base = [BaseModelType.Imagen3]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.Imagen4Model:
json_schema_extra_.ui_model_base = [BaseModelType.Imagen4]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.ChatGPT4oModel:
json_schema_extra_.ui_model_base = [BaseModelType.ChatGPT4o]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.Gemini2_5Model:
json_schema_extra_.ui_model_base = [BaseModelType.Gemini2_5]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.FluxKontextModel:
json_schema_extra_.ui_model_base = [BaseModelType.FluxKontext]
json_schema_extra_.ui_model_type = [ModelType.Main]
elif ui_type is UIType.Veo3Model:
json_schema_extra_.ui_model_base = [BaseModelType.Veo3]
json_schema_extra_.ui_model_type = [ModelType.Video]
elif ui_type is UIType.RunwayModel:
json_schema_extra_.ui_model_base = [BaseModelType.Runway]
json_schema_extra_.ui_model_type = [ModelType.Video]
else:
json_schema_extra_.ui_type = ui_type
json_schema_extra_.ui_type = ui_type
if ui_component is not None:
json_schema_extra_.ui_component = ui_component
if ui_hidden is not None:
@@ -670,26 +547,6 @@ def InputField(
json_schema_extra_.ui_order = ui_order
if ui_choice_labels is not None:
json_schema_extra_.ui_choice_labels = ui_choice_labels
if ui_model_base is not None:
if isinstance(ui_model_base, list):
json_schema_extra_.ui_model_base = ui_model_base
else:
json_schema_extra_.ui_model_base = [ui_model_base]
if ui_model_type is not None:
if isinstance(ui_model_type, list):
json_schema_extra_.ui_model_type = ui_model_type
else:
json_schema_extra_.ui_model_type = [ui_model_type]
if ui_model_variant is not None:
if isinstance(ui_model_variant, list):
json_schema_extra_.ui_model_variant = ui_model_variant
else:
json_schema_extra_.ui_model_variant = [ui_model_variant]
if ui_model_format is not None:
if isinstance(ui_model_format, list):
json_schema_extra_.ui_model_format = ui_model_format
else:
json_schema_extra_.ui_model_format = [ui_model_format]
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's
@@ -791,20 +648,20 @@ def OutputField(
"""
Creates an output field for an invocation output.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization)
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
that adds a few extra parameters to support graph execution and the node editor UI.
Args:
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
ui_hidden: Specifies whether or not this field should be hidden in the UI.
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
"""
return Field(
default=default,
title=title,
@@ -822,9 +679,9 @@ def OutputField(
min_length=min_length,
max_length=max_length,
json_schema_extra=OutputFieldJSONSchemaExtra(
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
ui_type=ui_type,
field_kind=FieldKind.Output,
).model_dump(exclude_none=True),
)

View File

@@ -4,10 +4,9 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("flux_control_lora_loader_output")
@@ -30,10 +29,7 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
"""LoRA model and Image to use with FLUX transformer generation."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.control_lora_model,
title="Control LoRA",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.ControlLoRa,
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
)
image: ImageField = InputField(description="The image to encode.")
weight: float = InputField(description="The weight of the LoRA.", default=1.0)

View File

@@ -6,12 +6,11 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class FluxControlNetField(BaseModel):
@@ -58,9 +57,7 @@ class FluxControlNetInvocation(BaseInvocation):
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model,
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.ControlNet,
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float | list[float] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@@ -5,7 +5,7 @@ from pydantic import field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField
from invokeai.app.invocations.fields import InputField, UIType
from invokeai.app.invocations.ip_adapter import (
CLIP_VISION_MODEL_MAP,
IPAdapterField,
@@ -20,7 +20,6 @@ from invokeai.backend.model_manager.config import (
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation(
@@ -37,10 +36,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
image: ImageField = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.IPAdapter,
description="The IP-Adapter model.", title="IP-Adapter Model", ui_type=UIType.IPAdapterModel
)
# Currently, the only known ViT model used by FLUX IP-Adapters is ViT-L.
clip_vision_model: Literal["ViT-L"] = InputField(description="CLIP Vision model to use.", default="ViT-L")

View File

@@ -6,10 +6,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType
@invocation_output("flux_lora_loader_output")
@@ -36,10 +36,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.LoRA,
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(

View File

@@ -6,7 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
@@ -17,7 +17,7 @@ from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import SubModelType
@invocation_output("flux_model_loader_output")
@@ -46,30 +46,23 @@ class FluxModelLoaderInvocation(BaseInvocation):
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.Main,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder,
input=Input.Direct,
title="T5 Encoder",
ui_model_type=ModelType.T5Encoder,
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
ui_model_type=ModelType.CLIPEmbed,
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model,
title="VAE",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.VAE,
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:

View File

@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
@@ -63,8 +64,7 @@ class FluxReduxInvocation(BaseInvocation):
redux_model: ModelIdentifierField = InputField(
description="The FLUX Redux model to use.",
title="FLUX Redux Model",
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.FluxRedux,
ui_type=UIType.FluxReduxModel,
)
downsampling_factor: int = InputField(
ge=1,

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@@ -85,8 +85,7 @@ class IPAdapterInvocation(BaseInvocation):
description="The IP-Adapter model.",
title="IP-Adapter Model",
ui_order=-1,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
ui_model_type=ModelType.IPAdapter,
ui_type=UIType.IPAdapterModel,
)
clip_vision_model: Literal["ViT-H", "ViT-G", "ViT-L"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",

View File

@@ -6,12 +6,11 @@ from pydantic import field_validator
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.util.devices import TorchDevice
@@ -35,7 +34,7 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
vllm_model: ModelIdentifierField = InputField(
title="LLaVA Model Type",
description=FieldDescriptions.vllm_model,
ui_model_type=ModelType.LlavaOnevision,
ui_type=UIType.LlavaOnevisionModel,
)
@field_validator("images", mode="before")

View File

@@ -53,7 +53,7 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import ModelType, SubModelType
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.version import __version__
@@ -473,6 +473,7 @@ class MetadataToModelOutput(BaseInvocationOutput):
model: ModelIdentifierField = OutputField(
description=FieldDescriptions.main_model,
title="Model",
ui_type=UIType.MainModel,
)
name: str = OutputField(description="Model Name", title="Name")
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@@ -487,6 +488,7 @@ class MetadataToSDXLModelOutput(BaseInvocationOutput):
model: ModelIdentifierField = OutputField(
description=FieldDescriptions.main_model,
title="Model",
ui_type=UIType.SDXLMainModel,
)
name: str = OutputField(description="Model Name", title="Name")
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@@ -517,7 +519,8 @@ class MetadataToModelInvocation(BaseInvocation, WithMetadata):
input=Input.Direct,
)
default_value: ModelIdentifierField = InputField(
description="The default model to use if not found in the metadata", ui_model_type=ModelType.Main
description="The default model to use if not found in the metadata",
ui_type=UIType.MainModel,
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
@@ -572,8 +575,7 @@ class MetadataToSDXLModelInvocation(BaseInvocation, WithMetadata):
)
default_value: ModelIdentifierField = InputField(
description="The default SDXL Model to use if not found in the metadata",
ui_model_type=ModelType.Main,
ui_model_base=BaseModelType.StableDiffusionXL,
ui_type=UIType.SDXLMainModel,
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)

View File

@@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@@ -145,7 +145,7 @@ class ModelIdentifierInvocation(BaseInvocation):
@invocation(
"main_model_loader",
title="Main Model - SD1.5, SD2",
title="Main Model - SD1.5",
tags=["model"],
category="model",
version="1.0.4",
@@ -153,11 +153,7 @@ class ModelIdentifierInvocation(BaseInvocation):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2],
ui_model_type=ModelType.Main,
)
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@@ -191,10 +187,7 @@ class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.StableDiffusion1,
ui_model_type=ModelType.LoRA,
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -257,9 +250,7 @@ class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_type=ModelType.LoRA,
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@@ -341,10 +332,7 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.StableDiffusionXL,
ui_model_type=ModelType.LoRA,
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -485,26 +473,13 @@ class SDXLLoRACollectionLoader(BaseInvocation):
@invocation(
"vae_loader",
title="VAE Model - SD1.5, SD2, SDXL, SD3, FLUX",
tags=["vae", "model"],
category="model",
version="1.0.4",
"vae_loader", title="VAE Model - SD1.5, SDXL, SD3, FLUX", tags=["vae", "model"], category="model", version="1.0.4"
)
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model,
title="VAE",
ui_model_base=[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusion3,
BaseModelType.Flux,
],
ui_model_type=ModelType.VAE,
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
)
def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@@ -6,14 +6,14 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ClipVariantType, ModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import SubModelType
@invocation_output("sd3_model_loader_output")
@@ -39,43 +39,36 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
ui_model_base=BaseModelType.StableDiffusion3,
ui_model_type=ModelType.Main,
)
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="T5 Encoder",
default=None,
ui_model_type=ModelType.T5Encoder,
)
clip_l_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPLEmbedModel,
input=Input.Direct,
title="CLIP L Encoder",
default=None,
ui_model_type=ModelType.CLIPEmbed,
ui_model_variant=ClipVariantType.L,
)
clip_g_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_g_model,
ui_type=UIType.CLIPGEmbedModel,
input=Input.Direct,
title="CLIP G Encoder",
default=None,
ui_model_type=ModelType.CLIPEmbed,
ui_model_variant=ClipVariantType.G,
)
vae_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.vae_model,
title="VAE",
default=None,
ui_model_base=BaseModelType.StableDiffusion3,
ui_model_type=ModelType.VAE,
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
)
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:

View File

@@ -1,8 +1,8 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import SubModelType
@invocation_output("sdxl_model_loader_output")
@@ -29,9 +29,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model,
ui_model_base=BaseModelType.StableDiffusionXL,
ui_model_type=ModelType.Main,
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@@ -69,9 +67,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
ui_model_base=BaseModelType.StableDiffusionXLRefiner,
ui_model_type=ModelType.Main,
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?

View File

@@ -11,6 +11,7 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
UIType,
WithBoard,
WithMetadata,
)
@@ -18,7 +19,6 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
@@ -33,7 +33,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_to_image_model: ModelIdentifierField = InputField(
title="Image-to-Image Model",
description=FieldDescriptions.spandrel_image_to_image_model,
ui_model_type=ModelType.SpandrelImageToImage,
ui_type=UIType.SpandrelImageToImageModel,
)
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."

View File

@@ -8,12 +8,11 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class T2IAdapterField(BaseModel):
@@ -61,8 +60,7 @@ class T2IAdapterInvocation(BaseInvocation):
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
ui_order=-1,
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
ui_model_type=ModelType.T2IAdapter,
ui_type=UIType.T2IAdapterModel,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"

View File

@@ -17,7 +17,6 @@ from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.fields import ImageField
from invokeai.app.services.shared.field_identifier import FieldIdentifier
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
@@ -210,6 +209,13 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
return None
class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItem(BaseModel):
"""Session queue item without the full graph. Used for serialization."""

View File

@@ -1,14 +0,0 @@
from typing import Literal
from pydantic import BaseModel, Field
class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
user_label: str | None = Field(description="The user label of the field, if any")
__all__ = ["FieldIdentifier"]

View File

@@ -5,7 +5,6 @@ from typing import Any, Optional, Union
import semver
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
from invokeai.app.services.shared.field_identifier import FieldIdentifier
from invokeai.app.util.metaenum import MetaEnum
__workflow_meta_version__ = semver.Version.parse("1.0.0")
@@ -60,10 +59,6 @@ class WorkflowWithoutID(BaseModel):
tags: str = Field(description="The tags of the workflow.")
notes: str = Field(description="The notes of the workflow.")
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
output_fields: list[FieldIdentifier] | None = Field(
default=None,
description="The fields designated as output fields for the workflow.",
)
meta: WorkflowMeta = Field(description="The meta of the workflow.")
# TODO(psyche): nodes, edges and form are very loosely typed - they are strictly modeled and checked on the frontend.
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
@@ -126,29 +121,6 @@ class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):
description: str = Field(description="The description of the workflow.")
category: WorkflowCategory = Field(description="The description of the workflow.")
tags: str = Field(description="The tags of the workflow.")
has_valid_image_output_field: bool = Field(
default=False,
description="True when the workflow exposes exactly one output field and it is an image output.",
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRecordListItemDTO":
workflow_data = data.pop("workflow", None)
has_valid_output = False
if workflow_data:
if isinstance(workflow_data, (str, bytes, bytearray)):
workflow = WorkflowValidator.validate_json(workflow_data)
else:
workflow = WorkflowValidator.validate_python(workflow_data)
output_fields = workflow.output_fields or []
image_outputs = [f for f in output_fields if f.kind == "output" and f.field_name.startswith("image")]
if len(image_outputs) == 1:
has_valid_output = True
data = dict(data)
data["has_valid_image_output_field"] = has_valid_output
return WorkflowRecordListItemDTOValidator.validate_python(data)
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)

View File

@@ -12,6 +12,7 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordListItemDTOValidator,
WorkflowRecordOrderBy,
WorkflowValidator,
WorkflowWithoutID,
@@ -122,8 +123,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
created_at,
updated_at,
opened_at,
tags,
workflow
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
@@ -204,7 +204,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTO.from_dict(dict(row)) for row in rows]
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]

View File

@@ -207,24 +207,15 @@ class IPAdapterPlusXL(IPAdapterPlus):
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
state_dict: IPAdapterStateDict = {
"ip_adapter": {},
"image_proj": {},
"adapter_modules": {}, # added for noobai-mark-ipa
"image_proj_model": {}, # added for noobai-mark-ipa
}
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
if ip_adapter_ckpt_path.suffix == ".safetensors":
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys():
if key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
elif key.startswith("image_proj_model."):
state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = model[key]
elif key.startswith("image_proj."):
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("adapter_modules."):
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = model[key]
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else:

View File

@@ -5,7 +5,11 @@ import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.controlnets.controlnet import (
ControlNetConditioningEmbedding,
ControlNetOutput,
zero_module,
)
from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
@@ -775,7 +779,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel
# Patch both the new and legacy module paths for compatibility
try:
diffusers.models.controlnets.controlnet.ControlNetModel = ControlNetModel
except Exception:
# Fallback for environments still exposing the legacy path
try:
diffusers.models.controlnet.ControlNetModel = ControlNetModel
except Exception:
pass
# patch LoRACompatibleConv to use original Conv2D forward function

View File

@@ -1952,7 +1952,6 @@
"private": "Private",
"shared": "Shared",
"published": "Published",
"validImageOutput": "Valid Image Output",
"browseWorkflows": "Browse Workflows",
"deselectAll": "Deselect All",
"recommended": "Recommended For You",
@@ -2463,18 +2462,6 @@
"cr": "Cr (YCbCr)"
}
},
"triggerWorkflow": {
"menuItem": "Trigger Workflow",
"heading": "Trigger Workflow",
"openLibrary": "Select Workflow",
"noWorkflowSelected": "No workflow selected",
"loading": "Loading workflow...",
"apply": "Apply",
"applying": "Applying",
"cancel": "Cancel",
"enqueued": "Workflow enqueued",
"enqueueFailed": "Unable to enqueue workflow"
},
"transform": {
"transform": "Transform",
"fitToBbox": "Fit to Bbox",

View File

@@ -131,8 +131,7 @@
"notInstalled": "Non $t(common.installed)",
"prevPage": "Pagina precedente",
"nextPage": "Pagina successiva",
"resetToDefaults": "Ripristina impostazioni predefinite",
"crop": "Ritaglia"
"resetToDefaults": "Ripristina impostazioni predefinite"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -279,14 +278,6 @@
"selectVideoTab": {
"title": "Seleziona la scheda Video",
"desc": "Seleziona la scheda Video."
},
"promptHistoryPrev": {
"title": "Prompt precedente nella cronologia",
"desc": "Quando il prompt è attivo, passa al prompt precedente (più vecchio) nella cronologia."
},
"promptHistoryNext": {
"title": "Prossimo prompt nella cronologia",
"desc": "Quando il prompt è attivo, passa al prompt successivo (più recente) nella cronologia."
}
},
"hotkeys": "Tasti di scelta rapida",
@@ -884,8 +875,7 @@
"video": "Video",
"resolution": "Risoluzione",
"downloadImage": "Scarica l'immagine",
"showOptionsPanel": "Mostra pannello laterale (O o T)",
"startingFrameImageAspectRatioWarning": "Le proporzioni dell'immagine non corrispondono alle proporzioni del video ({{videoAspectRatio}}). Ciò potrebbe causare ritagli imprevisti durante la generazione del video."
"showOptionsPanel": "Mostra pannello laterale (O o T)"
},
"settings": {
"models": "Modelli",
@@ -2105,10 +2095,7 @@
"generateFromImage": "Genera prompt dall'immagine",
"resultTitle": "Espansione del prompt completata",
"resultSubtitle": "Scegli come gestire il prompt espanso:",
"insert": "Inserisci",
"noPromptHistory": "Nessuna cronologia di prompt registrata.",
"noMatchingPrompts": "Nessun prompt corrispondente nella cronologia.",
"toSwitchBetweenPrompts": "per passare da un prompt all'altro."
"insert": "Inserisci"
},
"controlLayers": {
"addLayer": "Aggiungi Livello",
@@ -2804,8 +2791,7 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"items": [
"Seleziona oggetto v2: selezione degli oggetti migliorata con input di punti e riquadri o prompt di testo.",
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello.",
"Cronologia prompt: rivedi e richiama rapidamente i tuoi ultimi 100 prompt."
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello."
],
"watchUiUpdatesOverview": "Guarda la panoramica degli aggiornamenti dell'interfaccia utente"
},

View File

@@ -9,7 +9,6 @@ import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/component
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
import { CanvasEntityMenuItemsTriggerWorkflow } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTriggerWorkflow';
import { RasterLayerMenuItemsAdjustments } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments';
import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu';
import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu';
@@ -25,7 +24,6 @@ export const RasterLayerMenuItems = memo(() => {
</IconMenuItemGroup>
<CanvasEntityMenuItemsTransform />
<CanvasEntityMenuItemsFilter />
<CanvasEntityMenuItemsTriggerWorkflow />
<CanvasEntityMenuItemsSelectObject />
<RasterLayerMenuItemsAdjustments />
<MenuDivider />

View File

@@ -1,120 +0,0 @@
import { Button, ButtonGroup, Flex, Heading, Spacer, Spinner, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useWorkflowTriggerApply } from 'features/controlLayers/hooks/useWorkflowTriggerApply';
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
import { parseAndMigrateWorkflow } from 'features/nodes/util/workflow/migrations';
import { toast } from 'features/toast/toast';
import {
setWorkflowLibraryBrowseIntent,
setWorkflowLibraryTriggerIntent,
} from 'features/workflowLibrary/store/workflowLibraryIntent';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBooksDuotone, PiPlayBold, PiXBold } from 'react-icons/pi';
import { useLazyGetWorkflowQuery } from 'services/api/endpoints/workflows';
import type { S } from 'services/api/types';
const TriggerWorkflowContent = memo(() => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const workflowLibraryModal = useWorkflowLibraryModal();
const [fetchWorkflow, { isFetching: isFetchingWorkflow }] = useLazyGetWorkflowQuery();
const isBusy = useCanvasIsBusy();
const { apply, isApplying, selectedWorkflow, selectedWorkflowName } = useWorkflowTriggerApply();
const onSelectWorkflow = useCallback(
(workflow: S['WorkflowRecordListItemWithThumbnailDTO']) => {
const handleSelection = async () => {
try {
const res = await fetchWorkflow(workflow.workflow_id).unwrap();
const migratedWorkflow = parseAndMigrateWorkflow(res.workflow);
canvasManager.stateApi.setWorkflowTriggerSelection({
workflow: migratedWorkflow,
workflowId: workflow.workflow_id,
workflowName: migratedWorkflow.name ?? workflow.name,
});
} catch {
toast({
status: 'error',
title: t('workflows.problemRetrievingWorkflow'),
});
} finally {
setWorkflowLibraryBrowseIntent();
}
};
void handleSelection();
},
[canvasManager.stateApi, fetchWorkflow, t]
);
const openWorkflowLibrary = useCallback(() => {
setWorkflowLibraryTriggerIntent((workflow) => {
onSelectWorkflow(workflow);
workflowLibraryModal.close();
});
workflowLibraryModal.open();
}, [onSelectWorkflow, workflowLibraryModal]);
const isApplyDisabled = useMemo(() => {
return !selectedWorkflow || isApplying || isFetchingWorkflow || isBusy;
}, [selectedWorkflow, isApplying, isFetchingWorkflow, isBusy]);
const cancel = useCallback(() => {
setWorkflowLibraryBrowseIntent();
canvasManager.stateApi.cancelWorkflowTrigger();
}, [canvasManager.stateApi]);
return (
<Flex bg="base.800" borderRadius="base" p={4} flexDir="column" gap={4} w={420} shadow="dark-lg">
<Flex w="full" gap={4} alignItems="center">
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.triggerWorkflow.heading')}
</Heading>
<Spacer />
</Flex>
<Flex flexDir="column" gap={2}>
<Text color="base.400">{selectedWorkflowName ?? t('controlLayers.triggerWorkflow.noWorkflowSelected')}</Text>
{isFetchingWorkflow && (
<Flex alignItems="center" gap={2} color="base.500">
<Spinner size="sm" />
<Text>{t('controlLayers.triggerWorkflow.loading')}</Text>
</Flex>
)}
</Flex>
<ButtonGroup isAttached={false} size="sm" w="full">
<Button
variant="ghost"
leftIcon={<PiBooksDuotone />}
onClick={openWorkflowLibrary}
isDisabled={isApplying || isFetchingWorkflow || isBusy}
>
{t('controlLayers.triggerWorkflow.openLibrary')}
</Button>
<Spacer />
<Button variant="ghost" leftIcon={<PiPlayBold />} onClick={apply} isDisabled={isApplyDisabled}>
{isApplying ? t('controlLayers.triggerWorkflow.applying') : t('controlLayers.triggerWorkflow.apply')}
{isApplying && <Spinner size="sm" ml={2} />}
</Button>
<Button variant="ghost" leftIcon={<PiXBold />} onClick={cancel} isDisabled={isApplying}>
{t('controlLayers.triggerWorkflow.cancel')}
</Button>
</ButtonGroup>
</Flex>
);
});
TriggerWorkflowContent.displayName = 'TriggerWorkflowContent';
export const TriggerWorkflow = () => {
const canvasManager = useCanvasManager();
const state = useStore(canvasManager.stateApi.$workflowTrigger);
if (!state) {
return null;
}
return <TriggerWorkflowContent />;
};
TriggerWorkflow.displayName = 'TriggerWorkflow';

View File

@@ -1,20 +0,0 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityWorkflowTrigger } from 'features/controlLayers/hooks/useEntityWorkflowTrigger';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlayCircleBold } from 'react-icons/pi';
export const CanvasEntityMenuItemsTriggerWorkflow = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext();
const trigger = useEntityWorkflowTrigger(entityIdentifier);
return (
<MenuItem onClick={trigger.start} icon={<PiPlayCircleBold />} isDisabled={trigger.isDisabled}>
{t('controlLayers.triggerWorkflow.menuItem')}
</MenuItem>
);
});
CanvasEntityMenuItemsTriggerWorkflow.displayName = 'CanvasEntityMenuItemsTriggerWorkflow';

View File

@@ -1,62 +0,0 @@
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
import { useEntityIsLocked } from 'features/controlLayers/hooks/useEntityIsLocked';
import { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isRasterLayerEntityIdentifier } from 'features/controlLayers/store/types';
import { useCallback, useMemo } from 'react';
export const useEntityWorkflowTrigger = (entityIdentifier: CanvasEntityIdentifier | null) => {
const canvasManager = useCanvasManager();
const adapter = useEntityAdapterSafe(entityIdentifier);
const isBusy = useCanvasIsBusy();
const isLocked = useEntityIsLocked(entityIdentifier);
const isEmpty = useEntityIsEmpty(entityIdentifier);
const workflowTrigger = useStore(canvasManager.stateApi.$workflowTrigger);
const isDisabled = useMemo(() => {
if (!entityIdentifier) {
return true;
}
if (!isRasterLayerEntityIdentifier(entityIdentifier)) {
return true;
}
if (!adapter) {
return true;
}
if (!(adapter instanceof CanvasEntityAdapterRasterLayer)) {
return true;
}
if (isBusy) {
return true;
}
if (isLocked) {
return true;
}
if (isEmpty) {
return true;
}
if (workflowTrigger && workflowTrigger.adapter !== adapter) {
return true;
}
return false;
}, [adapter, entityIdentifier, isBusy, isEmpty, isLocked, workflowTrigger]);
const start = useCallback(() => {
if (isDisabled) {
return;
}
if (!entityIdentifier || !isRasterLayerEntityIdentifier(entityIdentifier) || !adapter) {
return;
}
if (!(adapter instanceof CanvasEntityAdapterRasterLayer)) {
return;
}
canvasManager.stateApi.startWorkflowTrigger(adapter);
}, [adapter, canvasManager.stateApi, entityIdentifier, isDisabled]);
return { isDisabled, start } as const;
};

View File

@@ -1,142 +0,0 @@
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import { buildInvocationGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { CANVAS_OUTPUT_PREFIX } from 'features/nodes/util/graph/graphBuilderUtils';
import { toast } from 'features/toast/toast';
import { setWorkflowLibraryBrowseIntent } from 'features/workflowLibrary/store/workflowLibraryIntent';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { EnqueueBatchArg } from 'services/api/types';
export const useWorkflowTriggerApply = () => {
const canvasManager = useCanvasManager();
const workflowState = useStore(canvasManager.stateApi.$workflowTrigger);
const selection = workflowState?.selection ?? null;
const selectedWorkflow = selection?.workflow ?? null;
const selectedWorkflowName = selection?.workflowName ?? null;
const canvasSessionId = useAppSelector(selectCanvasSessionId);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [isApplying, setIsApplying] = useState(false);
const templates = useStore($templates);
const apply = useCallback(async () => {
if (!selection || !selectedWorkflow || !canvasSessionId) {
return;
}
setIsApplying(true);
try {
const workflow = deepClone(selectedWorkflow);
const outputFields = workflow.output_fields ?? [];
if (outputFields.length === 0) {
toast({
status: 'error',
title: t('controlLayers.triggerWorkflow.noWorkflowSelected'),
});
return;
}
const outputField = outputFields[0]!;
const originalNodeId = outputField.nodeId;
const fieldName = outputField.fieldName ?? outputField.field_name;
if (!originalNodeId || !fieldName) {
toast({
status: 'error',
title: t('controlLayers.triggerWorkflow.noWorkflowSelected'),
});
return;
}
const outputNodeId = `${CANVAS_OUTPUT_PREFIX}:${selection.workflowId}`;
workflow.output_fields = [
{
...outputField,
kind: 'output',
nodeId: outputNodeId,
node_id: outputNodeId,
fieldName,
field_name: fieldName,
userLabel: outputField.userLabel ?? outputField.user_label ?? null,
user_label: outputField.userLabel ?? outputField.user_label ?? null,
},
];
workflow.nodes = workflow.nodes.map((node) => {
if (node.id !== originalNodeId || node.type !== 'invocation') {
return node;
}
return {
...node,
id: outputNodeId,
data: { ...node.data, id: outputNodeId },
};
});
workflow.edges = workflow.edges.map((edge) => {
if (edge.type !== 'default') {
return edge;
}
return {
...edge,
source: edge.source === originalNodeId ? outputNodeId : edge.source,
target: edge.target === originalNodeId ? outputNodeId : edge.target,
};
});
const graph = buildInvocationGraph({
nodes: workflow.nodes,
edges: workflow.edges,
templates,
graphId: workflow.id,
});
const workflowPayload = deepClone(workflow);
delete workflowPayload.id;
const enqueueArg: EnqueueBatchArg = {
batch: {
graph,
workflow: workflowPayload,
runs: 1,
origin: 'canvas',
destination: canvasSessionId,
},
};
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueArg, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
await req.unwrap();
toast({
status: 'success',
title: t('controlLayers.triggerWorkflow.enqueued'),
});
setWorkflowLibraryBrowseIntent();
canvasManager.stateApi.cancelWorkflowTrigger();
} catch {
toast({
status: 'error',
title: t('controlLayers.triggerWorkflow.enqueueFailed'),
});
} finally {
setIsApplying(false);
}
}, [canvasManager.stateApi, canvasSessionId, dispatch, selection, selectedWorkflow, t, templates]);
return useMemo(
() => ({ apply, isApplying, selection, selectedWorkflow, selectedWorkflowName }) as const,
[apply, isApplying, selection, selectedWorkflow, selectedWorkflowName]
);
};

View File

@@ -50,7 +50,6 @@ import type {
} from 'features/controlLayers/store/types';
import { RGBA_BLACK } from 'features/controlLayers/store/types';
import { zImageOutput } from 'features/nodes/types/common';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
@@ -62,17 +61,6 @@ import type { Param0 } from 'tsafe';
import type { CanvasEntityAdapter } from './CanvasEntity/types';
export type WorkflowTriggerSelection = {
workflow: WorkflowV4;
workflowId: string;
workflowName: string;
};
type WorkflowTriggerState = {
adapter: CanvasEntityAdapterRasterLayer;
selection: WorkflowTriggerSelection | null;
};
export class CanvasStateApiModule extends CanvasModuleBase {
readonly type = 'state_api';
readonly id: string;
@@ -469,26 +457,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
}
};
startWorkflowTrigger = (adapter: CanvasEntityAdapterRasterLayer) => {
const current = this.$workflowTrigger.get();
if (current && current.adapter !== adapter) {
this.$workflowTrigger.set(null);
}
this.$workflowTrigger.set({ adapter, selection: null });
};
cancelWorkflowTrigger = () => {
this.$workflowTrigger.set(null);
};
setWorkflowTriggerSelection = (selection: WorkflowTriggerSelection) => {
const current = this.$workflowTrigger.get();
if (!current) {
return;
}
this.$workflowTrigger.set({ adapter: current.adapter, selection });
};
/**
* The entity adapter being filtered, if any.
*/
@@ -529,16 +497,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
*/
$isSegmenting = computed(this.$segmentingAdapter, (segmentingAdapter) => Boolean(segmentingAdapter));
/**
* The entity adapter currently triggering a workflow, if any.
*/
$workflowTrigger = atom<WorkflowTriggerState | null>(null);
/**
* Whether a workflow trigger is active.
*/
$isWorkflowTriggering = computed(this.$workflowTrigger, (trigger) => Boolean(trigger));
/**
* Whether the space key is currently pressed.
*/

View File

@@ -18,7 +18,6 @@ const modelPaneSx: SystemStyleObject = {
},
h: 'full',
minWidth: '300px',
overflow: 'auto',
};
export const ModelPane = memo(() => {

View File

@@ -40,7 +40,7 @@ export const ModelView = memo(({ modelConfig }: Props) => {
}, [modelConfig.base, modelConfig.type]);
return (
<Flex flexDir="column" gap={4} h="full">
<Flex flexDir="column" gap={4}>
<ModelHeader modelConfig={modelConfig}>
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
<ModelConvertButton modelConfig={modelConfig} />
@@ -48,7 +48,7 @@ export const ModelView = memo(({ modelConfig }: Props) => {
<ModelEditButton />
</ModelHeader>
<Divider />
<Flex flexDir="column" gap={4}>
<Flex flexDir="column" h="full" gap={4}>
<Box>
<SimpleGrid columns={2} gap={4}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />

View File

@@ -1,10 +1,14 @@
import { FloatFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInput';
import { FloatFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInputAndSlider';
import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldSlider';
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import FluxKontextModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxKontextModelFieldInputComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
@@ -23,8 +27,22 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isChatGPT4oModelFieldInputInstance,
isChatGPT4oModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
isCLIPGEmbedModelFieldInputTemplate,
isCLIPLEmbedModelFieldInputInstance,
isCLIPLEmbedModelFieldInputTemplate,
isCogView4MainModelFieldInputInstance,
isCogView4MainModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlLoRAModelFieldInputInstance,
isControlLoRAModelFieldInputTemplate,
isControlNetModelFieldInputInstance,
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldCollectionInputInstance,
@@ -33,28 +51,68 @@ import {
isFloatFieldInputTemplate,
isFloatGeneratorFieldInputInstance,
isFloatGeneratorFieldInputTemplate,
isFluxKontextModelFieldInputInstance,
isFluxKontextModelFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxReduxModelFieldInputInstance,
isFluxReduxModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
isFluxVAEModelFieldInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isImageGeneratorFieldInputInstance,
isImageGeneratorFieldInputTemplate,
isImagen3ModelFieldInputInstance,
isImagen3ModelFieldInputTemplate,
isImagen4ModelFieldInputInstance,
isImagen4ModelFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isIntegerGeneratorFieldInputInstance,
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLLaVAModelFieldInputInstance,
isLLaVAModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
isMainModelFieldInputTemplate,
isModelIdentifierFieldInputInstance,
isModelIdentifierFieldInputTemplate,
isRunwayModelFieldInputInstance,
isRunwayModelFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate,
isSigLipModelFieldInputInstance,
isSigLipModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStringGeneratorFieldInputInstance,
isStringGeneratorFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
isT5EncoderModelFieldInputTemplate,
isVAEModelFieldInputInstance,
isVAEModelFieldInputTemplate,
isVeo3ModelFieldInputInstance,
isVeo3ModelFieldInputTemplate,
} from 'features/nodes/types/field';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { memo } from 'react';
@@ -63,10 +121,33 @@ import { assert } from 'tsafe';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import CogView4MainModelFieldInputComponent from './inputs/CogView4MainModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInputComponent';
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import RunwayModelFieldInputComponent from './inputs/RunwayModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SigLipModelFieldInputComponent from './inputs/SigLipModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
import Veo3ModelFieldInputComponent from './inputs/Veo3ModelFieldInputComponent';
type Props = {
nodeId: string;
@@ -206,6 +287,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <BoardFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isMainModelFieldInputTemplate(template)) {
if (!isMainModelFieldInputInstance(field)) {
return null;
}
return <MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isModelIdentifierFieldInputTemplate(template)) {
if (!isModelIdentifierFieldInputInstance(field)) {
return null;
@@ -213,6 +301,159 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSDXLRefinerModelFieldInputTemplate(template)) {
if (!isSDXLRefinerModelFieldInputInstance(field)) {
return null;
}
return <RefinerModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isVAEModelFieldInputTemplate(template)) {
if (!isVAEModelFieldInputInstance(field)) {
return null;
}
return <VAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isT5EncoderModelFieldInputTemplate(template)) {
if (!isT5EncoderModelFieldInputInstance(field)) {
return null;
}
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCLIPEmbedModelFieldInputTemplate(template)) {
if (!isCLIPEmbedModelFieldInputInstance(field)) {
return null;
}
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCLIPLEmbedModelFieldInputTemplate(template)) {
if (!isCLIPLEmbedModelFieldInputInstance(field)) {
return null;
}
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCLIPGEmbedModelFieldInputTemplate(template)) {
if (!isCLIPGEmbedModelFieldInputInstance(field)) {
return null;
}
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isControlLoRAModelFieldInputTemplate(template)) {
if (!isControlLoRAModelFieldInputInstance(field)) {
return null;
}
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isLLaVAModelFieldInputTemplate(template)) {
if (!isLLaVAModelFieldInputInstance(field)) {
return null;
}
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxVAEModelFieldInputTemplate(template)) {
if (!isFluxVAEModelFieldInputInstance(field)) {
return null;
}
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isLoRAModelFieldInputTemplate(template)) {
if (!isLoRAModelFieldInputInstance(field)) {
return null;
}
return <LoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isControlNetModelFieldInputTemplate(template)) {
if (!isControlNetModelFieldInputInstance(field)) {
return null;
}
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isIPAdapterModelFieldInputTemplate(template)) {
if (!isIPAdapterModelFieldInputInstance(field)) {
return null;
}
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isT2IAdapterModelFieldInputTemplate(template)) {
if (!isT2IAdapterModelFieldInputInstance(field)) {
return null;
}
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSpandrelImageToImageModelFieldInputTemplate(template)) {
if (!isSpandrelImageToImageModelFieldInputInstance(field)) {
return null;
}
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSigLipModelFieldInputTemplate(template)) {
if (!isSigLipModelFieldInputInstance(field)) {
return null;
}
return <SigLipModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxReduxModelFieldInputTemplate(template)) {
if (!isFluxReduxModelFieldInputInstance(field)) {
return null;
}
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen3ModelFieldInputTemplate(template)) {
if (!isImagen3ModelFieldInputInstance(field)) {
return null;
}
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen4ModelFieldInputTemplate(template)) {
if (!isImagen4ModelFieldInputInstance(field)) {
return null;
}
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxKontextModelFieldInputTemplate(template)) {
if (!isFluxKontextModelFieldInputInstance(field)) {
return null;
}
return <FluxKontextModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;
}
return <ChatGPT4oModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isVeo3ModelFieldInputTemplate(template)) {
if (!isVeo3ModelFieldInputInstance(field)) {
return null;
}
return <Veo3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isRunwayModelFieldInputTemplate(template)) {
if (!isRunwayModelFieldInputInstance(field)) {
return null;
}
return <RunwayModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isColorFieldInputTemplate(template)) {
if (!isColorFieldInputInstance(field)) {
return null;
@@ -220,6 +461,34 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ColorFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxMainModelFieldInputTemplate(template)) {
if (!isFluxMainModelFieldInputInstance(field)) {
return null;
}
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSD3MainModelFieldInputTemplate(template)) {
if (!isSD3MainModelFieldInputInstance(field)) {
return null;
}
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isCogView4MainModelFieldInputTemplate(template)) {
if (!isCogView4MainModelFieldInputInstance(field)) {
return null;
}
return <CogView4MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSDXLMainModelFieldInputTemplate(template)) {
if (!isSDXLMainModelFieldInputInstance(field)) {
return null;
}
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSchedulerFieldInputTemplate(template)) {
if (!isSchedulerFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const onChange = useCallback(
(value: CLIPEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(CLIPEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,45 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const onChange = useCallback(
(value: CLIPGEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPGEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(CLIPGEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,45 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const onChange = useCallback(
(value: CLIPLEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPLEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config))}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(CLIPLEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldChatGPT4oModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useChatGPT4oModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const ChatGPT4oModelFieldInputComponent = (
props: FieldComponentProps<ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useChatGPT4oModels();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldChatGPT4oModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ChatGPT4oModelFieldInputComponent);

View File

@@ -0,0 +1,63 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type {
CogView4MainModelFieldInputInstance,
CogView4MainModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useCogView4Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CogView4MainModelFieldInputInstance, CogView4MainModelFieldInputTemplate>;
const CogView4MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCogView4Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(CogView4MainModelFieldInputComponent);

View File

@@ -0,0 +1,48 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
ControlLoRAModelFieldInputInstance,
ControlLoRAModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
import type { ControlLoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRAModelFieldInputTemplate>;
const ControlLoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlLoRAModel();
const onChange = useCallback(
(value: ControlLoRAModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldControlLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ControlLoRAModelFieldInputComponent);

View File

@@ -0,0 +1,45 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate>;
const ControlNetModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlNetModels();
const onChange = useCallback(
(value: ControlNetModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldControlNetModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(ControlNetModelFieldInputComponent);

View File

@@ -0,0 +1,49 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxKontextModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
FluxKontextModelFieldInputInstance,
FluxKontextModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxKontextModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const FluxKontextModelFieldInputComponent = (
props: FieldComponentProps<FluxKontextModelFieldInputInstance, FluxKontextModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxKontextModels();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxKontextModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxKontextModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxMainModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
import type { FLUXReduxModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const FluxReduxModelFieldInputComponent = (
props: FieldComponentProps<FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxReduxModels();
const onChange = useCallback(
(value: FLUXReduxModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxReduxModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxReduxModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
const FluxVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxVAEModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(FluxVAEModelFieldInputComponent);

View File

@@ -0,0 +1,45 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { IPAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const IPAdapterModelFieldInputComponent = (
props: FieldComponentProps<IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const onChange = useCallback(
(value: IPAdapterModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldIPAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(IPAdapterModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen3ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen3Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen3ModelFieldInputComponent = (
props: FieldComponentProps<Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen3Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen3ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen3ModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen4Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen4ModelFieldInputComponent = (
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen4Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen4ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen4ModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
import type { LlavaOnevisionConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
const LLaVAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLLaVAModels();
const onChange = useCallback(
(value: LlavaOnevisionConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldLLaVAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(LLaVAModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate>;
const LoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLoRAModels();
const onChange = useCallback(
(value: LoRAModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(LoRAModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInputTemplate>;
const MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(MainModelFieldInputComponent);

View File

@@ -12,7 +12,7 @@ import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
const ModelIdentifierFieldInputComponent = (props: Props) => {
const { nodeId, field, fieldTemplate } = props;
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetModelConfigsQuery();
const onChange = useCallback(
@@ -36,31 +36,8 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
return EMPTY_ARRAY;
}
if (!fieldTemplate.ui_model_base && !fieldTemplate.ui_model_type) {
return modelConfigsAdapterSelectors.selectAll(data);
}
return modelConfigsAdapterSelectors.selectAll(data).filter((config) => {
if (fieldTemplate.ui_model_base && !fieldTemplate.ui_model_base.includes(config.base)) {
return false;
}
if (fieldTemplate.ui_model_type && !fieldTemplate.ui_model_type.includes(config.type)) {
return false;
}
if (
fieldTemplate.ui_model_variant &&
'variant' in config &&
config.variant &&
!fieldTemplate.ui_model_variant.includes(config.variant)
) {
return false;
}
if (fieldTemplate.ui_model_format && !fieldTemplate.ui_model_format.includes(config.format)) {
return false;
}
return true;
});
}, [data, fieldTemplate]);
return modelConfigsAdapterSelectors.selectAll(data);
}, [data]);
return (
<ModelFieldCombobox

View File

@@ -0,0 +1,47 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
SDXLRefinerModelFieldInputInstance,
SDXLRefinerModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefinerModelFieldInputTemplate>;
const RefinerModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useRefinerModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldRefinerModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(RefinerModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldRunwayModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useRunwayModels } from 'services/api/hooks/modelsByType';
import type { VideoApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const RunwayModelFieldInputComponent = (
props: FieldComponentProps<RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useRunwayModels();
const onChange = useCallback(
(value: VideoApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldRunwayModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(RunwayModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SD3MainModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSDXLModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate>;
const SDXLMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSDXLModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SDXLMainModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSigLipModels } from 'services/api/hooks/modelsByType';
import type { SigLipModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const SigLipModelFieldInputComponent = (
props: FieldComponentProps<SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSigLipModels();
const onChange = useCallback(
(value: SigLipModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldSigLipModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SigLipModelFieldInputComponent);

View File

@@ -0,0 +1,49 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
SpandrelImageToImageModelFieldInputInstance,
SpandrelImageToImageModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const SpandrelImageToImageModelFieldInputComponent = (
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const onChange = useCallback(
(value: SpandrelImageToImageModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldSpandrelImageToImageModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(SpandrelImageToImageModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { T2IAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const T2IAdapterModelFieldInputComponent = (
props: FieldComponentProps<T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
const onChange = useCallback(
(value: T2IAdapterModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldT2IAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(T2IAdapterModelFieldInputComponent);

View File

@@ -0,0 +1,43 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate>;
const T5EncoderModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldT5EncoderValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(T5EncoderModelFieldInputComponent);

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputTemplate>;
const VAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldVaeModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(VAEModelFieldInputComponent);

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldVeo3ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Veo3ModelFieldInputInstance, Veo3ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVeo3Models } from 'services/api/hooks/modelsByType';
import type { VideoApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Veo3ModelFieldInputComponent = (
props: FieldComponentProps<Veo3ModelFieldInputInstance, Veo3ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVeo3Models();
const onChange = useCallback(
(value: VideoApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldVeo3ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Veo3ModelFieldInputComponent);

View File

@@ -7,14 +7,14 @@ import { debounce } from 'es-toolkit/compat';
import { getInitialWorkflow } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice, selectWorkflowId } from 'features/nodes/store/selectors';
import type { NodesState } from 'features/nodes/store/types';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow';
import { atom, computed } from 'nanostores';
import { useEffect, useMemo } from 'react';
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
import stableHash from 'stable-hash';
const $maybePreviewWorkflow = atom<WorkflowV4 | null>(null);
const $maybePreviewWorkflow = atom<WorkflowV3 | null>(null);
export const $previewWorkflow = computed(
$maybePreviewWorkflow,
(maybePreviewWorkflow) => maybePreviewWorkflow ?? EMPTY_OBJECT

View File

@@ -17,8 +17,7 @@ import {
selectWorkflowLibraryView,
workflowLibraryViewChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import { setWorkflowLibraryBrowseIntent } from 'features/workflowLibrary/store/workflowLibraryIntent';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetCountsByCategoryQuery } from 'services/api/endpoints/workflows';
@@ -30,12 +29,8 @@ export const WorkflowLibraryModal = memo(() => {
const { t } = useTranslation();
const workflowLibraryModal = useWorkflowLibraryModal();
const didSync = useSyncInitialWorkflowLibraryCategories();
const handleClose = useCallback(() => {
setWorkflowLibraryBrowseIntent();
workflowLibraryModal.close();
}, [workflowLibraryModal]);
return (
<Modal isOpen={workflowLibraryModal.isOpen} onClose={handleClose} isCentered>
<Modal isOpen={workflowLibraryModal.isOpen} onClose={workflowLibraryModal.close} isCentered>
<ModalOverlay />
<ModalContent
w="calc(100% - var(--invoke-sizes-40))"
@@ -44,7 +39,7 @@ export const WorkflowLibraryModal = memo(() => {
maxH="calc(100% - var(--invoke-sizes-40))"
>
<ModalHeader>{t('workflows.workflowLibrary')}</ModalHeader>
<ModalCloseButton onClick={handleClose} />
<ModalCloseButton />
<ModalBody pb={6}>
{didSync && (
<Flex gap={4} h="100%">

View File

@@ -1,17 +1,11 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { LockedWorkflowIcon } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/LockedWorkflowIcon';
import { ShareWorkflowButton } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/ShareWorkflow';
import { selectWorkflowId } from 'features/nodes/store/selectors';
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import {
$workflowLibraryIntent,
setWorkflowLibraryBrowseIntent,
} from 'features/workflowLibrary/store/workflowLibraryIntent';
import InvokeLogo from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -42,27 +36,12 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
const dispatch = useAppDispatch();
const workflowId = useAppSelector(selectWorkflowId);
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const workflowLibraryModal = useWorkflowLibraryModal();
const workflowLibraryIntent = useStore($workflowLibraryIntent);
const isActive = useMemo(() => {
return workflowId === workflow.workflow_id;
}, [workflowId, workflow.workflow_id]);
const isTriggerMode = workflowLibraryIntent.mode === 'trigger-workflow';
const isDisabled = isTriggerMode && !workflow.has_valid_image_output_field;
const handleClickLoad = useCallback(() => {
if (isTriggerMode) {
if (isDisabled) {
return;
}
workflowLibraryIntent.onSelect(workflow);
setWorkflowLibraryBrowseIntent();
workflowLibraryModal.close();
return;
}
loadWorkflowWithDialog({
type: 'library',
data: workflow.workflow_id,
@@ -70,32 +49,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
dispatch(workflowModeChanged('view'));
},
});
}, [
dispatch,
isDisabled,
isTriggerMode,
loadWorkflowWithDialog,
workflow,
workflowLibraryIntent,
workflowLibraryModal,
]);
}, [dispatch, loadWorkflowWithDialog, workflow.workflow_id]);
return (
<Flex
position="relative"
role="button"
onClick={handleClickLoad}
aria-disabled={isDisabled}
bg="base.750"
borderRadius="base"
w="full"
alignItems="stretch"
sx={{
...sx,
cursor: isDisabled ? 'not-allowed' : 'pointer',
opacity: isDisabled ? 0.5 : 1,
pointerEvents: isDisabled ? 'none' : 'auto',
}}
sx={sx}
gap={2}
>
<Flex p={2} pr={0}>
@@ -141,18 +106,6 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
{t('workflows.builder.published')}
</Badge>
)}
{workflow.has_valid_image_output_field && (
<Badge
color="invokeGreen.400"
borderColor="invokeGreen.700"
borderWidth={1}
bg="transparent"
flexShrink={0}
variant="subtle"
>
{t('workflows.validImageOutput')}
</Badge>
)}
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
{workflow.category === 'default' && (
<Image

View File

@@ -1,268 +0,0 @@
import { Button, ButtonGroup, Divider, Flex, Spacer, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { InvocationNodeContextProvider } from 'features/nodes/components/flow/nodes/Invocation/context';
import { NodeFieldElementOverlay } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
import {
$isInPublishFlow,
$isSelectingOutputNode,
$outputNodeId,
} from 'features/nodes/components/sidePanel/workflow/publish';
import { useMouseOverFormField } from 'features/nodes/hooks/useMouseOverNode';
import { useNodeTemplateTitleOrThrow } from 'features/nodes/hooks/useNodeTemplateTitleOrThrow';
import { useNodeUserTitleOrThrow } from 'features/nodes/hooks/useNodeUserTitleOrThrow';
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
import { $templates, workflowOutputFieldsChanged } from 'features/nodes/store/nodesSlice';
import { selectNodes, selectNodesSlice, selectWorkflowOutputFields } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { type AnyNode, isInvocationNode } from 'features/nodes/types/invocation';
import type { WorkflowOutputField } from 'features/nodes/types/workflow';
import { toast } from 'features/toast/toast';
import { useCallback, useEffect, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowLineRightBold, PiTrashBold } from 'react-icons/pi';
type OutputDetail = { nodeId: string; fieldName: string; userLabel: string | null };
const buildOutputDetails = (outputs: WorkflowOutputField[], templates: Templates, nodes: AnyNode[]): OutputDetail[] => {
const details: OutputDetail[] = [];
for (const output of outputs) {
if (details.length >= 1) {
break;
}
const node = nodes.find((n) => n.id === output.nodeId);
if (!isInvocationNode(node)) {
continue;
}
const template = templates[node.data.type];
const fieldTemplate = template?.outputs?.[output.fieldName];
if (!fieldTemplate || fieldTemplate.type.name !== 'ImageField' || fieldTemplate.type.cardinality === 'COLLECTION') {
continue;
}
details.push({
nodeId: output.nodeId,
fieldName: output.fieldName,
userLabel: output.userLabel ?? fieldTemplate.title ?? null,
});
}
return details;
};
const WorkflowOutputFieldsTab = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const store = useAppStore();
const outputFields = useAppSelector(selectWorkflowOutputFields);
const nodes = useAppSelector(selectNodes);
const templates = useStore($templates);
const outputNodeId = useStore($outputNodeId);
const isSelectingOutputNode = useStore($isSelectingOutputNode);
const isInPublishFlow = useStore($isInPublishFlow);
const selectionSourceRef = useRef<'workflow-output-tab' | null>(null);
const previousOutputsRef = useRef<WorkflowOutputField[]>(outputFields);
const outputDetails = useMemo(
() => buildOutputDetails(outputFields, templates, nodes),
[outputFields, templates, nodes]
);
const revertSelection = useCallback(
(outputsToRestore: WorkflowOutputField[]) => {
dispatch(
workflowOutputFieldsChanged(
outputsToRestore.map(({ nodeId, fieldName, userLabel }) => ({
nodeId,
fieldName,
userLabel,
}))
)
);
const restoredNodeId = outputsToRestore[0]?.nodeId ?? null;
$outputNodeId.set(restoredNodeId);
},
[dispatch]
);
useEffect(() => {
previousOutputsRef.current = outputFields;
}, [outputFields]);
useEffect(() => {
if (isInPublishFlow || isSelectingOutputNode) {
return;
}
const currentNodeId = outputFields[0]?.nodeId ?? null;
if ($outputNodeId.get() !== currentNodeId) {
$outputNodeId.set(currentNodeId);
}
}, [isInPublishFlow, isSelectingOutputNode, outputFields]);
useEffect(() => {
if (selectionSourceRef.current !== 'workflow-output-tab') {
return;
}
if (isSelectingOutputNode) {
return;
}
const selectedNodeId = outputNodeId;
selectionSourceRef.current = null;
if (!selectedNodeId) {
return;
}
const state = store.getState();
const nodesState = selectNodesSlice(state);
const node = nodesState.nodes.find((n) => n.id === selectedNodeId);
if (!isInvocationNode(node)) {
toast({
status: 'error',
title: t('workflows.builder.outputFieldSelectInvalid', { defaultValue: 'Please select an invocation node.' }),
});
revertSelection(previousOutputsRef.current);
return;
}
const template = templates[node.data.type];
if (!template) {
toast({
status: 'error',
title: t('workflows.builder.outputFieldMissingTemplate', {
defaultValue: 'Unable to select outputs for this node.',
}),
});
revertSelection(previousOutputsRef.current);
return;
}
const imageOutputEntry = Object.entries(template.outputs).find(([, output]) => {
return output.type.name === 'ImageField' && output.type.cardinality !== 'COLLECTION';
});
if (!imageOutputEntry) {
toast({
status: 'error',
title: t('workflows.builder.outputFieldMustBeImage', {
defaultValue: 'Selected node must have an image output.',
}),
});
revertSelection(previousOutputsRef.current);
return;
}
const [fieldName, fieldTemplate] = imageOutputEntry;
dispatch(
workflowOutputFieldsChanged([
{
nodeId: selectedNodeId,
fieldName,
userLabel: fieldTemplate.title ?? null,
},
])
);
}, [dispatch, isSelectingOutputNode, outputNodeId, revertSelection, store, templates, t]);
const handleSelectNodeClick = useCallback(() => {
selectionSourceRef.current = 'workflow-output-tab';
previousOutputsRef.current = outputFields;
$outputNodeId.set(null);
$isSelectingOutputNode.set(true);
}, [outputFields]);
const handleClear = useCallback(() => {
previousOutputsRef.current = [];
dispatch(workflowOutputFieldsChanged([]));
$outputNodeId.set(null);
}, [dispatch]);
return (
<Flex flexDir="column" gap={2} h="full">
<Flex alignItems="center">
<Text fontWeight="semibold">{t('workflows.builder.outputFieldsTab', 'Output Fields')}</Text>
<Spacer />
<ButtonGroup size="sm" variant="ghost" isAttached={false}>
{outputDetails.length > 0 && (
<Button leftIcon={<PiTrashBold />} onClick={handleClear}>
{t('common.clear')}
</Button>
)}
<Button
leftIcon={<PiArrowLineRightBold />}
onClick={handleSelectNodeClick}
isDisabled={isSelectingOutputNode}
>
{isSelectingOutputNode
? t('workflows.builder.selectingOutputNode')
: outputDetails.length > 0
? t('workflows.builder.changeOutputNode')
: t('workflows.builder.selectOutputNode')}
</Button>
</ButtonGroup>
</Flex>
<Divider />
{outputDetails.length === 0 ? (
<Text color="warning.300" fontWeight="semibold">
{outputFields.length === 0
? t('workflows.builder.noOutputNodeSelected')
: t('workflows.builder.outputFieldPending', {
defaultValue: 'Selected output is unavailable. Check that the node still exists.',
})}
</Text>
) : (
outputDetails.map((detail) => (
<InvocationNodeContextProvider nodeId={detail.nodeId} key={`${detail.nodeId}-${detail.fieldName}`}>
<SelectedOutputPreview
nodeId={detail.nodeId}
fieldName={detail.fieldName}
fallbackLabel={detail.userLabel ?? detail.fieldName}
/>
</InvocationNodeContextProvider>
))
)}
</Flex>
);
};
export default WorkflowOutputFieldsTab;
const SelectedOutputPreview = ({
nodeId,
fieldName,
fallbackLabel,
}: {
nodeId: string;
fieldName: string;
fallbackLabel: string;
}) => {
const mouseOverFormField = useMouseOverFormField(nodeId);
const nodeUserTitle = useNodeUserTitleOrThrow();
const nodeTemplateTitle = useNodeTemplateTitleOrThrow();
const fieldTemplate = useOutputFieldTemplate(fieldName);
const zoomToNode = useZoomToNode(nodeId);
return (
<Flex
flexDir="column"
position="relative"
p={2}
borderRadius="base"
borderWidth={1}
gap={1}
onMouseOver={mouseOverFormField.handleMouseOver}
onMouseOut={mouseOverFormField.handleMouseOut}
onClick={zoomToNode}
>
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldTemplate?.title ?? fallbackLabel}`}</Text>
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
<NodeFieldElementOverlay nodeId={nodeId} />
</Flex>
);
};

View File

@@ -8,7 +8,6 @@ import { useTranslation } from 'react-i18next';
import WorkflowGeneralTab from './WorkflowGeneralTab';
import WorkflowJSONTab from './WorkflowJSONTab';
import WorkflowOutputFieldsTab from './WorkflowOutputFieldsTab';
const WorkflowFieldsLinearViewPanel = () => {
const { t } = useTranslation();
@@ -18,7 +17,6 @@ const WorkflowFieldsLinearViewPanel = () => {
<TabList>
<Tab>{t('workflows.builder.builder')}</Tab>
<Tab>{t('common.details')}</Tab>
<Tab>{t('workflows.builder.outputFieldsTab', 'Output Fields')}</Tab>
<Tab>JSON</Tab>
<Spacer />
{allowPublishWorkflows && <StartPublishFlowButton />}
@@ -31,9 +29,6 @@ const WorkflowFieldsLinearViewPanel = () => {
<TabPanel h="full" p={0}>
<WorkflowGeneralTab />
</TabPanel>
<TabPanel h="full" p={0}>
<WorkflowOutputFieldsTab />
</TabPanel>
<TabPanel h="full" p={0}>
<WorkflowJSONTab />
</TabPanel>

View File

@@ -2,16 +2,20 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { useInvocationNodeContext } from 'features/nodes/components/flow/nodes/Invocation/context';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { isSingleOrCollection, isStatefulFieldType } from 'features/nodes/types/field';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { useMemo } from 'react';
const isConnectionInputField = (field: FieldInputTemplate) => {
return (field.input === 'connection' && !isSingleOrCollection(field.type)) || !isStatefulFieldType(field.type);
return (
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
);
};
const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
return (
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) && isStatefulFieldType(field.type)
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
field.type.name in TEMPLATE_BUILDER_MAP
);
};

View File

@@ -24,44 +24,90 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
BooleanFieldValue,
ChatGPT4oModelFieldValue,
CLIPEmbedModelFieldValue,
CLIPGEmbedModelFieldValue,
CLIPLEmbedModelFieldValue,
ColorFieldValue,
ControlLoRAModelFieldValue,
ControlNetModelFieldValue,
EnumFieldValue,
FieldValue,
FloatFieldValue,
FloatGeneratorFieldValue,
FluxKontextModelFieldValue,
FluxReduxModelFieldValue,
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
ImageFieldValue,
ImageGeneratorFieldValue,
Imagen3ModelFieldValue,
Imagen4ModelFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
IPAdapterModelFieldValue,
LLaVAModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
ModelIdentifierFieldValue,
RunwayModelFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
SigLipModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldCollectionValue,
StringFieldValue,
StringGeneratorFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
Veo3ModelFieldValue,
} from 'features/nodes/types/field';
import {
zBoardFieldValue,
zBooleanFieldValue,
zChatGPT4oModelFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zColorFieldValue,
zControlLoRAModelFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
zFloatFieldCollectionValue,
zFloatFieldValue,
zFloatGeneratorFieldValue,
zFluxKontextModelFieldValue,
zFluxReduxModelFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
zImageFieldValue,
zImageGeneratorFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
zIPAdapterModelFieldValue,
zLLaVAModelFieldValue,
zLoRAModelFieldValue,
zMainModelFieldValue,
zModelIdentifierFieldValue,
zRunwayModelFieldValue,
zSchedulerFieldValue,
zSDXLRefinerModelFieldValue,
zSigLipModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldCollectionValue,
zStringFieldValue,
zStringGeneratorFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
zVeo3ModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
@@ -74,8 +120,7 @@ import type {
NodeFieldElement,
TextElement,
WorkflowCategory,
WorkflowOutputField,
WorkflowV4,
WorkflowV3,
} from 'features/nodes/types/workflow';
import {
getDefaultForm,
@@ -102,8 +147,7 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
tags: '',
notes: '',
exposedFields: [],
output_fields: [] as WorkflowOutputField[],
meta: { version: '4.0.0', category: 'user' },
meta: { version: '3.0.0', category: 'user' },
form: getDefaultForm(),
nodes: [],
edges: [],
@@ -195,9 +239,6 @@ const slice = createSlice({
);
if (didNodesChange) {
const remainingNodeIds = new Set(state.nodes.map((node) => node.id));
state.output_fields = state.output_fields.filter((field) => remainingNodeIds.has(field.nodeId));
const edgeChanges: EdgeChange<AnyEdge>[] = [];
for (const e of state.edges) {
const sourceExists = state.nodes.some((n) => n.id === e.source);
@@ -452,9 +493,81 @@ const slice = createSlice({
fieldColorValueChanged: (state, action: FieldValueAction<ColorFieldValue>) => {
fieldValueReducer(state, action, zColorFieldValue);
},
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
fieldValueReducer(state, action, zMainModelFieldValue);
},
fieldModelIdentifierValueChanged: (state, action: FieldValueAction<ModelIdentifierFieldValue>) => {
fieldValueReducer(state, action, zModelIdentifierFieldValue);
},
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
},
fieldVaeModelValueChanged: (state, action: FieldValueAction<VAEModelFieldValue>) => {
fieldValueReducer(state, action, zVAEModelFieldValue);
},
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
fieldValueReducer(state, action, zLoRAModelFieldValue);
},
fieldLLaVAModelValueChanged: (state, action: FieldValueAction<LLaVAModelFieldValue>) => {
fieldValueReducer(state, action, zLLaVAModelFieldValue);
},
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
fieldValueReducer(state, action, zControlNetModelFieldValue);
},
fieldIPAdapterModelValueChanged: (state, action: FieldValueAction<IPAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zIPAdapterModelFieldValue);
},
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
},
fieldSpandrelImageToImageModelValueChanged: (
state,
action: FieldValueAction<SpandrelImageToImageModelFieldValue>
) => {
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
},
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
},
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
},
fieldCLIPLEmbedValueChanged: (state, action: FieldValueAction<CLIPLEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPLEmbedModelFieldValue);
},
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
},
fieldControlLoRAModelValueChanged: (state, action: FieldValueAction<ControlLoRAModelFieldValue>) => {
fieldValueReducer(state, action, zControlLoRAModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
fieldSigLipModelValueChanged: (state, action: FieldValueAction<SigLipModelFieldValue>) => {
fieldValueReducer(state, action, zSigLipModelFieldValue);
},
fieldFluxReduxModelValueChanged: (state, action: FieldValueAction<FluxReduxModelFieldValue>) => {
fieldValueReducer(state, action, zFluxReduxModelFieldValue);
},
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen3ModelFieldValue);
},
fieldImagen4ModelValueChanged: (state, action: FieldValueAction<Imagen4ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen4ModelFieldValue);
},
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
fieldVeo3ModelValueChanged: (state, action: FieldValueAction<Veo3ModelFieldValue>) => {
fieldValueReducer(state, action, zVeo3ModelFieldValue);
},
fieldRunwayModelValueChanged: (state, action: FieldValueAction<RunwayModelFieldValue>) => {
fieldValueReducer(state, action, zRunwayModelFieldValue);
},
fieldFluxKontextModelValueChanged: (state, action: FieldValueAction<FluxKontextModelFieldValue>) => {
fieldValueReducer(state, action, zFluxKontextModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -517,43 +630,6 @@ const slice = createSlice({
workflowContactChanged: (state, action: PayloadAction<string>) => {
state.contact = action.payload;
},
workflowOutputFieldsChanged: (
state,
action: PayloadAction<Array<{ nodeId: string; fieldName: string; userLabel?: string | null }>>
) => {
const validOutputs: WorkflowOutputField[] = [];
for (const identifier of action.payload) {
if (validOutputs.length >= 1) {
break; // only allow one output field
}
const nodeId = identifier.nodeId;
const fieldName = identifier.fieldName;
if (!nodeId || !fieldName) {
continue;
}
const node = state.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
continue;
}
const label = identifier.userLabel ?? null;
validOutputs.push({
nodeId,
fieldName,
kind: 'output',
userLabel: label,
node_id: nodeId,
field_name: fieldName,
user_label: label,
});
}
state.output_fields = validOutputs;
},
workflowIDChanged: (state, action: PayloadAction<string>) => {
state.id = action.payload;
},
@@ -606,7 +682,7 @@ const slice = createSlice({
const { formFieldInitialValues } = action.payload;
state.formFieldInitialValues = formFieldInitialValues;
},
workflowLoaded: (state, action: PayloadAction<WorkflowV4>) => {
workflowLoaded: (state, action: PayloadAction<WorkflowV3>) => {
const { nodes, edges, is_published: _is_published, ...workflowExtra } = action.payload;
const formFieldInitialValues = getFormFieldInitialValues(workflowExtra.form, nodes);
@@ -630,22 +706,45 @@ export const {
fieldBoardValueChanged,
fieldBooleanValueChanged,
fieldColorValueChanged,
fieldControlNetModelValueChanged,
fieldEnumModelValueChanged,
fieldImageValueChanged,
fieldImageCollectionValueChanged,
fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged,
fieldSpandrelImageToImageModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldLLaVAModelValueChanged,
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldIntegerValueChanged,
fieldFloatValueChanged,
fieldFloatCollectionValueChanged,
fieldIntegerCollectionValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldControlLoRAModelValueChanged,
fieldFluxVAEModelValueChanged,
fieldSigLipModelValueChanged,
fieldFluxReduxModelValueChanged,
fieldImagen3ModelValueChanged,
fieldImagen4ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFluxKontextModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
fieldStringGeneratorValueChanged,
fieldImageGeneratorValueChanged,
fieldVeo3ModelValueChanged,
fieldRunwayModelValueChanged,
fieldDescriptionChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
@@ -663,7 +762,6 @@ export const {
workflowNotesChanged,
workflowVersionChanged,
workflowContactChanged,
workflowOutputFieldsChanged,
workflowIDChanged,
formReset,
formElementAdded,

View File

@@ -76,7 +76,6 @@ export const selectWorkflowContact = createNodesSelector((workflow) => workflow.
export const selectWorkflowTags = createNodesSelector((workflow) => workflow.tags);
export const selectWorkflowVersion = createNodesSelector((workflow) => workflow.version);
export const selectWorkflowForm = createNodesSelector((workflow) => workflow.form);
export const selectWorkflowOutputFields = createNodesSelector((workflow) => workflow.output_fields);
export const selectFormRootElementId = createNodesSelector((workflow) => {
return workflow.form.rootElementId;

View File

@@ -1,7 +1,7 @@
import type { HandleType } from '@xyflow/react';
import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field';
import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation';
import { zWorkflowV4 } from 'features/nodes/types/workflow';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import z from 'zod';
export type Templates = Record<string, InvocationTemplate>;
@@ -21,6 +21,6 @@ export const zNodesState = z.object({
nodes: z.array(zAnyNode),
edges: z.array(zAnyEdge),
formFieldInitialValues: z.record(z.string(), zStatefulFieldValue),
...zWorkflowV4.omit({ nodes: true, edges: true, is_published: true }).shape,
...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape,
});
export type NodesState = z.infer<typeof zNodesState>;

View File

@@ -40,7 +40,7 @@ describe(areTypesEqual.name, () => {
},
};
const targetType: FieldType = {
name: 'StringField',
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
@@ -54,7 +54,7 @@ describe(areTypesEqual.name, () => {
it('should handle equal original source type and target type', () => {
const sourceType: FieldType = {
name: 'FloatField',
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
@@ -78,7 +78,7 @@ describe(areTypesEqual.name, () => {
it('should handle equal original source type and original target type', () => {
const sourceType: FieldType = {
name: 'IntegerField',
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
@@ -88,7 +88,7 @@ describe(areTypesEqual.name, () => {
},
};
const targetType: FieldType = {
name: 'StringField',
name: 'LoRAModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {

View File

@@ -247,12 +247,17 @@ export const main_model_loader: InvocationTemplate = {
fieldKind: 'input',
input: 'direct',
ui_hidden: false,
ui_model_base: ['sd-1'],
ui_model_type: ['main'],
ui_type: 'MainModelField',
type: {
name: 'ModelIdentifierField',
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'ModelIdentifierField',
cardinality: 'SINGLE',
batch: false,
},
},
},
},
@@ -791,8 +796,7 @@ export const schema = {
input: 'direct',
orig_required: true,
ui_hidden: false,
ui_model_base: ['sd-1'],
ui_model_type: ['main'],
ui_type: 'MainModelField',
},
type: {
type: 'string',

View File

@@ -12,15 +12,11 @@ import type {
SchedulerField,
SubModelType,
T2IAdapterField,
zClipVariantType,
zModelFormat,
zModelVariantType,
} from 'features/nodes/types/common';
import type { Invocation, S } from 'services/api/types';
import type { Invocation, ModelType, S } from 'services/api/types';
import type { Equals, Extends } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type z from 'zod';
/**
* These types originate from the server and are recreated as zod schemas manually, for use at runtime.
@@ -42,9 +38,7 @@ describe('Common types', () => {
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
test('ModelIdentifier', () => assert<Equals<BaseModelType, S['BaseModelType']>>());
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
test('ClipVariantType', () => assert<Equals<z.infer<typeof zClipVariantType>, S['ClipVariantType']>>());
test('ModelVariantType', () => assert<Equals<z.infer<typeof zModelVariantType>, S['ModelVariantType']>>());
test('ModelFormat', () => assert<Equals<z.infer<typeof zModelFormat>, S['ModelFormat']>>());
test('ModelIdentifier', () => assert<Equals<ModelType, S['ModelType']>>());
// Misc types
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());

View File

@@ -73,7 +73,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
export const zBaseModelType = z.enum([
const zBaseModel = z.enum([
'any',
'sd-1',
'sd-2',
@@ -90,7 +90,7 @@ export const zBaseModelType = z.enum([
'veo3',
'runway',
]);
export type BaseModelType = z.infer<typeof zBaseModelType>;
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum([
'sd-1',
'sd-2',
@@ -143,31 +143,11 @@ const zSubModelType = z.enum([
'safety_checker',
]);
export type SubModelType = z.infer<typeof zSubModelType>;
export const zClipVariantType = z.enum(['large', 'gigantic']);
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
export const zModelFormat = z.enum([
'omi',
'diffusers',
'checkpoint',
'lycoris',
'onnx',
'olive',
'embedding_file',
'embedding_folder',
'invokeai',
't5_encoder',
'bnb_quantized_int8b',
'bnb_quantized_nf4b',
'gguf_quantized',
'api',
]);
export const zModelIdentifierField = z.object({
key: z.string().min(1),
hash: z.string().min(1),
name: z.string().min(1),
base: zBaseModelType,
base: zBaseModel,
type: zModelType,
submodel_type: zSubModelType.nullish(),
});

View File

@@ -9,18 +9,7 @@ import { assert } from 'tsafe';
import { z } from 'zod';
import type { ImageField } from './common';
import {
zBaseModelType,
zBoardField,
zClipVariantType,
zColorField,
zImageField,
zModelFormat,
zModelIdentifierField,
zModelType,
zModelVariantType,
zSchedulerField,
} from './common';
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
/**
* zod schemas & inferred types for fields.
@@ -71,10 +60,6 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({
default: z.undefined(),
ui_component: zFieldUIComponent.nullish(),
ui_choice_labels: z.record(z.string(), z.string()).nullish(),
ui_model_base: z.array(zBaseModelType).nullish(),
ui_model_type: z.array(zModelType).nullish(),
ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(),
ui_model_format: z.array(zModelFormat).nullish(),
});
const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
fieldKind: z.literal('output'),
@@ -176,10 +161,118 @@ const zColorFieldType = zFieldTypeBase.extend({
name: z.literal('ColorField'),
originalType: zStatelessFieldType.optional(),
});
const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zModelIdentifierFieldType = zFieldTypeBase.extend({
name: z.literal('ModelIdentifierField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCogView4MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('CogView4MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
});
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zLLaVAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LLaVAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
originalType: zStatelessFieldType.optional(),
});
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('IPAdapterModelField'),
originalType: zStatelessFieldType.optional(),
});
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
name: z.literal('SpandrelImageToImageModelField'),
originalType: zStatelessFieldType.optional(),
});
const zT5EncoderModelFieldType = zFieldTypeBase.extend({
name: z.literal('T5EncoderModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPLEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPLEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPGEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zControlLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlLoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSigLipModelFieldType = zFieldTypeBase.extend({
name: z.literal('SigLipModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxReduxModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxReduxModelField'),
originalType: zStatelessFieldType.optional(),
});
const zImagen3ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen3ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zImagen4ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen4ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
name: z.literal('ChatGPT4oModelField'),
originalType: zStatelessFieldType.optional(),
});
const zVeo3ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Veo3ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zRunwayModelFieldType = zFieldTypeBase.extend({
name: z.literal('RunwayModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxKontextModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxKontextModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -209,6 +302,33 @@ const zStatefulFieldType = z.union([
zImageFieldType,
zBoardFieldType,
zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSD3MainModelFieldType,
zCogView4MainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zLLaVAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType,
zCLIPEmbedModelFieldType,
zCLIPLEmbedModelFieldType,
zCLIPGEmbedModelFieldType,
zControlLoRAModelFieldType,
zFluxVAEModelFieldType,
zSigLipModelFieldType,
zFluxReduxModelFieldType,
zImagen3ModelFieldType,
zImagen4ModelFieldType,
zChatGPT4oModelFieldType,
zFluxKontextModelFieldType,
zVeo3ModelFieldType,
zRunwayModelFieldType,
zColorFieldType,
zSchedulerFieldType,
zFloatGeneratorFieldType,
@@ -224,7 +344,36 @@ const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
export type FieldType = z.infer<typeof zFieldType>;
const modelFieldTypeNames = [
// Stateful model fields
zModelIdentifierFieldType.shape.name.value,
zMainModelFieldType.shape.name.value,
zSDXLMainModelFieldType.shape.name.value,
zSD3MainModelFieldType.shape.name.value,
zCogView4MainModelFieldType.shape.name.value,
zFluxMainModelFieldType.shape.name.value,
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
zLLaVAModelFieldType.shape.name.value,
zControlNetModelFieldType.shape.name.value,
zIPAdapterModelFieldType.shape.name.value,
zT2IAdapterModelFieldType.shape.name.value,
zSpandrelImageToImageModelFieldType.shape.name.value,
zT5EncoderModelFieldType.shape.name.value,
zCLIPEmbedModelFieldType.shape.name.value,
zCLIPLEmbedModelFieldType.shape.name.value,
zCLIPGEmbedModelFieldType.shape.name.value,
zControlLoRAModelFieldType.shape.name.value,
zFluxVAEModelFieldType.shape.name.value,
zSigLipModelFieldType.shape.name.value,
zFluxReduxModelFieldType.shape.name.value,
zImagen3ModelFieldType.shape.name.value,
zImagen4ModelFieldType.shape.name.value,
zChatGPT4oModelFieldType.shape.name.value,
zFluxKontextModelFieldType.shape.name.value,
zVeo3ModelFieldType.shape.name.value,
zRunwayModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
'VAEField',
'CLIPField',
@@ -630,6 +779,26 @@ export const isColorFieldInputInstance = buildInstanceTypeGuard(zColorFieldInput
export const isColorFieldInputTemplate = buildTemplateTypeGuard<ColorFieldInputTemplate>('ColorField');
// #endregion
// #region MainModelField
export const zMainModelFieldValue = zModelIdentifierField.optional();
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zMainModelFieldValue,
});
const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zMainModelFieldType,
originalType: zFieldType.optional(),
default: zMainModelFieldValue,
});
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zMainModelFieldType,
});
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
export const isMainModelFieldInputInstance = buildInstanceTypeGuard(zMainModelFieldInputInstance);
export const isMainModelFieldInputTemplate = buildTemplateTypeGuard<MainModelFieldInputTemplate>('MainModelField');
// #endregion
// #region ModelIdentifierField
export const zModelIdentifierFieldValue = zModelIdentifierField.optional();
const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -651,6 +820,507 @@ export const isModelIdentifierFieldInputTemplate =
buildTemplateTypeGuard<ModelIdentifierFieldInputTemplate>('ModelIdentifierField');
// #endregion
// #region SDXLMainModelField
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSDXLMainModelFieldValue,
});
const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSDXLMainModelFieldType,
originalType: zFieldType.optional(),
default: zSDXLMainModelFieldValue,
});
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSDXLMainModelFieldType,
});
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
export const isSDXLMainModelFieldInputInstance = buildInstanceTypeGuard(zSDXLMainModelFieldInputInstance);
export const isSDXLMainModelFieldInputTemplate =
buildTemplateTypeGuard<SDXLMainModelFieldInputTemplate>('SDXLMainModelField');
// #endregion
// #region SD3MainModelField
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSD3MainModelFieldType,
originalType: zFieldType.optional(),
default: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSD3MainModelFieldType,
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = buildInstanceTypeGuard(zSD3MainModelFieldInputInstance);
export const isSD3MainModelFieldInputTemplate =
buildTemplateTypeGuard<SD3MainModelFieldInputTemplate>('SD3MainModelField');
// #endregion
// #region CogView4MainModelField
const zCogView4MainModelFieldValue = zMainModelFieldValue;
const zCogView4MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCogView4MainModelFieldValue,
});
const zCogView4MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCogView4MainModelFieldType,
originalType: zFieldType.optional(),
default: zCogView4MainModelFieldValue,
});
const zCogView4MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zCogView4MainModelFieldType,
});
export type CogView4MainModelFieldInputInstance = z.infer<typeof zCogView4MainModelFieldInputInstance>;
export type CogView4MainModelFieldInputTemplate = z.infer<typeof zCogView4MainModelFieldInputTemplate>;
export const isCogView4MainModelFieldInputInstance = buildInstanceTypeGuard(zCogView4MainModelFieldInputInstance);
export const isCogView4MainModelFieldInputTemplate =
buildTemplateTypeGuard<CogView4MainModelFieldInputTemplate>('CogView4MainModelField');
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxMainModelFieldType,
originalType: zFieldType.optional(),
default: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFluxMainModelFieldType,
});
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
export const isFluxMainModelFieldInputInstance = buildInstanceTypeGuard(zFluxMainModelFieldInputInstance);
export const isFluxMainModelFieldInputTemplate =
buildTemplateTypeGuard<FluxMainModelFieldInputTemplate>('FluxMainModelField');
// #endregion
// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSDXLRefinerModelFieldValue,
});
const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSDXLRefinerModelFieldType,
originalType: zFieldType.optional(),
default: zSDXLRefinerModelFieldValue,
});
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSDXLRefinerModelFieldType,
});
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
export const isSDXLRefinerModelFieldInputInstance = buildInstanceTypeGuard(zSDXLRefinerModelFieldInputInstance);
export const isSDXLRefinerModelFieldInputTemplate =
buildTemplateTypeGuard<SDXLRefinerModelFieldInputTemplate>('SDXLRefinerModelField');
// #endregion
// #region VAEModelField
export const zVAEModelFieldValue = zModelIdentifierField.optional();
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zVAEModelFieldValue,
});
const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zVAEModelFieldType,
originalType: zFieldType.optional(),
default: zVAEModelFieldValue,
});
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zVAEModelFieldType,
});
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
export const isVAEModelFieldInputInstance = buildInstanceTypeGuard(zVAEModelFieldInputInstance);
export const isVAEModelFieldInputTemplate = buildTemplateTypeGuard<VAEModelFieldInputTemplate>('VAEModelField');
// #endregion
// #region LoRAModelField
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLoRAModelFieldValue,
});
const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zLoRAModelFieldType,
originalType: zFieldType.optional(),
default: zLoRAModelFieldValue,
});
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zLoRAModelFieldType,
});
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
export const isLoRAModelFieldInputInstance = buildInstanceTypeGuard(zLoRAModelFieldInputInstance);
export const isLoRAModelFieldInputTemplate = buildTemplateTypeGuard<LoRAModelFieldInputTemplate>('LoRAModelField');
// #endregion
// #region LLaVAModelField
export const zLLaVAModelFieldValue = zModelIdentifierField.optional();
const zLLaVAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLLaVAModelFieldValue,
});
const zLLaVAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zLLaVAModelFieldType,
originalType: zFieldType.optional(),
default: zLLaVAModelFieldValue,
});
const zLLaVAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zLLaVAModelFieldType,
});
export type LLaVAModelFieldValue = z.infer<typeof zLLaVAModelFieldValue>;
export type LLaVAModelFieldInputInstance = z.infer<typeof zLLaVAModelFieldInputInstance>;
export type LLaVAModelFieldInputTemplate = z.infer<typeof zLLaVAModelFieldInputTemplate>;
export const isLLaVAModelFieldInputInstance = buildInstanceTypeGuard(zLLaVAModelFieldInputInstance);
export const isLLaVAModelFieldInputTemplate = buildTemplateTypeGuard<LLaVAModelFieldInputTemplate>('LLaVAModelField');
// #endregion
// #region ControlNetModelField
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zControlNetModelFieldValue,
});
const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zControlNetModelFieldType,
originalType: zFieldType.optional(),
default: zControlNetModelFieldValue,
});
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zControlNetModelFieldType,
});
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
export const isControlNetModelFieldInputInstance = buildInstanceTypeGuard(zControlNetModelFieldInputInstance);
export const isControlNetModelFieldInputTemplate =
buildTemplateTypeGuard<ControlNetModelFieldInputTemplate>('ControlNetModelField');
// #endregion
// #region IPAdapterModelField
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIPAdapterModelFieldValue,
});
const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zIPAdapterModelFieldType,
originalType: zFieldType.optional(),
default: zIPAdapterModelFieldValue,
});
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIPAdapterModelFieldType,
});
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
export const isIPAdapterModelFieldInputInstance = buildInstanceTypeGuard(zIPAdapterModelFieldInputInstance);
export const isIPAdapterModelFieldInputTemplate =
buildTemplateTypeGuard<IPAdapterModelFieldInputTemplate>('IPAdapterModelField');
// #endregion
// #region T2IAdapterField
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT2IAdapterModelFieldValue,
});
const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zT2IAdapterModelFieldType,
originalType: zFieldType.optional(),
default: zT2IAdapterModelFieldValue,
});
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zT2IAdapterModelFieldType,
});
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
export const isT2IAdapterModelFieldInputInstance = buildInstanceTypeGuard(zT2IAdapterModelFieldInputInstance);
export const isT2IAdapterModelFieldInputTemplate =
buildTemplateTypeGuard<T2IAdapterModelFieldInputTemplate>('T2IAdapterModelField');
// #endregion
// #region SpandrelModelToModelField
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSpandrelImageToImageModelFieldValue,
});
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSpandrelImageToImageModelFieldType,
originalType: zFieldType.optional(),
default: zSpandrelImageToImageModelFieldValue,
});
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSpandrelImageToImageModelFieldType,
});
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
export const isSpandrelImageToImageModelFieldInputInstance = buildInstanceTypeGuard(
zSpandrelImageToImageModelFieldInputInstance
);
export const isSpandrelImageToImageModelFieldInputTemplate =
buildTemplateTypeGuard<SpandrelImageToImageModelFieldInputTemplate>('SpandrelImageToImageModelField');
// #endregion
// #region T5EncoderModelField
export const zT5EncoderModelFieldValue = zModelIdentifierField.optional();
const zT5EncoderModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT5EncoderModelFieldValue,
});
const zT5EncoderModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zT5EncoderModelFieldType,
originalType: zFieldType.optional(),
default: zT5EncoderModelFieldValue,
});
export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>;
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
export const isT5EncoderModelFieldInputInstance = buildInstanceTypeGuard(zT5EncoderModelFieldInputInstance);
export const isT5EncoderModelFieldInputTemplate =
buildTemplateTypeGuard<T5EncoderModelFieldInputTemplate>('T5EncoderModelField');
// #endregion
// #region FluxVAEModelField
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxVAEModelFieldValue,
});
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxVAEModelFieldType,
originalType: zFieldType.optional(),
default: zFluxVAEModelFieldValue,
});
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
export const isFluxVAEModelFieldInputInstance = buildInstanceTypeGuard(zFluxVAEModelFieldInputInstance);
export const isFluxVAEModelFieldInputTemplate =
buildTemplateTypeGuard<FluxVAEModelFieldInputTemplate>('FluxVAEModelField');
// #endregion
// #region CLIPEmbedModelField
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPEmbedModelFieldValue,
});
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPEmbedModelFieldValue,
});
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
export const isCLIPEmbedModelFieldInputInstance = buildInstanceTypeGuard(zCLIPEmbedModelFieldInputInstance);
export const isCLIPEmbedModelFieldInputTemplate =
buildTemplateTypeGuard<CLIPEmbedModelFieldInputTemplate>('CLIPEmbedModelField');
// #endregion
// #region CLIPLEmbedModelField
export const zCLIPLEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPLEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPLEmbedModelFieldValue,
});
const zCLIPLEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPLEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPLEmbedModelFieldValue,
});
export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
export const isCLIPLEmbedModelFieldInputInstance = buildInstanceTypeGuard(zCLIPLEmbedModelFieldInputInstance);
export const isCLIPLEmbedModelFieldInputTemplate =
buildTemplateTypeGuard<CLIPLEmbedModelFieldInputTemplate>('CLIPLEmbedModelField');
// #endregion
// #region CLIPGEmbedModelField
export const zCLIPGEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPGEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPGEmbedModelFieldValue,
});
const zCLIPGEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPGEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPGEmbedModelFieldValue,
});
export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
export const isCLIPGEmbedModelFieldInputInstance = buildInstanceTypeGuard(zCLIPGEmbedModelFieldInputInstance);
export const isCLIPGEmbedModelFieldInputTemplate =
buildTemplateTypeGuard<CLIPGEmbedModelFieldInputTemplate>('CLIPGEmbedModelField');
// #endregion
// #region ControlLoRAModelField
export const zControlLoRAModelFieldValue = zModelIdentifierField.optional();
const zControlLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zControlLoRAModelFieldValue,
});
const zControlLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zControlLoRAModelFieldType,
originalType: zFieldType.optional(),
default: zControlLoRAModelFieldValue,
});
export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
export const isControlLoRAModelFieldInputInstance = buildInstanceTypeGuard(zControlLoRAModelFieldInputInstance);
export const isControlLoRAModelFieldInputTemplate =
buildTemplateTypeGuard<ControlLoRAModelFieldInputTemplate>('ControlLoRAModelField');
// #endregion
// #region SigLipModelField
export const zSigLipModelFieldValue = zModelIdentifierField.optional();
const zSigLipModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSigLipModelFieldValue,
});
const zSigLipModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSigLipModelFieldType,
originalType: zFieldType.optional(),
default: zSigLipModelFieldValue,
});
export type SigLipModelFieldValue = z.infer<typeof zSigLipModelFieldValue>;
export type SigLipModelFieldInputInstance = z.infer<typeof zSigLipModelFieldInputInstance>;
export type SigLipModelFieldInputTemplate = z.infer<typeof zSigLipModelFieldInputTemplate>;
export const isSigLipModelFieldInputInstance = buildInstanceTypeGuard(zSigLipModelFieldInputInstance);
export const isSigLipModelFieldInputTemplate =
buildTemplateTypeGuard<SigLipModelFieldInputTemplate>('SigLipModelField');
// #endregion
// #region FluxReduxModelField
export const zFluxReduxModelFieldValue = zModelIdentifierField.optional();
const zFluxReduxModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxReduxModelFieldValue,
});
const zFluxReduxModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxReduxModelFieldType,
originalType: zFieldType.optional(),
default: zFluxReduxModelFieldValue,
});
export type FluxReduxModelFieldValue = z.infer<typeof zFluxReduxModelFieldValue>;
export type FluxReduxModelFieldInputInstance = z.infer<typeof zFluxReduxModelFieldInputInstance>;
export type FluxReduxModelFieldInputTemplate = z.infer<typeof zFluxReduxModelFieldInputTemplate>;
export const isFluxReduxModelFieldInputInstance = buildInstanceTypeGuard(zFluxReduxModelFieldInputInstance);
export const isFluxReduxModelFieldInputTemplate =
buildTemplateTypeGuard<FluxReduxModelFieldInputTemplate>('FluxReduxModelField');
// #endregion
// #region Imagen3ModelField
export const zImagen3ModelFieldValue = zModelIdentifierField.optional();
const zImagen3ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImagen3ModelFieldValue,
});
const zImagen3ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImagen3ModelFieldType,
originalType: zFieldType.optional(),
default: zImagen3ModelFieldValue,
});
export type Imagen3ModelFieldValue = z.infer<typeof zImagen3ModelFieldValue>;
export type Imagen3ModelFieldInputInstance = z.infer<typeof zImagen3ModelFieldInputInstance>;
export type Imagen3ModelFieldInputTemplate = z.infer<typeof zImagen3ModelFieldInputTemplate>;
export const isImagen3ModelFieldInputInstance = buildInstanceTypeGuard(zImagen3ModelFieldInputInstance);
export const isImagen3ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
// #endregion
// #region Imagen4ModelField
export const zImagen4ModelFieldValue = zModelIdentifierField.optional();
const zImagen4ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImagen4ModelFieldValue,
});
const zImagen4ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImagen4ModelFieldType,
originalType: zFieldType.optional(),
default: zImagen4ModelFieldValue,
});
export type Imagen4ModelFieldValue = z.infer<typeof zImagen4ModelFieldValue>;
export type Imagen4ModelFieldInputInstance = z.infer<typeof zImagen4ModelFieldInputInstance>;
export type Imagen4ModelFieldInputTemplate = z.infer<typeof zImagen4ModelFieldInputTemplate>;
export const isImagen4ModelFieldInputInstance = buildInstanceTypeGuard(zImagen4ModelFieldInputInstance);
export const isImagen4ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen4ModelFieldInputTemplate>('Imagen4ModelField');
// #endregion
// #region FluxKontextModelField
export const zFluxKontextModelFieldValue = zModelIdentifierField.optional();
const zFluxKontextModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxKontextModelFieldValue,
});
const zFluxKontextModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxKontextModelFieldType,
originalType: zFieldType.optional(),
default: zFluxKontextModelFieldValue,
});
export type FluxKontextModelFieldValue = z.infer<typeof zFluxKontextModelFieldValue>;
export type FluxKontextModelFieldInputInstance = z.infer<typeof zFluxKontextModelFieldInputInstance>;
export type FluxKontextModelFieldInputTemplate = z.infer<typeof zFluxKontextModelFieldInputTemplate>;
export const isFluxKontextModelFieldInputInstance = buildInstanceTypeGuard(zFluxKontextModelFieldInputInstance);
export const isFluxKontextModelFieldInputTemplate =
buildTemplateTypeGuard<FluxKontextModelFieldInputTemplate>('FluxKontextModelField');
// #endregion
// #region ChatGPT4oModelField
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zChatGPT4oModelFieldValue,
});
const zChatGPT4oModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zChatGPT4oModelFieldType,
originalType: zFieldType.optional(),
default: zChatGPT4oModelFieldValue,
});
export type ChatGPT4oModelFieldValue = z.infer<typeof zChatGPT4oModelFieldValue>;
export type ChatGPT4oModelFieldInputInstance = z.infer<typeof zChatGPT4oModelFieldInputInstance>;
export type ChatGPT4oModelFieldInputTemplate = z.infer<typeof zChatGPT4oModelFieldInputTemplate>;
export const isChatGPT4oModelFieldInputInstance = buildInstanceTypeGuard(zChatGPT4oModelFieldInputInstance);
export const isChatGPT4oModelFieldInputTemplate =
buildTemplateTypeGuard<ChatGPT4oModelFieldInputTemplate>('ChatGPT4oModelField');
// #endregion
// #region Veo3ModelField
export const zVeo3ModelFieldValue = zModelIdentifierField.optional();
const zVeo3ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zVeo3ModelFieldValue,
});
const zVeo3ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zVeo3ModelFieldType,
originalType: zFieldType.optional(),
default: zVeo3ModelFieldValue,
});
export type Veo3ModelFieldValue = z.infer<typeof zVeo3ModelFieldValue>;
export type Veo3ModelFieldInputInstance = z.infer<typeof zVeo3ModelFieldInputInstance>;
export type Veo3ModelFieldInputTemplate = z.infer<typeof zVeo3ModelFieldInputTemplate>;
export const isVeo3ModelFieldInputInstance = buildInstanceTypeGuard(zVeo3ModelFieldInputInstance);
export const isVeo3ModelFieldInputTemplate = buildTemplateTypeGuard<Veo3ModelFieldInputTemplate>('Veo3ModelField');
// #endregion
// #region RunwayModelField
export const zRunwayModelFieldValue = zModelIdentifierField.optional();
const zRunwayModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zRunwayModelFieldValue,
});
const zRunwayModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zRunwayModelFieldType,
originalType: zFieldType.optional(),
default: zRunwayModelFieldValue,
});
export type RunwayModelFieldValue = z.infer<typeof zRunwayModelFieldValue>;
export type RunwayModelFieldInputInstance = z.infer<typeof zRunwayModelFieldInputInstance>;
export type RunwayModelFieldInputTemplate = z.infer<typeof zRunwayModelFieldInputTemplate>;
export const isRunwayModelFieldInputInstance = buildInstanceTypeGuard(zRunwayModelFieldInputInstance);
export const isRunwayModelFieldInputTemplate =
buildTemplateTypeGuard<RunwayModelFieldInputTemplate>('RunwayModelField');
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1261,6 +1931,31 @@ export const zStatefulFieldValue = z.union([
zImageFieldCollectionValue,
zBoardFieldValue,
zModelIdentifierFieldValue,
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSD3MainModelFieldValue,
zCogView4MainModelFieldValue,
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
zLLaVAModelFieldValue,
zControlNetModelFieldValue,
zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zT5EncoderModelFieldValue,
zFluxVAEModelFieldValue,
zCLIPEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zControlLoRAModelFieldValue,
zSigLipModelFieldValue,
zFluxReduxModelFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zFluxKontextModelFieldValue,
zChatGPT4oModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
zFloatGeneratorFieldValue,
@@ -1288,6 +1983,22 @@ const zStatefulFieldInputInstance = z.union([
zImageFieldCollectionInputInstance,
zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zCogView4MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zLLaVAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance,
zT5EncoderModelFieldInputInstance,
zFluxVAEModelFieldInputInstance,
zCLIPEmbedModelFieldInputInstance,
zColorFieldInputInstance,
zSchedulerFieldInputInstance,
zFloatGeneratorFieldInputInstance,
@@ -1314,6 +2025,33 @@ const zStatefulFieldInputTemplate = z.union([
zImageFieldCollectionInputTemplate,
zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zCogView4MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
zLoRAModelFieldInputTemplate,
zLLaVAModelFieldInputTemplate,
zControlNetModelFieldInputTemplate,
zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate,
zT5EncoderModelFieldInputTemplate,
zFluxVAEModelFieldInputTemplate,
zCLIPEmbedModelFieldInputTemplate,
zCLIPLEmbedModelFieldInputTemplate,
zCLIPGEmbedModelFieldInputTemplate,
zControlLoRAModelFieldInputTemplate,
zSigLipModelFieldInputTemplate,
zFluxReduxModelFieldInputTemplate,
zImagen3ModelFieldInputTemplate,
zImagen4ModelFieldInputTemplate,
zChatGPT4oModelFieldInputTemplate,
zFluxKontextModelFieldInputTemplate,
zVeo3ModelFieldInputTemplate,
zRunwayModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,
@@ -1341,6 +2079,19 @@ const zStatefulFieldOutputTemplate = z.union([
zImageFieldCollectionOutputTemplate,
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zCogView4MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,
zLoRAModelFieldOutputTemplate,
zLLaVAModelFieldOutputTemplate,
zControlNetModelFieldOutputTemplate,
zIPAdapterModelFieldOutputTemplate,
zT2IAdapterModelFieldOutputTemplate,
zSpandrelImageToImageModelFieldOutputTemplate,
zColorFieldOutputTemplate,
zSchedulerFieldOutputTemplate,
zFloatGeneratorFieldOutputTemplate,

View File

@@ -1,5 +1,5 @@
import type { XYPosition as ReactFlowXYPosition } from '@xyflow/react';
import type { WorkflowCategory, WorkflowV4, XYPosition } from 'features/nodes/types/workflow';
import type { WorkflowCategory, WorkflowV3, XYPosition } from 'features/nodes/types/workflow';
import type { S } from 'services/api/types';
import type { Equals, Extends } from 'tsafe';
import { assert } from 'tsafe';
@@ -14,5 +14,5 @@ import { describe, test } from 'vitest';
describe('Workflow types', () => {
test('XYPosition', () => assert<Equals<XYPosition, ReactFlowXYPosition>>());
test('WorkflowCategory', () => assert<Equals<WorkflowCategory, S['WorkflowCategory']>>());
test('WorkflowV4', () => assert<Extends<SetRequired<WorkflowV4, 'id'>, S['Workflow']>>());
test('WorkflowV3', () => assert<Extends<SetRequired<WorkflowV3, 'id'>, S['Workflow']>>());
});

View File

@@ -57,41 +57,6 @@ const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({
const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]);
// #endregion
export type WorkflowOutputField = {
kind: 'output';
nodeId: string;
fieldName: string;
userLabel: string | null;
node_id: string;
field_name: string;
user_label: string | null;
};
const zWorkflowOutputField = z
.object({
kind: z.literal('output').default('output'),
nodeId: z.string().trim().min(1).optional(),
node_id: z.string().trim().min(1),
fieldName: z.string().trim().min(1).optional(),
field_name: z.string().trim().min(1),
userLabel: z.string().nullable().optional(),
user_label: z.string().nullable().optional(),
})
.transform((field) => {
const nodeId = field.nodeId ?? field.node_id;
const fieldName = field.fieldName ?? field.field_name;
const userLabel = field.userLabel ?? field.user_label ?? null;
return {
kind: 'output',
nodeId,
fieldName,
userLabel,
node_id: nodeId,
field_name: fieldName,
user_label: userLabel,
} as WorkflowOutputField;
});
// #region Workflow Builder
const zElementId = z.string().trim().min(1);
export type ElementId = z.infer<typeof zElementId>;
@@ -398,7 +363,7 @@ const zValidatedBuilderForm = zBuilderForm
//# endregion
// #region Workflow
const workflowBaseShape = {
export const zWorkflowV3 = z.object({
id: z.string().min(1).optional(),
name: z.string(),
author: z.string(),
@@ -410,27 +375,13 @@ const workflowBaseShape = {
nodes: z.array(zWorkflowNode),
edges: z.array(zWorkflowEdge),
exposedFields: z.array(zFieldIdentifier),
// Use the validated form schema!
form: zValidatedBuilderForm,
is_published: z.boolean().nullish(),
} as const;
export const zWorkflowV3 = z.object({
...workflowBaseShape,
meta: z.object({
category: zWorkflowCategory.default('user'),
version: z.literal('3.0.0'),
}),
// Use the validated form schema!
form: zValidatedBuilderForm,
is_published: z.boolean().nullish(),
});
export type WorkflowV3 = z.infer<typeof zWorkflowV3>;
export const zWorkflowV4 = z.object({
...workflowBaseShape,
meta: z.object({
category: zWorkflowCategory.default('user'),
version: z.literal('4.0.0'),
}),
output_fields: z.array(zWorkflowOutputField).max(1).default([]),
});
export type WorkflowV4 = z.infer<typeof zWorkflowV4>;
// #endregion

View File

@@ -7,15 +7,12 @@ import type { Templates } from 'features/nodes/store/types';
import type { BoardField } from 'features/nodes/types/common';
import type { BoardFieldInputInstance } from 'features/nodes/types/field';
import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { isBatchNodeType, isGeneratorNodeType } from 'features/nodes/types/invocation';
import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
const log = logger('workflows');
type BoardFieldResolver = (field: BoardFieldInputInstance) => BoardField | undefined;
const getBoardField = (field: BoardFieldInputInstance, state: RootState): BoardField | undefined => {
// Translate the UI value to the graph value. See note in BoardFieldInputComponent for more info.
const { value } = field;
@@ -37,70 +34,19 @@ const getBoardField = (field: BoardFieldInputInstance, state: RootState): BoardF
return value;
};
const defaultBoardFieldResolver: BoardFieldResolver = (field) => {
const { value } = field;
if (!value || value === 'none' || value === 'auto') {
return undefined;
}
return value;
};
/**
* Builds a graph from the node editor state.
*/
export const buildNodesGraph = (state: RootState, templates: Templates): Required<Graph> => {
const { nodes, edges } = selectNodesSlice(state);
type NodeLike = {
id: string;
type?: string;
data?: unknown;
};
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
const filteredNodes = nodes.filter(isInvocationNode).filter(isExecutableNode);
type InvocationNodeLike = NodeLike & {
data: InvocationNodeData;
};
const isInvocationNodeLike = (node: NodeLike): node is InvocationNodeLike => {
if (node.type !== 'invocation') {
return false;
}
if (!node.data || typeof node.data !== 'object') {
return false;
}
const data = node.data as Partial<InvocationNodeData>;
return Boolean(data.inputs && data.type && data.useCache !== undefined && data.isIntermediate !== undefined);
};
type EdgeLike = {
type?: string;
source: string;
target: string;
sourceHandle?: string | null;
targetHandle?: string | null;
};
type BuildInvocationGraphArgs = {
nodes: NodeLike[];
edges: EdgeLike[];
templates: Templates;
resolveBoardField?: BoardFieldResolver;
graphId?: string;
};
export const buildInvocationGraph = ({
nodes,
edges,
templates,
resolveBoardField = defaultBoardFieldResolver,
graphId,
}: BuildInvocationGraphArgs): Required<Graph> => {
const invocationNodes = nodes.filter(isInvocationNodeLike);
const executableNodes = invocationNodes.filter((node) => {
const nodeType = node.data.type;
return !isBatchNodeType(nodeType) && !isGeneratorNodeType(nodeType);
});
const parsedNodes = executableNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
const { id, data } = node;
const { type } = data;
const { type, inputs, isIntermediate } = data;
const nodeTemplate = templates[type];
if (!nodeTemplate) {
@@ -108,8 +54,9 @@ export const buildInvocationGraph = ({
return nodesAccumulator;
}
// Transform each node's inputs to simple key-value pairs
const transformedInputs = reduce(
data.inputs,
inputs,
(inputsAccumulator, input, name) => {
const fieldTemplate = nodeTemplate.inputs[name];
if (!fieldTemplate) {
@@ -117,7 +64,7 @@ export const buildInvocationGraph = ({
return inputsAccumulator;
}
if (isBoardFieldInputTemplate(fieldTemplate) && isBoardFieldInputInstance(input)) {
inputsAccumulator[name] = resolveBoardField(input);
inputsAccumulator[name] = getBoardField(input, state);
} else {
inputsAccumulator[name] = input.value;
}
@@ -127,15 +74,18 @@ export const buildInvocationGraph = ({
{} as Record<Exclude<string, 'id' | 'type'>, unknown>
);
transformedInputs['use_cache'] = data.useCache;
// add reserved use_cache
transformedInputs['use_cache'] = node.data.useCache;
// Build this specific node
const graphNode = {
type,
id,
...transformedInputs,
is_intermediate: data.isIntermediate,
} as AnyInvocation;
is_intermediate: isIntermediate,
};
// Add it to the nodes object
Object.assign(nodesAccumulator, {
[id]: graphNode,
});
@@ -143,59 +93,58 @@ export const buildInvocationGraph = ({
return nodesAccumulator;
}, {});
const executableNodeIds = executableNodes.map(({ id }) => id);
const filteredNodeIds = filteredNodes.map(({ id }) => id);
const parsedEdges = edges
// skip out the "dummy" edges between collapsed nodes
const filteredEdges = edges
.filter((edge) => edge.type !== 'collapsed')
.filter((edge) => executableNodeIds.includes(edge.source) && executableNodeIds.includes(edge.target))
.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
const { source, target, sourceHandle, targetHandle } = edge;
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
if (!sourceHandle || !targetHandle) {
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
return edgesAccumulator;
}
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle,
},
destination: {
node_id: target,
field: targetHandle,
},
});
// Reduce the node editor edges into invocation graph edges
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
const { source, target, sourceHandle, targetHandle } = edge;
if (!sourceHandle || !targetHandle) {
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
return edgesAccumulator;
}, []);
parsedEdges.forEach((edge) => {
const destinationNode = parsedNodes[edge.destination.node_id];
if (!destinationNode) {
return;
}
// Format the edges and add to the edges array
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle,
},
destination: {
node_id: target,
field: targetHandle,
},
});
return edgesAccumulator;
}, []);
/**
* Omit all inputs that have edges connected.
*
* Fixes edge case where the user has connected an input, but also provided an invalid explicit,
* value.
*
* In this edge case, pydantic will invalidate the node based on the invalid explicit value,
* even though the actual value that will be used comes from the connection.
*/
parsedEdges.forEach((edge) => {
const destination_node = parsedNodes[edge.destination.node_id];
const field = edge.destination.field;
parsedNodes[edge.destination.node_id] = omit(destinationNode, field) as AnyInvocation;
parsedNodes[edge.destination.node_id] = omit(destination_node, field) as AnyInvocation;
});
return {
id: graphId ?? uuidv4(),
// Assemble!
const graph = {
id: uuidv4(),
nodes: parsedNodes,
edges: parsedEdges,
};
};
/**
* Builds a graph from the node editor state.
*/
export const buildNodesGraph = (state: RootState, templates: Templates): Required<Graph> => {
const { nodes, edges } = selectNodesSlice(state);
return buildInvocationGraph({
nodes,
edges,
templates,
resolveBoardField: (field) => getBoardField(field, state),
});
return graph;
};

View File

@@ -205,7 +205,7 @@ export const getInfill = (
assert(false, 'Unknown infill method');
};
export const CANVAS_OUTPUT_PREFIX = 'canvas_output';
const CANVAS_OUTPUT_PREFIX = 'canvas_output';
export const isMainModelWithoutUnet = (modelLoader: Invocation<MainModelLoaderNodes>) => {
return (

View File

@@ -9,9 +9,36 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
FloatField: 0,
ImageField: undefined,
IntegerField: 0,
IPAdapterModelField: undefined,
LoRAModelField: undefined,
LLaVAModelField: undefined,
ModelIdentifierField: undefined,
MainModelField: undefined,
SchedulerField: 'dpmpp_3m_k',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SD3MainModelField: undefined,
CogView4MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,
SpandrelImageToImageModelField: undefined,
VAEModelField: undefined,
ControlNetModelField: undefined,
T5EncoderModelField: undefined,
FluxVAEModelField: undefined,
CLIPEmbedModelField: undefined,
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
ControlLoRAModelField: undefined,
SigLipModelField: undefined,
FluxReduxModelField: undefined,
Imagen3ModelField: undefined,
Imagen4ModelField: undefined,
ChatGPT4oModelField: undefined,
FluxKontextModelField: undefined,
Veo3ModelField: undefined,
RunwayModelField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,
StringGeneratorField: undefined,

View File

@@ -3,26 +3,53 @@ import { FieldParseError } from 'features/nodes/types/error';
import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
ChatGPT4oModelFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
CLIPGEmbedModelFieldInputTemplate,
CLIPLEmbedModelFieldInputTemplate,
CogView4MainModelFieldInputTemplate,
ColorFieldInputTemplate,
ControlLoRAModelFieldInputTemplate,
ControlNetModelFieldInputTemplate,
EnumFieldInputTemplate,
FieldInputTemplate,
FieldType,
FloatFieldCollectionInputTemplate,
FloatFieldInputTemplate,
FloatGeneratorFieldInputTemplate,
FluxKontextModelFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxReduxModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
ImageFieldCollectionInputTemplate,
ImageFieldInputTemplate,
ImageGeneratorFieldInputTemplate,
Imagen3ModelFieldInputTemplate,
Imagen4ModelFieldInputTemplate,
IntegerFieldCollectionInputTemplate,
IntegerFieldInputTemplate,
IntegerGeneratorFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
LLaVAModelFieldInputTemplate,
LoRAModelFieldInputTemplate,
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
RunwayModelFieldInputTemplate,
SchedulerFieldInputTemplate,
SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
SigLipModelFieldInputTemplate,
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldCollectionInputTemplate,
StringFieldInputTemplate,
StringGeneratorFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
Veo3ModelFieldInputTemplate,
} from 'features/nodes/types/field';
import {
getFloatGeneratorArithmeticSequenceDefaults,
@@ -275,6 +302,373 @@ const buildModelIdentifierFieldInputTemplate: FieldInputTemplateBuilder<ModelIde
return template;
};
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SDXLMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SD3MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildCogView4MainModelFieldInputTemplate: FieldInputTemplateBuilder<CogView4MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CogView4MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SDXLRefinerModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: VAEModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5EncoderModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: T5EncoderModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildCLIPLEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPLEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPLEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPGEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildControlLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<ControlLoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: ControlLoRAModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLLaVAModelFieldInputTemplate: FieldInputTemplateBuilder<LLaVAModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: LLaVAModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxVAEModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: LoRAModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<ControlNetModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: ControlNetModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<IPAdapterModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: IPAdapterModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapterModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: T2IAdapterModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
SpandrelImageToImageModelFieldInputTemplate
> = ({ schemaObject, baseField, fieldType }) => {
const template: SpandrelImageToImageModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildSigLipModelFieldInputTemplate: FieldInputTemplateBuilder<SigLipModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SigLipModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxReduxModelFieldInputTemplate: FieldInputTemplateBuilder<FluxReduxModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxReduxModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImagen3ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen3ModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: Imagen3ModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImagen4ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen4ModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: Imagen4ModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxKontextModelFieldInputTemplate: FieldInputTemplateBuilder<FluxKontextModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxKontextModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVeo3ModelFieldInputTemplate: FieldInputTemplateBuilder<Veo3ModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: Veo3ModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRunwayModelFieldInputTemplate: FieldInputTemplateBuilder<RunwayModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: RunwayModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildChatGPT4oModelFieldInputTemplate: FieldInputTemplateBuilder<ChatGPT4oModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: ChatGPT4oModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -449,41 +843,56 @@ const buildImageGeneratorFieldInputTemplate: FieldInputTemplateBuilder<ImageGene
return template;
};
const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
BoardField: buildBoardFieldInputTemplate,
BooleanField: buildBooleanFieldInputTemplate,
ColorField: buildColorFieldInputTemplate,
ControlNetModelField: buildControlNetModelFieldInputTemplate,
EnumField: buildEnumFieldInputTemplate,
FloatField: buildFloatFieldInputTemplate,
ImageField: buildImageFieldInputTemplate,
IntegerField: buildIntegerFieldInputTemplate,
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
LoRAModelField: buildLoRAModelFieldInputTemplate,
LLaVAModelField: buildLLaVAModelFieldInputTemplate,
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
SigLipModelField: buildSigLipModelFieldInputTemplate,
FluxReduxModelField: buildFluxReduxModelFieldInputTemplate,
Imagen3ModelField: buildImagen3ModelFieldInputTemplate,
Imagen4ModelField: buildImagen4ModelFieldInputTemplate,
ChatGPT4oModelField: buildChatGPT4oModelFieldInputTemplate,
FluxKontextModelField: buildFluxKontextModelFieldInputTemplate,
Veo3ModelField: buildVeo3ModelFieldInputTemplate,
RunwayModelField: buildRunwayModelFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
ImageGeneratorField: buildImageGeneratorFieldInputTemplate,
};
} as const;
export const buildFieldInputTemplate = (
fieldSchema: InvocationFieldSchema,
fieldName: string,
fieldType: FieldType
): FieldInputTemplate => {
const {
input,
ui_hidden,
ui_component,
ui_type,
ui_order,
ui_choice_labels,
orig_required: required,
ui_model_base,
ui_model_type,
ui_model_variant,
ui_model_format,
} = fieldSchema;
const { input, ui_hidden, ui_component, ui_type, ui_order, ui_choice_labels, orig_required: required } = fieldSchema;
// This is the base field template that is common to all fields. The builder function will add all other
// properties to this template.
@@ -499,10 +908,6 @@ export const buildFieldInputTemplate = (
ui_type,
ui_order,
ui_choice_labels,
ui_model_base,
ui_model_type,
ui_model_variant,
ui_model_format,
};
if (isStatefulFieldType(fieldType)) {

View File

@@ -6,8 +6,8 @@ import { pick } from 'es-toolkit/compat';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import { zWorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import i18n from 'i18n';
import { useCallback } from 'react';
import { fromZodError } from 'zod-validation-error/v4';
@@ -23,27 +23,21 @@ const workflowKeys = [
'tags',
'notes',
'exposedFields',
'output_fields',
'meta',
'id',
'form',
] satisfies (keyof WorkflowV4)[];
] satisfies (keyof WorkflowV3)[];
export const buildWorkflowFast = (nodesState: NodesState): WorkflowV4 => {
export const buildWorkflowFast = (nodesState: NodesState): WorkflowV3 => {
const { nodes, edges, ...rest } = nodesState;
const clonedWorkflow = pick(rest, workflowKeys);
const newWorkflow: WorkflowV4 = {
const newWorkflow: WorkflowV3 = {
...clonedWorkflow,
nodes: [],
edges: [],
};
newWorkflow.meta = {
...newWorkflow.meta,
version: '4.0.0',
};
for (const node of nodes) {
if (isInvocationNode(node) && node.type) {
const { id, type, data, position } = node;
@@ -67,12 +61,12 @@ export const buildWorkflowFast = (nodesState: NodesState): WorkflowV4 => {
return deepClone(newWorkflow);
};
export const buildWorkflowWithValidation = (nodesState: NodesState): WorkflowV4 | null => {
export const buildWorkflowWithValidation = (nodesState: NodesState): WorkflowV3 | null => {
// builds what really, really should be a valid workflow
const workflowToValidate = buildWorkflowFast(nodesState);
// but bc we are storing this in the DB, let's be extra sure
const result = zWorkflowV4.safeParse(workflowToValidate);
const result = zWorkflowV3.safeParse(workflowToValidate);
if (!result.success) {
const { message } = fromZodError(result.error, {
@@ -86,7 +80,7 @@ export const buildWorkflowWithValidation = (nodesState: NodesState): WorkflowV4
return result.data;
};
export const useBuildWorkflowFast = (): (() => WorkflowV4) => {
export const useBuildWorkflowFast = (): (() => WorkflowV3) => {
const store = useAppStore();
const buildWorkflow = useCallback(() => {
const nodesState = selectNodesSlice(store.getState());

View File

@@ -4,7 +4,7 @@ import { forEach } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { getDefaultForm } from 'features/nodes/types/workflow';
import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance';
import type { NonNullableGraph } from 'services/api/types';
@@ -19,24 +19,23 @@ const log = logger('workflows');
* @param autoLayout Whether to auto-layout the nodes using `dagre`. If false, nodes will be simply stacked on top of one another with an offset.
* @returns The workflow.
*/
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV4 => {
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
const templates = $templates.get();
// Initialize the workflow
const workflow: WorkflowV4 = {
const workflow: WorkflowV3 = {
name: '',
author: '',
contact: '',
description: '',
meta: {
category: 'user',
version: '4.0.0',
version: '3.0.0',
},
notes: '',
tags: '',
version: '',
exposedFields: [],
output_fields: [],
edges: [],
nodes: [],
form: getDefaultForm(),

View File

@@ -10,8 +10,8 @@ import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1';
import type { StatelessFieldType } from 'features/nodes/types/v2/field';
import type { WorkflowV2 } from 'features/nodes/types/v2/workflow';
import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
import type { WorkflowOutputField, WorkflowV3, WorkflowV4 } from 'features/nodes/types/workflow';
import { zWorkflowV3, zWorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import { t } from 'i18next';
import { z } from 'zod';
@@ -76,72 +76,12 @@ const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => {
return zWorkflowV3.parse(workflowToMigrate);
};
const normalizeOutputField = (field: unknown): WorkflowOutputField | null => {
if (!field || typeof field !== 'object') {
return null;
}
const maybeField = field as Record<string, unknown>;
const nodeId =
typeof maybeField.nodeId === 'string'
? maybeField.nodeId
: typeof maybeField.node_id === 'string'
? maybeField.node_id
: null;
const fieldName =
typeof maybeField.fieldName === 'string'
? maybeField.fieldName
: typeof maybeField.field_name === 'string'
? maybeField.field_name
: null;
if (!nodeId || !fieldName) {
return null;
}
const userLabel =
typeof maybeField.userLabel === 'string'
? maybeField.userLabel
: typeof maybeField.userLabel === 'object'
? null
: typeof maybeField.user_label === 'string'
? maybeField.user_label
: null;
return {
kind: 'output',
nodeId,
fieldName,
userLabel,
node_id: nodeId,
field_name: fieldName,
user_label: userLabel,
} satisfies WorkflowOutputField;
};
const migrateV3toV4 = (workflowToMigrate: WorkflowV3 | Record<string, unknown>): WorkflowV4 => {
const rawOutputs = Array.isArray((workflowToMigrate as Record<string, unknown>).output_fields)
? ((workflowToMigrate as Record<string, unknown>).output_fields as unknown[])
: [];
const normalizedOutputs = rawOutputs
.map(normalizeOutputField)
.filter((field): field is WorkflowOutputField => field !== null)
.slice(0, 1);
const migrated = {
...(workflowToMigrate as Record<string, unknown>),
output_fields: normalizedOutputs,
} as WorkflowV4;
migrated.meta.version = '4.0.0';
return zWorkflowV4.parse(migrated);
};
/**
* Parses a workflow and migrates it to the latest version if necessary.
*
* This function will return a new workflow object, so the original workflow is not modified.
*/
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV4 => {
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
if (!workflowVersionResult.success) {
@@ -160,13 +100,8 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV4 => {
workflow = migrateV2toV3(v2);
}
if (get(workflow, 'meta.version') === '3.0.0') {
const v3 = zWorkflowV3.parse(workflow);
workflow = migrateV3toV4(v3);
}
// We should now have a V4 workflow
const migratedWorkflow = zWorkflowV4.parse(workflow);
// We should now have a V3 workflow
const migratedWorkflow = zWorkflowV3.parse(workflow);
return migratedWorkflow;
};

View File

@@ -1,13 +1,13 @@
import { get } from 'es-toolkit/compat';
import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { getDefaultForm } from 'features/nodes/types/workflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { describe, expect, it } from 'vitest';
//TODO(psyche): Test workflow validation for form builder fields
describe('validateWorkflow', () => {
const getWorkflow = (): WorkflowV4 => ({
const getWorkflow = (): WorkflowV3 => ({
name: '',
author: '',
description: '',
@@ -16,9 +16,8 @@ describe('validateWorkflow', () => {
tags: '',
notes: '',
exposedFields: [],
output_fields: [],
form: getDefaultForm(),
meta: { version: '4.0.0', category: 'user' },
meta: { version: '3.0.0', category: 'user' },
nodes: [
{
id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad',

View File

@@ -8,7 +8,7 @@ import {
isModelFieldType,
isModelIdentifierFieldInputInstance,
} from 'features/nodes/types/field';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import {
buildNodeFieldElement,
getDefaultForm,
@@ -36,7 +36,7 @@ type ValidateWorkflowArgs = {
};
type ValidateWorkflowResult = {
workflow: WorkflowV4;
workflow: WorkflowV3;
warnings: WorkflowWarning[];
};
@@ -233,65 +233,6 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
// Remove invalid edges
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
if (_workflow.output_fields.length > 0) {
const validOutputs: typeof _workflow.output_fields = [];
for (const output of _workflow.output_fields) {
const node = nodes.find(({ id }) => id === output.nodeId);
if (!node) {
warnings.push({
message: t('workflows.builder.outputFieldMissingNode', {
defaultValue: 'Removed workflow output referencing a missing node.',
}),
data: parseify(output),
});
continue;
}
if (!isWorkflowInvocationNode(node)) {
warnings.push({
message: t('workflows.builder.outputFieldInvalidNode', {
defaultValue: 'Removed workflow output because the selected node is not an invocation node.',
}),
data: parseify({ output, node }),
});
continue;
}
const template = templates[node.data.type];
if (!template) {
warnings.push({
message: t('workflows.builder.outputFieldMissingTemplate', {
defaultValue: 'Removed workflow output because the node template is missing.',
}),
data: parseify({ output, node }),
});
continue;
}
const field = template.outputs[output.fieldName];
if (!field) {
warnings.push({
message: t('workflows.builder.outputFieldMissingHandle', {
defaultValue: 'Removed workflow output because the node no longer exposes that field.',
}),
data: parseify({ output, node, template }),
});
continue;
}
if (!(field.type.name === 'ImageField' && field.type.cardinality !== 'COLLECTION')) {
warnings.push({
message: t('workflows.builder.outputFieldMustBeImage', {
defaultValue: 'Removed workflow output because it is not an image field.',
}),
data: parseify({ output, node, field }),
});
continue;
}
validOutputs.push(output);
if (validOutputs.length >= 1) {
break;
}
}
_workflow.output_fields = validOutputs.slice(0, 1);
}
// Migrated exposed fields to form elements if they exist and the form does not
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!
if (_workflow.exposedFields.length > 0 && getIsFormEmpty(_workflow.form)) {

View File

@@ -16,7 +16,6 @@ import { SelectObject } from 'features/controlLayers/components/SelectObject/Sel
import { StagingAreaContextProvider } from 'features/controlLayers/components/StagingArea/context';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
import { Transform } from 'features/controlLayers/components/Transform/Transform';
import { TriggerWorkflow } from 'features/controlLayers/components/TriggerWorkflow/TriggerWorkflow';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
@@ -113,7 +112,6 @@ export const CanvasWorkspacePanel = memo(() => {
<Flex position="absolute" bottom={4}>
<CanvasManagerProviderGate>
<Filter />
<TriggerWorkflow />
<Transform />
<SelectObject />
</CanvasManagerProviderGate>

View File

@@ -3,7 +3,7 @@ import { useStore } from '@nanostores/react';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useLoadWorkflowFromFile } from 'features/workflowLibrary/hooks/useLoadWorkflowFromFile';
import { useLoadWorkflowFromImage } from 'features/workflowLibrary/hooks/useLoadWorkflowFromImage';
import { useLoadWorkflowFromLibrary } from 'features/workflowLibrary/hooks/useLoadWorkflowFromLibrary';
@@ -13,7 +13,7 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type LoadWorkflowOptions = {
onSuccess?: (workflow: WorkflowV4) => void;
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
};

View File

@@ -15,7 +15,7 @@ import { useStore } from '@nanostores/react';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { deepClone } from 'common/util/deepClone';
import { $workflowLibraryCategoriesOptions } from 'features/nodes/store/workflowLibrarySlice';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { isDraftWorkflow, useCreateLibraryWorkflow } from 'features/workflowLibrary/hooks/useCreateNewWorkflow';
import { t } from 'i18next';
import { atom, computed } from 'nanostores';
@@ -28,7 +28,7 @@ import { assert } from 'tsafe';
*
* This state is used to determine whether or not the modal is open.
*/
const $workflowToSave = atom<WorkflowV4 | null>(null);
const $workflowToSave = atom<WorkflowV3 | null>(null);
/**
* Whether or not the modal is open. It is open if there is a workflow to save.
@@ -40,7 +40,7 @@ const $workflowToSave = atom<WorkflowV4 | null>(null);
*/
const $isOpen = computed($workflowToSave, (val) => val !== null);
const getInitialName = (workflow: WorkflowV4): string => {
const getInitialName = (workflow: WorkflowV3): string => {
if (!workflow.id) {
// If the workflow has no ID, that means it's a new workflow that has never been saved to the server. In this case,
// we should use whatever the user has entered in the workflow name field.
@@ -60,7 +60,7 @@ const getInitialName = (workflow: WorkflowV4): string => {
* The workflow object is deep cloned to prevent any changes to the original workflow object.
* @param workflow The workflow to save as a new workflow.
*/
export const saveWorkflowAs = (workflow: WorkflowV4) => {
export const saveWorkflowAs = (workflow: WorkflowV3) => {
$workflowToSave.set(deepClone(workflow));
};
@@ -82,7 +82,7 @@ export const SaveWorkflowAsDialog = () => {
);
};
const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV4; cancelRef: RefObject<HTMLButtonElement> }) => {
const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef: RefObject<HTMLButtonElement> }) => {
const workflowCategories = useStore($workflowLibraryCategoriesOptions);
const [name, setName] = useState(() => {
if (workflow) {

View File

@@ -7,7 +7,7 @@ import {
workflowIDChanged,
workflowNameChanged,
} from 'features/nodes/store/nodesSlice';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useGetFormFieldInitialValues } from 'features/workflowLibrary/hooks/useGetFormInitialValues';
import { newWorkflowSaved } from 'features/workflowLibrary/store/actions';
import { useCallback, useRef } from 'react';
@@ -19,12 +19,12 @@ import type { SetFieldType } from 'type-fest';
* A draft workflow is a workflow that is has not been saved yet. It does not have an id and is not in the default category.
*/
type DraftWorkflow = SetFieldType<
SetFieldType<WorkflowV4, 'id', undefined>,
SetFieldType<WorkflowV3, 'id', undefined>,
'meta',
SetFieldType<WorkflowV4['meta'], 'category', Exclude<WorkflowV4['meta']['category'], 'default'>>
SetFieldType<WorkflowV3['meta'], 'category', Exclude<WorkflowV3['meta']['category'], 'default'>>
>;
export const isDraftWorkflow = (workflow: WorkflowV4): workflow is DraftWorkflow =>
export const isDraftWorkflow = (workflow: WorkflowV3): workflow is DraftWorkflow =>
!workflow.id && workflow.meta.category !== 'default';
type CreateLibraryWorkflowArg = {

View File

@@ -1,5 +1,5 @@
import { useAppDispatch } from 'app/store/storeHooks';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
import { workflowLoadedFromFile } from 'features/workflowLibrary/store/actions';
import { useCallback } from 'react';
@@ -17,12 +17,12 @@ export const useLoadWorkflowFromFile = () => {
(
file: File,
options: {
onSuccess?: (workflow: WorkflowV4) => void;
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}
) => {
return new Promise<WorkflowV4 | void>((resolve, reject) => {
return new Promise<WorkflowV3 | void>((resolve, reject) => {
const reader = new FileReader();
reader.onload = async () => {
const rawJSON = reader.result;

View File

@@ -1,4 +1,4 @@
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { toast } from 'features/toast/toast';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
@@ -21,7 +21,7 @@ export const useLoadWorkflowFromImage = () => {
async (
imageName: string,
options: {
onSuccess?: (workflow: WorkflowV4) => void;
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}

View File

@@ -1,5 +1,5 @@
import { useToast } from '@invoke-ai/ui-library';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -21,7 +21,7 @@ export const useLoadWorkflowFromLibrary = () => {
async (
workflowId: string,
options: {
onSuccess?: (workflow: WorkflowV4) => void;
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}

View File

@@ -1,4 +1,4 @@
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
import { useCallback } from 'react';
@@ -14,7 +14,7 @@ export const useLoadWorkflowFromObject = () => {
async (
unvalidatedWorkflow: unknown,
options: {
onSuccess?: (workflow: WorkflowV4) => void;
onSuccess?: (workflow: WorkflowV3) => void;
onError?: () => void;
onCompleted?: () => void;
} = {}

View File

@@ -2,7 +2,7 @@ import type { ToastId } from '@invoke-ai/ui-library';
import { useToast } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { formFieldInitialValuesChanged } from 'features/nodes/store/nodesSlice';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { useGetFormFieldInitialValues } from 'features/workflowLibrary/hooks/useGetFormInitialValues';
import { workflowUpdated } from 'features/workflowLibrary/store/actions';
import { useCallback, useRef } from 'react';
@@ -14,12 +14,12 @@ import type { SetFieldType, SetRequired } from 'type-fest';
* A library workflow is a workflow that is already saved in the library. It has an id and is not in the default category.
*/
type LibraryWorkflow = SetFieldType<
SetRequired<WorkflowV4, 'id'>,
SetRequired<WorkflowV3, 'id'>,
'meta',
SetFieldType<WorkflowV4['meta'], 'category', Exclude<WorkflowV4['meta']['category'], 'default'>>
SetFieldType<WorkflowV3['meta'], 'category', Exclude<WorkflowV3['meta']['category'], 'default'>>
>;
export const isLibraryWorkflow = (workflow: WorkflowV4): workflow is LibraryWorkflow =>
export const isLibraryWorkflow = (workflow: WorkflowV3): workflow is LibraryWorkflow =>
!!workflow.id && workflow.meta.category !== 'default';
type UseSaveLibraryWorkflowReturn = {

View File

@@ -6,7 +6,7 @@ import { $templates, workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import type { WorkflowV4 } from 'features/nodes/types/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
@@ -48,7 +48,7 @@ export const useValidateAndLoadWorkflow = () => {
async (
unvalidatedWorkflow: unknown,
origin: 'file' | 'image' | 'object' | 'library'
): Promise<WorkflowV4 | null> => {
): Promise<WorkflowV3 | null> => {
try {
const templates = $templates.get();
const { workflow, warnings } = await validateWorkflow({

View File

@@ -1,18 +0,0 @@
import { atom } from 'nanostores';
import type { WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types';
type WorkflowLibraryIntent =
| { mode: 'browse' }
| { mode: 'trigger-workflow'; onSelect: (workflow: WorkflowRecordListItemWithThumbnailDTO) => void };
export const $workflowLibraryIntent = atom<WorkflowLibraryIntent>({ mode: 'browse' });
export const setWorkflowLibraryBrowseIntent = () => {
$workflowLibraryIntent.set({ mode: 'browse' });
};
export const setWorkflowLibraryTriggerIntent = (
onSelect: (workflow: WorkflowRecordListItemWithThumbnailDTO) => void
) => {
$workflowLibraryIntent.set({ mode: 'trigger-workflow', onSelect });
};

View File

@@ -12,25 +12,34 @@ import {
isChatGPT4oModelConfig,
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isCogView4MainModelModelConfig,
isControlLayerModelConfig,
isControlLoRAModelConfig,
isControlNetModelConfig,
isFluxKontextApiModelConfig,
isFluxKontextModelConfig,
isFluxMainModelModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isGemini2_5ModelConfig,
isImagen3ModelConfig,
isImagen4ModelConfig,
isIPAdapterModelConfig,
isLLaVAModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isRunwayModelConfig,
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSigLipModelConfig,
isSpandrelImageToImageModelConfig,
isT2IAdapterModelConfig,
isT5EncoderModelConfig,
isTIModelConfig,
isVAEModelConfig,
isVeo3ModelConfig,
isVideoModelConfig,
} from 'services/api/types';
@@ -57,7 +66,12 @@ const buildModelsHook =
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useCogView4Models = buildModelsHook(isCogView4MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlLoRAModel = buildModelsHook(isControlLoRAModelConfig);
export const useControlLayerModels = buildModelsHook(isControlLayerModelConfig);
@@ -89,6 +103,12 @@ export const useRegionalReferenceImageModels = buildModelsHook(
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
);
export const useLLaVAModels = buildModelsHook(isLLaVAModelConfig);
export const useImagen3Models = buildModelsHook(isImagen3ModelConfig);
export const useImagen4Models = buildModelsHook(isImagen4ModelConfig);
export const useChatGPT4oModels = buildModelsHook(isChatGPT4oModelConfig);
export const useFluxKontextModels = buildModelsHook(isFluxKontextApiModelConfig);
export const useVeo3Models = buildModelsHook(isVeo3ModelConfig);
export const useRunwayModels = buildModelsHook(isRunwayModelConfig);
export const useVideoModels = buildModelsHook(isVideoModelConfig);
const buildModelsSelector =

View File

@@ -5580,7 +5580,7 @@ export type components = {
repo_variant?: components["schemas"]["ModelRepoVariant"] | null;
};
/**
* ControlNet - SD1.5, SD2, SDXL
* ControlNet - SD1.5, SDXL
* @description Collects ControlNet info to pass to other nodes
*/
ControlNetInvocation: {
@@ -11805,26 +11805,6 @@ export type components = {
ui_choice_labels: {
[key: string]: string;
} | null;
/**
* Ui Model Base
* @default null
*/
ui_model_base: components["schemas"]["BaseModelType"][] | null;
/**
* Ui Model Type
* @default null
*/
ui_model_type: components["schemas"]["ModelType"][] | null;
/**
* Ui Model Variant
* @default null
*/
ui_model_variant: (components["schemas"]["ClipVariantType"] | components["schemas"]["ModelVariantType"])[] | null;
/**
* Ui Model Format
* @default null
*/
ui_model_format: components["schemas"]["ModelFormat"][] | null;
};
/**
* InstallStatus
@@ -15183,7 +15163,7 @@ export type components = {
guidance?: number | null;
};
/**
* Main Model - SD1.5, SD2
* Main Model - SD1.5
* @description Loads a main model, outputting its submodels.
*/
MainModelLoaderInvocation: {
@@ -17739,18 +17719,11 @@ export type components = {
*/
OutputFieldJSONSchemaExtra: {
field_kind: components["schemas"]["FieldKind"];
/**
* Ui Hidden
* @default false
*/
/** Ui Hidden */
ui_hidden: boolean;
/**
* Ui Order
* @default null
*/
ui_order: number | null;
/** @default null */
ui_type: components["schemas"]["UIType"] | null;
/** Ui Order */
ui_order: number | null;
};
/** PaginatedResults[WorkflowRecordListItemWithThumbnailDTO] */
PaginatedResults_WorkflowRecordListItemWithThumbnailDTO_: {
@@ -21857,7 +21830,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string}
*/
UIType: "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict" | "DEPRECATED_MainModelField" | "DEPRECATED_CogView4MainModelField" | "DEPRECATED_FluxMainModelField" | "DEPRECATED_SD3MainModelField" | "DEPRECATED_SDXLMainModelField" | "DEPRECATED_SDXLRefinerModelField" | "DEPRECATED_ONNXModelField" | "DEPRECATED_VAEModelField" | "DEPRECATED_FluxVAEModelField" | "DEPRECATED_LoRAModelField" | "DEPRECATED_ControlNetModelField" | "DEPRECATED_IPAdapterModelField" | "DEPRECATED_T2IAdapterModelField" | "DEPRECATED_T5EncoderModelField" | "DEPRECATED_CLIPEmbedModelField" | "DEPRECATED_CLIPLEmbedModelField" | "DEPRECATED_CLIPGEmbedModelField" | "DEPRECATED_SpandrelImageToImageModelField" | "DEPRECATED_ControlLoRAModelField" | "DEPRECATED_SigLipModelField" | "DEPRECATED_FluxReduxModelField" | "DEPRECATED_LLaVAModelField" | "DEPRECATED_Imagen3ModelField" | "DEPRECATED_Imagen4ModelField" | "DEPRECATED_ChatGPT4oModelField" | "DEPRECATED_Gemini2_5ModelField" | "DEPRECATED_FluxKontextModelField" | "DEPRECATED_Veo3ModelField" | "DEPRECATED_RunwayModelField";
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "Imagen3ModelField" | "Imagen4ModelField" | "ChatGPT4oModelField" | "Gemini2_5ModelField" | "FluxKontextModelField" | "Veo3ModelField" | "RunwayModelField" | "SchedulerField" | "AnyField" | "VideoField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */
UNetField: {
/** @description Info to load unet submodel */
@@ -22203,7 +22176,7 @@ export type components = {
seamless_axes?: string[];
};
/**
* VAE Model - SD1.5, SD2, SDXL, SD3, FLUX
* VAE Model - SD1.5, SDXL, SD3, FLUX
* @description Loads a VAE model, outputting a VaeLoaderOutput
*/
VAELoaderInvocation: {
@@ -22572,11 +22545,6 @@ export type components = {
* @description The exposed fields of the workflow.
*/
exposedFields: components["schemas"]["ExposedField"][];
/**
* Output Fields
* @description The fields designated as output fields for the workflow.
*/
output_fields?: components["schemas"]["FieldIdentifier"][] | null;
/** @description The meta of the workflow. */
meta: components["schemas"]["WorkflowMeta"];
/**
@@ -22718,12 +22686,6 @@ export type components = {
* @description The tags of the workflow.
*/
tags: string;
/**
* Has Valid Image Output Field
* @description True when the workflow exposes exactly one output field and it is an image output.
* @default false
*/
has_valid_image_output_field?: boolean;
/**
* Thumbnail Url
* @description The URL of the workflow thumbnail.
@@ -22818,11 +22780,6 @@ export type components = {
* @description The exposed fields of the workflow.
*/
exposedFields: components["schemas"]["ExposedField"][];
/**
* Output Fields
* @description The fields designated as output fields for the workflow.
*/
output_fields?: components["schemas"]["FieldIdentifier"][] | null;
/** @description The meta of the workflow. */
meta: components["schemas"]["WorkflowMeta"];
/**

Some files were not shown because too many files have changed in this diff Show More