mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 10:48:07 -05:00
Compare commits
1 Commits
maryhipp/w
...
controlnet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21a05f4287 |
@@ -5,7 +5,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import (
|
||||
GlmEncoderField,
|
||||
ModelIdentifierField,
|
||||
@@ -14,7 +14,6 @@ from invokeai.app.invocations.model import (
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("cogview4_model_loader_output")
|
||||
@@ -39,9 +38,8 @@ class CogView4ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.cogview4_model,
|
||||
ui_type=UIType.CogView4MainModel,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.CogView4,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:
|
||||
|
||||
@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
@@ -27,7 +28,6 @@ from invokeai.app.util.controlnet_utils import (
|
||||
heuristic_resize_fast,
|
||||
)
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
@@ -63,17 +63,13 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation(
|
||||
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
|
||||
)
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL],
|
||||
ui_model_type=ModelType.ControlNet,
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
|
||||
@@ -7,13 +7,6 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
@@ -46,9 +39,42 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
CogView4MainModel = "CogView4MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
LlavaOnevisionModel = "LLaVAModelField"
|
||||
Imagen3Model = "Imagen3ModelField"
|
||||
Imagen4Model = "Imagen4ModelField"
|
||||
ChatGPT4oModel = "ChatGPT4oModelField"
|
||||
Gemini2_5Model = "Gemini2_5ModelField"
|
||||
FluxKontextModel = "FluxKontextModelField"
|
||||
Veo3Model = "Veo3ModelField"
|
||||
RunwayModel = "RunwayModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
Video = "VideoField"
|
||||
# endregion
|
||||
|
||||
# region Internal Field Types
|
||||
@@ -98,38 +124,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
|
||||
# Deprecated Model Field Types - use ui_model_[base|type|variant|format] instead
|
||||
MainModel = "DEPRECATED_MainModelField"
|
||||
CogView4MainModel = "DEPRECATED_CogView4MainModelField"
|
||||
FluxMainModel = "DEPRECATED_FluxMainModelField"
|
||||
SD3MainModel = "DEPRECATED_SD3MainModelField"
|
||||
SDXLMainModel = "DEPRECATED_SDXLMainModelField"
|
||||
SDXLRefinerModel = "DEPRECATED_SDXLRefinerModelField"
|
||||
ONNXModel = "DEPRECATED_ONNXModelField"
|
||||
VAEModel = "DEPRECATED_VAEModelField"
|
||||
FluxVAEModel = "DEPRECATED_FluxVAEModelField"
|
||||
LoRAModel = "DEPRECATED_LoRAModelField"
|
||||
ControlNetModel = "DEPRECATED_ControlNetModelField"
|
||||
IPAdapterModel = "DEPRECATED_IPAdapterModelField"
|
||||
T2IAdapterModel = "DEPRECATED_T2IAdapterModelField"
|
||||
T5EncoderModel = "DEPRECATED_T5EncoderModelField"
|
||||
CLIPEmbedModel = "DEPRECATED_CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "DEPRECATED_CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "DEPRECATED_CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "DEPRECATED_SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "DEPRECATED_ControlLoRAModelField"
|
||||
SigLipModel = "DEPRECATED_SigLipModelField"
|
||||
FluxReduxModel = "DEPRECATED_FluxReduxModelField"
|
||||
LlavaOnevisionModel = "DEPRECATED_LLaVAModelField"
|
||||
Imagen3Model = "DEPRECATED_Imagen3ModelField"
|
||||
Imagen4Model = "DEPRECATED_Imagen4ModelField"
|
||||
ChatGPT4oModel = "DEPRECATED_ChatGPT4oModelField"
|
||||
Gemini2_5Model = "DEPRECATED_Gemini2_5ModelField"
|
||||
FluxKontextModel = "DEPRECATED_FluxKontextModelField"
|
||||
Veo3Model = "DEPRECATED_Veo3ModelField"
|
||||
RunwayModel = "DEPRECATED_RunwayModelField"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
@@ -415,10 +409,6 @@ class InputFieldJSONSchemaExtra(BaseModel):
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
ui_model_base: Optional[list[BaseModelType]] = None
|
||||
ui_model_type: Optional[list[ModelType]] = None
|
||||
ui_model_variant: Optional[list[ClipVariantType | ModelVariantType]] = None
|
||||
ui_model_format: Optional[list[ModelFormat]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
@@ -475,9 +465,9 @@ class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool = False
|
||||
ui_order: Optional[int] = None
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
@@ -511,63 +501,35 @@ def InputField(
|
||||
ui_hidden: Optional[bool] = None,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
ui_model_base: Optional[BaseModelType | list[BaseModelType]] = None,
|
||||
ui_model_type: Optional[ModelType | list[ModelType]] = None,
|
||||
ui_model_variant: Optional[ClipVariantType | ModelVariantType | list[ClipVariantType | ModelVariantType]] = None,
|
||||
ui_model_format: Optional[ModelFormat | list[ModelFormat]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field)
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
If the field is a `ModelIdentifierField`, use the `ui_model_[base|type|variant|format]` args to filter the model list
|
||||
in the Workflow Editor. Otherwise, use `ui_type` to provide extra type hints for the UI.
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
Don't use both `ui_type` and `ui_model_[base|type|variant|format]` - if both are provided, a warning will be
|
||||
logged and `ui_type` will be ignored.
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
Args:
|
||||
input: The kind of input this field requires.
|
||||
- `Input.Direct` means a value must be provided on instantiation.
|
||||
- `Input.Connection` means the value must be provided by a connection.
|
||||
- `Input.Any` means either will do.
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
|
||||
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
|
||||
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
ui_component: Optionally specifies a specific component to use in the UI. The UI will always render a suitable
|
||||
component, but sometimes you want something different than the default. For example, a `string` field will
|
||||
default to a single-line input, but you may want a multi-line textarea instead. In this case, you could use
|
||||
`UIComponent.Textarea`.
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
ui_hidden: Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
|
||||
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
|
||||
|
||||
ui_model_base: Specifies the base model architectures to filter the model list by in the Workflow Editor. For
|
||||
example, `ui_model_base=BaseModelType.StableDiffusionXL` will show only SDXL architecture models. This arg is
|
||||
only valid if this Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_type: Specifies the model type(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_type=ModelType.VAE` will show only VAE models. This arg is only valid if this Input field is
|
||||
annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_variant: Specifies the model variant(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_variant=ModelVariantType.Inpainting` will show only inpainting models. This arg is only valid if this
|
||||
Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_model_format: Specifies the model format(s) to filter the model list by in the Workflow Editor. For example,
|
||||
`ui_model_format=ModelFormat.Diffusers` will show only models in the diffusers format. This arg is only valid
|
||||
if this Input field is annotated as a `ModelIdentifierField`.
|
||||
|
||||
ui_choice_labels: Specifies the labels to use for the choices in an enum field. If omitted, the enum values
|
||||
will be used. This arg is only valid if the field is annotated with as a `Literal`. For example,
|
||||
`Literal["choice1", "choice2", "choice3"]` with `ui_choice_labels={"choice1": "Choice 1", "choice2": "Choice 2",
|
||||
"choice3": "Choice 3"}` will render a dropdown with the labels "Choice 1", "Choice 2" and "Choice 3".
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
@@ -576,92 +538,7 @@ def InputField(
|
||||
)
|
||||
|
||||
if ui_type is not None:
|
||||
if (
|
||||
ui_model_base is not None
|
||||
or ui_model_type is not None
|
||||
or ui_model_variant is not None
|
||||
or ui_model_format is not None
|
||||
):
|
||||
logger.warning("InputField: Use either ui_type or ui_model_[base|type|variant|format]. Ignoring ui_type.")
|
||||
# Map old-style UIType to new-style ui_model_[base|type|variant|format]
|
||||
elif ui_type is UIType.MainModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.CogView4MainModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.CogView4]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.FluxMainModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.SD3MainModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusion3]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.SDXLMainModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXL]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.SDXLRefinerModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
# Think this UIType is unused...?
|
||||
# elif ui_type is UIType.ONNXModel:
|
||||
# json_schema_extra_.ui_model_base =
|
||||
# json_schema_extra_.ui_model_type =
|
||||
elif ui_type is UIType.VAEModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.VAE]
|
||||
elif ui_type is UIType.FluxVAEModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
|
||||
json_schema_extra_.ui_model_type = [ModelType.VAE]
|
||||
elif ui_type is UIType.LoRAModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.LoRA]
|
||||
elif ui_type is UIType.ControlNetModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.ControlNet]
|
||||
elif ui_type is UIType.IPAdapterModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.IPAdapter]
|
||||
elif ui_type is UIType.T2IAdapterModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.T2IAdapter]
|
||||
elif ui_type is UIType.T5EncoderModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.T5Encoder]
|
||||
elif ui_type is UIType.CLIPEmbedModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
|
||||
elif ui_type is UIType.CLIPLEmbedModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
|
||||
json_schema_extra_.ui_model_variant = [ClipVariantType.L]
|
||||
elif ui_type is UIType.CLIPGEmbedModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
|
||||
json_schema_extra_.ui_model_variant = [ClipVariantType.G]
|
||||
elif ui_type is UIType.SpandrelImageToImageModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.SpandrelImageToImage]
|
||||
elif ui_type is UIType.ControlLoRAModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.ControlLoRa]
|
||||
elif ui_type is UIType.SigLipModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.SigLIP]
|
||||
elif ui_type is UIType.FluxReduxModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.FluxRedux]
|
||||
elif ui_type is UIType.LlavaOnevisionModel:
|
||||
json_schema_extra_.ui_model_type = [ModelType.LlavaOnevision]
|
||||
elif ui_type is UIType.Imagen3Model:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.Imagen3]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.Imagen4Model:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.Imagen4]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.ChatGPT4oModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.ChatGPT4o]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.Gemini2_5Model:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.Gemini2_5]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.FluxKontextModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.FluxKontext]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Main]
|
||||
elif ui_type is UIType.Veo3Model:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.Veo3]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Video]
|
||||
elif ui_type is UIType.RunwayModel:
|
||||
json_schema_extra_.ui_model_base = [BaseModelType.Runway]
|
||||
json_schema_extra_.ui_model_type = [ModelType.Video]
|
||||
else:
|
||||
json_schema_extra_.ui_type = ui_type
|
||||
|
||||
json_schema_extra_.ui_type = ui_type
|
||||
if ui_component is not None:
|
||||
json_schema_extra_.ui_component = ui_component
|
||||
if ui_hidden is not None:
|
||||
@@ -670,26 +547,6 @@ def InputField(
|
||||
json_schema_extra_.ui_order = ui_order
|
||||
if ui_choice_labels is not None:
|
||||
json_schema_extra_.ui_choice_labels = ui_choice_labels
|
||||
if ui_model_base is not None:
|
||||
if isinstance(ui_model_base, list):
|
||||
json_schema_extra_.ui_model_base = ui_model_base
|
||||
else:
|
||||
json_schema_extra_.ui_model_base = [ui_model_base]
|
||||
if ui_model_type is not None:
|
||||
if isinstance(ui_model_type, list):
|
||||
json_schema_extra_.ui_model_type = ui_model_type
|
||||
else:
|
||||
json_schema_extra_.ui_model_type = [ui_model_type]
|
||||
if ui_model_variant is not None:
|
||||
if isinstance(ui_model_variant, list):
|
||||
json_schema_extra_.ui_model_variant = ui_model_variant
|
||||
else:
|
||||
json_schema_extra_.ui_model_variant = [ui_model_variant]
|
||||
if ui_model_format is not None:
|
||||
if isinstance(ui_model_format, list):
|
||||
json_schema_extra_.ui_model_format = ui_model_format
|
||||
else:
|
||||
json_schema_extra_.ui_model_format = [ui_model_format]
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
@@ -791,20 +648,20 @@ def OutputField(
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization)
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
Args:
|
||||
ui_type: Optionally provides an extra type hint for the UI. In some situations, the field's type is not enough
|
||||
to infer the correct UI type. For example, Scheduler fields are enums, but we want to render a special scheduler
|
||||
dropdown in the UI. Use `UIType.Scheduler` to indicate this.
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
ui_hidden: Specifies whether or not this field should be hidden in the UI.
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
ui_order: Specifies the order in which this field should be rendered in the UI. If omitted, the field will be
|
||||
rendered after all fields with an explicit order, in the order they are defined in the Invocation class.
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
@@ -822,9 +679,9 @@ def OutputField(
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_type=ui_type,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
@@ -4,10 +4,9 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("flux_control_lora_loader_output")
|
||||
@@ -30,10 +29,7 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.control_lora_model,
|
||||
title="Control LoRA",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.ControlLoRa,
|
||||
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
|
||||
)
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
|
||||
|
||||
@@ -6,12 +6,11 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class FluxControlNetField(BaseModel):
|
||||
@@ -58,9 +57,7 @@ class FluxControlNetInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.ControlNet,
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float | list[float] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
from invokeai.app.invocations.fields import InputField, UIType
|
||||
from invokeai.app.invocations.ip_adapter import (
|
||||
CLIP_VISION_MODEL_MAP,
|
||||
IPAdapterField,
|
||||
@@ -20,7 +20,6 @@ from invokeai.backend.model_manager.config import (
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -37,10 +36,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.IPAdapter,
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", ui_type=UIType.IPAdapterModel
|
||||
)
|
||||
# Currently, the only known ViT model used by FLUX IP-Adapters is ViT-L.
|
||||
clip_vision_model: Literal["ViT-L"] = InputField(description="CLIP Vision model to use.", default="ViT-L")
|
||||
|
||||
@@ -6,10 +6,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
@@ -36,10 +36,7 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
@@ -17,7 +17,7 @@ from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
@@ -46,30 +46,23 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
title="VAE",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.VAE,
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
|
||||
@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
@@ -63,8 +64,7 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
redux_model: ModelIdentifierField = InputField(
|
||||
description="The FLUX Redux model to use.",
|
||||
title="FLUX Redux Model",
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.FluxRedux,
|
||||
ui_type=UIType.FluxReduxModel,
|
||||
)
|
||||
downsampling_factor: int = InputField(
|
||||
ge=1,
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@@ -85,8 +85,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
ui_order=-1,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
|
||||
ui_model_type=ModelType.IPAdapter,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G", "ViT-L"] = InputField(
|
||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||
|
||||
@@ -6,12 +6,11 @@ from pydantic import field_validator
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -35,7 +34,7 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
vllm_model: ModelIdentifierField = InputField(
|
||||
title="LLaVA Model Type",
|
||||
description=FieldDescriptions.vllm_model,
|
||||
ui_model_type=ModelType.LlavaOnevision,
|
||||
ui_type=UIType.LlavaOnevisionModel,
|
||||
)
|
||||
|
||||
@field_validator("images", mode="before")
|
||||
|
||||
@@ -53,7 +53,7 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.invocations.scheduler import SchedulerOutput
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.version import __version__
|
||||
|
||||
@@ -473,6 +473,7 @@ class MetadataToModelOutput(BaseInvocationOutput):
|
||||
model: ModelIdentifierField = OutputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
title="Model",
|
||||
ui_type=UIType.MainModel,
|
||||
)
|
||||
name: str = OutputField(description="Model Name", title="Name")
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
@@ -487,6 +488,7 @@ class MetadataToSDXLModelOutput(BaseInvocationOutput):
|
||||
model: ModelIdentifierField = OutputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
title="Model",
|
||||
ui_type=UIType.SDXLMainModel,
|
||||
)
|
||||
name: str = OutputField(description="Model Name", title="Name")
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
@@ -517,7 +519,8 @@ class MetadataToModelInvocation(BaseInvocation, WithMetadata):
|
||||
input=Input.Direct,
|
||||
)
|
||||
default_value: ModelIdentifierField = InputField(
|
||||
description="The default model to use if not found in the metadata", ui_model_type=ModelType.Main
|
||||
description="The default model to use if not found in the metadata",
|
||||
ui_type=UIType.MainModel,
|
||||
)
|
||||
|
||||
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
|
||||
@@ -572,8 +575,7 @@ class MetadataToSDXLModelInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
default_value: ModelIdentifierField = InputField(
|
||||
description="The default SDXL Model to use if not found in the metadata",
|
||||
ui_model_type=ModelType.Main,
|
||||
ui_model_base=BaseModelType.StableDiffusionXL,
|
||||
ui_type=UIType.SDXLMainModel,
|
||||
)
|
||||
|
||||
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -145,7 +145,7 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model - SD1.5, SD2",
|
||||
title="Main Model - SD1.5",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
@@ -153,11 +153,7 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2],
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
@@ -191,10 +187,7 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.StableDiffusion1,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@@ -257,9 +250,7 @@ class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_type=ModelType.LoRA,
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
@@ -341,10 +332,7 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.StableDiffusionXL,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@@ -485,26 +473,13 @@ class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"vae_loader",
|
||||
title="VAE Model - SD1.5, SD2, SDXL, SD3, FLUX",
|
||||
tags=["vae", "model"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
"vae_loader", title="VAE Model - SD1.5, SDXL, SD3, FLUX", tags=["vae", "model"], category="model", version="1.0.4"
|
||||
)
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
title="VAE",
|
||||
ui_model_base=[
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.StableDiffusion3,
|
||||
BaseModelType.Flux,
|
||||
],
|
||||
ui_model_type=ModelType.VAE,
|
||||
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
|
||||
@@ -6,14 +6,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ClipVariantType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
@@ -39,43 +39,36 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.SD3MainModel,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.StableDiffusion3,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
default=None,
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
clip_l_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPLEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP L Encoder",
|
||||
default=None,
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
ui_model_variant=ClipVariantType.L,
|
||||
)
|
||||
|
||||
clip_g_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_g_model,
|
||||
ui_type=UIType.CLIPGEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP G Encoder",
|
||||
default=None,
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
ui_model_variant=ClipVariantType.G,
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
title="VAE",
|
||||
default=None,
|
||||
ui_model_base=BaseModelType.StableDiffusion3,
|
||||
ui_model_type=ModelType.VAE,
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@@ -29,9 +29,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model,
|
||||
ui_model_base=BaseModelType.StableDiffusionXL,
|
||||
ui_model_type=ModelType.Main,
|
||||
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@@ -69,9 +67,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
ui_model_base=BaseModelType.StableDiffusionXLRefiner,
|
||||
ui_model_type=ModelType.Main,
|
||||
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
@@ -18,7 +19,6 @@ from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
@@ -33,7 +33,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_to_image_model: ModelIdentifierField = InputField(
|
||||
title="Image-to-Image Model",
|
||||
description=FieldDescriptions.spandrel_image_to_image_model,
|
||||
ui_model_type=ModelType.SpandrelImageToImage,
|
||||
ui_type=UIType.SpandrelImageToImageModel,
|
||||
)
|
||||
tile_size: int = InputField(
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
|
||||
@@ -8,12 +8,11 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
@@ -61,8 +60,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
ui_order=-1,
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusionXL],
|
||||
ui_model_type=ModelType.T2IAdapter,
|
||||
ui_type=UIType.T2IAdapterModel,
|
||||
)
|
||||
weight: Union[float, list[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
||||
|
||||
@@ -17,7 +17,6 @@ from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.services.shared.field_identifier import FieldIdentifier
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
@@ -210,6 +209,13 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
||||
return None
|
||||
|
||||
|
||||
class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
user_label: str | None = Field(description="The user label of the field, if any")
|
||||
|
||||
|
||||
class SessionQueueItem(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
user_label: str | None = Field(description="The user label of the field, if any")
|
||||
|
||||
|
||||
__all__ = ["FieldIdentifier"]
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Optional, Union
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
|
||||
|
||||
from invokeai.app.services.shared.field_identifier import FieldIdentifier
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
__workflow_meta_version__ = semver.Version.parse("1.0.0")
|
||||
@@ -60,10 +59,6 @@ class WorkflowWithoutID(BaseModel):
|
||||
tags: str = Field(description="The tags of the workflow.")
|
||||
notes: str = Field(description="The notes of the workflow.")
|
||||
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
|
||||
output_fields: list[FieldIdentifier] | None = Field(
|
||||
default=None,
|
||||
description="The fields designated as output fields for the workflow.",
|
||||
)
|
||||
meta: WorkflowMeta = Field(description="The meta of the workflow.")
|
||||
# TODO(psyche): nodes, edges and form are very loosely typed - they are strictly modeled and checked on the frontend.
|
||||
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
|
||||
@@ -126,29 +121,6 @@ class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):
|
||||
description: str = Field(description="The description of the workflow.")
|
||||
category: WorkflowCategory = Field(description="The description of the workflow.")
|
||||
tags: str = Field(description="The tags of the workflow.")
|
||||
has_valid_image_output_field: bool = Field(
|
||||
default=False,
|
||||
description="True when the workflow exposes exactly one output field and it is an image output.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRecordListItemDTO":
|
||||
workflow_data = data.pop("workflow", None)
|
||||
has_valid_output = False
|
||||
|
||||
if workflow_data:
|
||||
if isinstance(workflow_data, (str, bytes, bytearray)):
|
||||
workflow = WorkflowValidator.validate_json(workflow_data)
|
||||
else:
|
||||
workflow = WorkflowValidator.validate_python(workflow_data)
|
||||
output_fields = workflow.output_fields or []
|
||||
image_outputs = [f for f in output_fields if f.kind == "output" and f.field_name.startswith("image")]
|
||||
if len(image_outputs) == 1:
|
||||
has_valid_output = True
|
||||
|
||||
data = dict(data)
|
||||
data["has_valid_image_output_field"] = has_valid_output
|
||||
return WorkflowRecordListItemDTOValidator.validate_python(data)
|
||||
|
||||
|
||||
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)
|
||||
|
||||
@@ -12,6 +12,7 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordListItemDTOValidator,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowValidator,
|
||||
WorkflowWithoutID,
|
||||
@@ -122,8 +123,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags,
|
||||
workflow
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
@@ -204,7 +204,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTO.from_dict(dict(row)) for row in rows]
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
@@ -207,24 +207,15 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
||||
|
||||
|
||||
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
|
||||
state_dict: IPAdapterStateDict = {
|
||||
"ip_adapter": {},
|
||||
"image_proj": {},
|
||||
"adapter_modules": {}, # added for noobai-mark-ipa
|
||||
"image_proj_model": {}, # added for noobai-mark-ipa
|
||||
}
|
||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||
|
||||
if ip_adapter_ckpt_path.suffix == ".safetensors":
|
||||
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
||||
for key in model.keys():
|
||||
if key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
|
||||
elif key.startswith("image_proj_model."):
|
||||
state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = model[key]
|
||||
elif key.startswith("image_proj."):
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
|
||||
elif key.startswith("adapter_modules."):
|
||||
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = model[key]
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
|
||||
else:
|
||||
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
||||
else:
|
||||
|
||||
@@ -5,7 +5,11 @@ import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
from diffusers.models.controlnets.controlnet import (
|
||||
ControlNetConditioningEmbedding,
|
||||
ControlNetOutput,
|
||||
zero_module,
|
||||
)
|
||||
from diffusers.models.embeddings import (
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
@@ -775,7 +779,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
|
||||
diffusers.ControlNetModel = ControlNetModel
|
||||
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
||||
# Patch both the new and legacy module paths for compatibility
|
||||
try:
|
||||
diffusers.models.controlnets.controlnet.ControlNetModel = ControlNetModel
|
||||
except Exception:
|
||||
# Fallback for environments still exposing the legacy path
|
||||
try:
|
||||
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# patch LoRACompatibleConv to use original Conv2D forward function
|
||||
|
||||
@@ -1952,7 +1952,6 @@
|
||||
"private": "Private",
|
||||
"shared": "Shared",
|
||||
"published": "Published",
|
||||
"validImageOutput": "Valid Image Output",
|
||||
"browseWorkflows": "Browse Workflows",
|
||||
"deselectAll": "Deselect All",
|
||||
"recommended": "Recommended For You",
|
||||
@@ -2463,18 +2462,6 @@
|
||||
"cr": "Cr (YCbCr)"
|
||||
}
|
||||
},
|
||||
"triggerWorkflow": {
|
||||
"menuItem": "Trigger Workflow",
|
||||
"heading": "Trigger Workflow",
|
||||
"openLibrary": "Select Workflow",
|
||||
"noWorkflowSelected": "No workflow selected",
|
||||
"loading": "Loading workflow...",
|
||||
"apply": "Apply",
|
||||
"applying": "Applying",
|
||||
"cancel": "Cancel",
|
||||
"enqueued": "Workflow enqueued",
|
||||
"enqueueFailed": "Unable to enqueue workflow"
|
||||
},
|
||||
"transform": {
|
||||
"transform": "Transform",
|
||||
"fitToBbox": "Fit to Bbox",
|
||||
|
||||
@@ -131,8 +131,7 @@
|
||||
"notInstalled": "Non $t(common.installed)",
|
||||
"prevPage": "Pagina precedente",
|
||||
"nextPage": "Pagina successiva",
|
||||
"resetToDefaults": "Ripristina impostazioni predefinite",
|
||||
"crop": "Ritaglia"
|
||||
"resetToDefaults": "Ripristina impostazioni predefinite"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -279,14 +278,6 @@
|
||||
"selectVideoTab": {
|
||||
"title": "Seleziona la scheda Video",
|
||||
"desc": "Seleziona la scheda Video."
|
||||
},
|
||||
"promptHistoryPrev": {
|
||||
"title": "Prompt precedente nella cronologia",
|
||||
"desc": "Quando il prompt è attivo, passa al prompt precedente (più vecchio) nella cronologia."
|
||||
},
|
||||
"promptHistoryNext": {
|
||||
"title": "Prossimo prompt nella cronologia",
|
||||
"desc": "Quando il prompt è attivo, passa al prompt successivo (più recente) nella cronologia."
|
||||
}
|
||||
},
|
||||
"hotkeys": "Tasti di scelta rapida",
|
||||
@@ -884,8 +875,7 @@
|
||||
"video": "Video",
|
||||
"resolution": "Risoluzione",
|
||||
"downloadImage": "Scarica l'immagine",
|
||||
"showOptionsPanel": "Mostra pannello laterale (O o T)",
|
||||
"startingFrameImageAspectRatioWarning": "Le proporzioni dell'immagine non corrispondono alle proporzioni del video ({{videoAspectRatio}}). Ciò potrebbe causare ritagli imprevisti durante la generazione del video."
|
||||
"showOptionsPanel": "Mostra pannello laterale (O o T)"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@@ -2105,10 +2095,7 @@
|
||||
"generateFromImage": "Genera prompt dall'immagine",
|
||||
"resultTitle": "Espansione del prompt completata",
|
||||
"resultSubtitle": "Scegli come gestire il prompt espanso:",
|
||||
"insert": "Inserisci",
|
||||
"noPromptHistory": "Nessuna cronologia di prompt registrata.",
|
||||
"noMatchingPrompts": "Nessun prompt corrispondente nella cronologia.",
|
||||
"toSwitchBetweenPrompts": "per passare da un prompt all'altro."
|
||||
"insert": "Inserisci"
|
||||
},
|
||||
"controlLayers": {
|
||||
"addLayer": "Aggiungi Livello",
|
||||
@@ -2804,8 +2791,7 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"items": [
|
||||
"Seleziona oggetto v2: selezione degli oggetti migliorata con input di punti e riquadri o prompt di testo.",
|
||||
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello.",
|
||||
"Cronologia prompt: rivedi e richiama rapidamente i tuoi ultimi 100 prompt."
|
||||
"Regolazioni del livello raster: regola facilmente la luminosità, il contrasto, la saturazione, le curve e altro ancora del livello."
|
||||
],
|
||||
"watchUiUpdatesOverview": "Guarda la panoramica degli aggiornamenti dell'interfaccia utente"
|
||||
},
|
||||
|
||||
@@ -9,7 +9,6 @@ import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/component
|
||||
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
|
||||
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
|
||||
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
|
||||
import { CanvasEntityMenuItemsTriggerWorkflow } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTriggerWorkflow';
|
||||
import { RasterLayerMenuItemsAdjustments } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments';
|
||||
import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu';
|
||||
import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu';
|
||||
@@ -25,7 +24,6 @@ export const RasterLayerMenuItems = memo(() => {
|
||||
</IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
<CanvasEntityMenuItemsFilter />
|
||||
<CanvasEntityMenuItemsTriggerWorkflow />
|
||||
<CanvasEntityMenuItemsSelectObject />
|
||||
<RasterLayerMenuItemsAdjustments />
|
||||
<MenuDivider />
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useWorkflowTriggerApply } from 'features/controlLayers/hooks/useWorkflowTriggerApply';
|
||||
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
|
||||
import { parseAndMigrateWorkflow } from 'features/nodes/util/workflow/migrations';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import {
|
||||
setWorkflowLibraryBrowseIntent,
|
||||
setWorkflowLibraryTriggerIntent,
|
||||
} from 'features/workflowLibrary/store/workflowLibraryIntent';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBooksDuotone, PiPlayBold, PiXBold } from 'react-icons/pi';
|
||||
import { useLazyGetWorkflowQuery } from 'services/api/endpoints/workflows';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
const TriggerWorkflowContent = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const canvasManager = useCanvasManager();
|
||||
const workflowLibraryModal = useWorkflowLibraryModal();
|
||||
const [fetchWorkflow, { isFetching: isFetchingWorkflow }] = useLazyGetWorkflowQuery();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const { apply, isApplying, selectedWorkflow, selectedWorkflowName } = useWorkflowTriggerApply();
|
||||
|
||||
const onSelectWorkflow = useCallback(
|
||||
(workflow: S['WorkflowRecordListItemWithThumbnailDTO']) => {
|
||||
const handleSelection = async () => {
|
||||
try {
|
||||
const res = await fetchWorkflow(workflow.workflow_id).unwrap();
|
||||
const migratedWorkflow = parseAndMigrateWorkflow(res.workflow);
|
||||
canvasManager.stateApi.setWorkflowTriggerSelection({
|
||||
workflow: migratedWorkflow,
|
||||
workflowId: workflow.workflow_id,
|
||||
workflowName: migratedWorkflow.name ?? workflow.name,
|
||||
});
|
||||
} catch {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('workflows.problemRetrievingWorkflow'),
|
||||
});
|
||||
} finally {
|
||||
setWorkflowLibraryBrowseIntent();
|
||||
}
|
||||
};
|
||||
|
||||
void handleSelection();
|
||||
},
|
||||
[canvasManager.stateApi, fetchWorkflow, t]
|
||||
);
|
||||
|
||||
const openWorkflowLibrary = useCallback(() => {
|
||||
setWorkflowLibraryTriggerIntent((workflow) => {
|
||||
onSelectWorkflow(workflow);
|
||||
workflowLibraryModal.close();
|
||||
});
|
||||
workflowLibraryModal.open();
|
||||
}, [onSelectWorkflow, workflowLibraryModal]);
|
||||
|
||||
const isApplyDisabled = useMemo(() => {
|
||||
return !selectedWorkflow || isApplying || isFetchingWorkflow || isBusy;
|
||||
}, [selectedWorkflow, isApplying, isFetchingWorkflow, isBusy]);
|
||||
|
||||
const cancel = useCallback(() => {
|
||||
setWorkflowLibraryBrowseIntent();
|
||||
canvasManager.stateApi.cancelWorkflowTrigger();
|
||||
}, [canvasManager.stateApi]);
|
||||
|
||||
return (
|
||||
<Flex bg="base.800" borderRadius="base" p={4} flexDir="column" gap={4} w={420} shadow="dark-lg">
|
||||
<Flex w="full" gap={4} alignItems="center">
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.triggerWorkflow.heading')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Text color="base.400">{selectedWorkflowName ?? t('controlLayers.triggerWorkflow.noWorkflowSelected')}</Text>
|
||||
{isFetchingWorkflow && (
|
||||
<Flex alignItems="center" gap={2} color="base.500">
|
||||
<Spinner size="sm" />
|
||||
<Text>{t('controlLayers.triggerWorkflow.loading')}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Button
|
||||
variant="ghost"
|
||||
leftIcon={<PiBooksDuotone />}
|
||||
onClick={openWorkflowLibrary}
|
||||
isDisabled={isApplying || isFetchingWorkflow || isBusy}
|
||||
>
|
||||
{t('controlLayers.triggerWorkflow.openLibrary')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button variant="ghost" leftIcon={<PiPlayBold />} onClick={apply} isDisabled={isApplyDisabled}>
|
||||
{isApplying ? t('controlLayers.triggerWorkflow.applying') : t('controlLayers.triggerWorkflow.apply')}
|
||||
{isApplying && <Spinner size="sm" ml={2} />}
|
||||
</Button>
|
||||
<Button variant="ghost" leftIcon={<PiXBold />} onClick={cancel} isDisabled={isApplying}>
|
||||
{t('controlLayers.triggerWorkflow.cancel')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TriggerWorkflowContent.displayName = 'TriggerWorkflowContent';
|
||||
|
||||
export const TriggerWorkflow = () => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const state = useStore(canvasManager.stateApi.$workflowTrigger);
|
||||
if (!state) {
|
||||
return null;
|
||||
}
|
||||
return <TriggerWorkflowContent />;
|
||||
};
|
||||
|
||||
TriggerWorkflow.displayName = 'TriggerWorkflow';
|
||||
@@ -1,20 +0,0 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityWorkflowTrigger } from 'features/controlLayers/hooks/useEntityWorkflowTrigger';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlayCircleBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityMenuItemsTriggerWorkflow = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const trigger = useEntityWorkflowTrigger(entityIdentifier);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={trigger.start} icon={<PiPlayCircleBold />} isDisabled={trigger.isDisabled}>
|
||||
{t('controlLayers.triggerWorkflow.menuItem')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityMenuItemsTriggerWorkflow.displayName = 'CanvasEntityMenuItemsTriggerWorkflow';
|
||||
@@ -1,62 +0,0 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
|
||||
import { useEntityIsLocked } from 'features/controlLayers/hooks/useEntityIsLocked';
|
||||
import { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { isRasterLayerEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
export const useEntityWorkflowTrigger = (entityIdentifier: CanvasEntityIdentifier | null) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useEntityAdapterSafe(entityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isLocked = useEntityIsLocked(entityIdentifier);
|
||||
const isEmpty = useEntityIsEmpty(entityIdentifier);
|
||||
const workflowTrigger = useStore(canvasManager.stateApi.$workflowTrigger);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
if (!entityIdentifier) {
|
||||
return true;
|
||||
}
|
||||
if (!isRasterLayerEntityIdentifier(entityIdentifier)) {
|
||||
return true;
|
||||
}
|
||||
if (!adapter) {
|
||||
return true;
|
||||
}
|
||||
if (!(adapter instanceof CanvasEntityAdapterRasterLayer)) {
|
||||
return true;
|
||||
}
|
||||
if (isBusy) {
|
||||
return true;
|
||||
}
|
||||
if (isLocked) {
|
||||
return true;
|
||||
}
|
||||
if (isEmpty) {
|
||||
return true;
|
||||
}
|
||||
if (workflowTrigger && workflowTrigger.adapter !== adapter) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}, [adapter, entityIdentifier, isBusy, isEmpty, isLocked, workflowTrigger]);
|
||||
|
||||
const start = useCallback(() => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
if (!entityIdentifier || !isRasterLayerEntityIdentifier(entityIdentifier) || !adapter) {
|
||||
return;
|
||||
}
|
||||
if (!(adapter instanceof CanvasEntityAdapterRasterLayer)) {
|
||||
return;
|
||||
}
|
||||
canvasManager.stateApi.startWorkflowTrigger(adapter);
|
||||
}, [adapter, canvasManager.stateApi, entityIdentifier, isDisabled]);
|
||||
|
||||
return { isDisabled, start } as const;
|
||||
};
|
||||
@@ -1,142 +0,0 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { buildInvocationGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { CANVAS_OUTPUT_PREFIX } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { setWorkflowLibraryBrowseIntent } from 'features/workflowLibrary/store/workflowLibraryIntent';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { EnqueueBatchArg } from 'services/api/types';
|
||||
|
||||
export const useWorkflowTriggerApply = () => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const workflowState = useStore(canvasManager.stateApi.$workflowTrigger);
|
||||
const selection = workflowState?.selection ?? null;
|
||||
const selectedWorkflow = selection?.workflow ?? null;
|
||||
const selectedWorkflowName = selection?.workflowName ?? null;
|
||||
const canvasSessionId = useAppSelector(selectCanvasSessionId);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [isApplying, setIsApplying] = useState(false);
|
||||
const templates = useStore($templates);
|
||||
|
||||
const apply = useCallback(async () => {
|
||||
if (!selection || !selectedWorkflow || !canvasSessionId) {
|
||||
return;
|
||||
}
|
||||
|
||||
setIsApplying(true);
|
||||
try {
|
||||
const workflow = deepClone(selectedWorkflow);
|
||||
const outputFields = workflow.output_fields ?? [];
|
||||
if (outputFields.length === 0) {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('controlLayers.triggerWorkflow.noWorkflowSelected'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const outputField = outputFields[0]!;
|
||||
const originalNodeId = outputField.nodeId;
|
||||
const fieldName = outputField.fieldName ?? outputField.field_name;
|
||||
if (!originalNodeId || !fieldName) {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('controlLayers.triggerWorkflow.noWorkflowSelected'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
const outputNodeId = `${CANVAS_OUTPUT_PREFIX}:${selection.workflowId}`;
|
||||
|
||||
workflow.output_fields = [
|
||||
{
|
||||
...outputField,
|
||||
kind: 'output',
|
||||
nodeId: outputNodeId,
|
||||
node_id: outputNodeId,
|
||||
fieldName,
|
||||
field_name: fieldName,
|
||||
userLabel: outputField.userLabel ?? outputField.user_label ?? null,
|
||||
user_label: outputField.userLabel ?? outputField.user_label ?? null,
|
||||
},
|
||||
];
|
||||
|
||||
workflow.nodes = workflow.nodes.map((node) => {
|
||||
if (node.id !== originalNodeId || node.type !== 'invocation') {
|
||||
return node;
|
||||
}
|
||||
return {
|
||||
...node,
|
||||
id: outputNodeId,
|
||||
data: { ...node.data, id: outputNodeId },
|
||||
};
|
||||
});
|
||||
|
||||
workflow.edges = workflow.edges.map((edge) => {
|
||||
if (edge.type !== 'default') {
|
||||
return edge;
|
||||
}
|
||||
return {
|
||||
...edge,
|
||||
source: edge.source === originalNodeId ? outputNodeId : edge.source,
|
||||
target: edge.target === originalNodeId ? outputNodeId : edge.target,
|
||||
};
|
||||
});
|
||||
|
||||
const graph = buildInvocationGraph({
|
||||
nodes: workflow.nodes,
|
||||
edges: workflow.edges,
|
||||
templates,
|
||||
graphId: workflow.id,
|
||||
});
|
||||
|
||||
const workflowPayload = deepClone(workflow);
|
||||
delete workflowPayload.id;
|
||||
|
||||
const enqueueArg: EnqueueBatchArg = {
|
||||
batch: {
|
||||
graph,
|
||||
workflow: workflowPayload,
|
||||
runs: 1,
|
||||
origin: 'canvas',
|
||||
destination: canvasSessionId,
|
||||
},
|
||||
};
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(enqueueArg, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
await req.unwrap();
|
||||
|
||||
toast({
|
||||
status: 'success',
|
||||
title: t('controlLayers.triggerWorkflow.enqueued'),
|
||||
});
|
||||
|
||||
setWorkflowLibraryBrowseIntent();
|
||||
canvasManager.stateApi.cancelWorkflowTrigger();
|
||||
} catch {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('controlLayers.triggerWorkflow.enqueueFailed'),
|
||||
});
|
||||
} finally {
|
||||
setIsApplying(false);
|
||||
}
|
||||
}, [canvasManager.stateApi, canvasSessionId, dispatch, selection, selectedWorkflow, t, templates]);
|
||||
|
||||
return useMemo(
|
||||
() => ({ apply, isApplying, selection, selectedWorkflow, selectedWorkflowName }) as const,
|
||||
[apply, isApplying, selection, selectedWorkflow, selectedWorkflowName]
|
||||
);
|
||||
};
|
||||
@@ -50,7 +50,6 @@ import type {
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { RGBA_BLACK } from 'features/controlLayers/store/types';
|
||||
import { zImageOutput } from 'features/nodes/types/common';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
@@ -62,17 +61,6 @@ import type { Param0 } from 'tsafe';
|
||||
|
||||
import type { CanvasEntityAdapter } from './CanvasEntity/types';
|
||||
|
||||
export type WorkflowTriggerSelection = {
|
||||
workflow: WorkflowV4;
|
||||
workflowId: string;
|
||||
workflowName: string;
|
||||
};
|
||||
|
||||
type WorkflowTriggerState = {
|
||||
adapter: CanvasEntityAdapterRasterLayer;
|
||||
selection: WorkflowTriggerSelection | null;
|
||||
};
|
||||
|
||||
export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
readonly type = 'state_api';
|
||||
readonly id: string;
|
||||
@@ -469,26 +457,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
}
|
||||
};
|
||||
|
||||
startWorkflowTrigger = (adapter: CanvasEntityAdapterRasterLayer) => {
|
||||
const current = this.$workflowTrigger.get();
|
||||
if (current && current.adapter !== adapter) {
|
||||
this.$workflowTrigger.set(null);
|
||||
}
|
||||
this.$workflowTrigger.set({ adapter, selection: null });
|
||||
};
|
||||
|
||||
cancelWorkflowTrigger = () => {
|
||||
this.$workflowTrigger.set(null);
|
||||
};
|
||||
|
||||
setWorkflowTriggerSelection = (selection: WorkflowTriggerSelection) => {
|
||||
const current = this.$workflowTrigger.get();
|
||||
if (!current) {
|
||||
return;
|
||||
}
|
||||
this.$workflowTrigger.set({ adapter: current.adapter, selection });
|
||||
};
|
||||
|
||||
/**
|
||||
* The entity adapter being filtered, if any.
|
||||
*/
|
||||
@@ -529,16 +497,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
*/
|
||||
$isSegmenting = computed(this.$segmentingAdapter, (segmentingAdapter) => Boolean(segmentingAdapter));
|
||||
|
||||
/**
|
||||
* The entity adapter currently triggering a workflow, if any.
|
||||
*/
|
||||
$workflowTrigger = atom<WorkflowTriggerState | null>(null);
|
||||
|
||||
/**
|
||||
* Whether a workflow trigger is active.
|
||||
*/
|
||||
$isWorkflowTriggering = computed(this.$workflowTrigger, (trigger) => Boolean(trigger));
|
||||
|
||||
/**
|
||||
* Whether the space key is currently pressed.
|
||||
*/
|
||||
|
||||
@@ -18,7 +18,6 @@ const modelPaneSx: SystemStyleObject = {
|
||||
},
|
||||
h: 'full',
|
||||
minWidth: '300px',
|
||||
overflow: 'auto',
|
||||
};
|
||||
|
||||
export const ModelPane = memo(() => {
|
||||
|
||||
@@ -40,7 +40,7 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
}, [modelConfig.base, modelConfig.type]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4} h="full">
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<ModelHeader modelConfig={modelConfig}>
|
||||
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
|
||||
<ModelConvertButton modelConfig={modelConfig} />
|
||||
@@ -48,7 +48,7 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
<ModelEditButton />
|
||||
</ModelHeader>
|
||||
<Divider />
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import { FloatFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInput';
|
||||
import { FloatFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldInputAndSlider';
|
||||
import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/FloatField/FloatFieldSlider';
|
||||
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
|
||||
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
|
||||
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
|
||||
import FluxKontextModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxKontextModelFieldInputComponent';
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
|
||||
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
|
||||
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
|
||||
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
|
||||
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
@@ -23,8 +27,22 @@ import {
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isChatGPT4oModelFieldInputInstance,
|
||||
isChatGPT4oModelFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isCLIPGEmbedModelFieldInputInstance,
|
||||
isCLIPGEmbedModelFieldInputTemplate,
|
||||
isCLIPLEmbedModelFieldInputInstance,
|
||||
isCLIPLEmbedModelFieldInputTemplate,
|
||||
isCogView4MainModelFieldInputInstance,
|
||||
isCogView4MainModelFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlLoRAModelFieldInputInstance,
|
||||
isControlLoRAModelFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldCollectionInputInstance,
|
||||
@@ -33,28 +51,68 @@ import {
|
||||
isFloatFieldInputTemplate,
|
||||
isFloatGeneratorFieldInputInstance,
|
||||
isFloatGeneratorFieldInputTemplate,
|
||||
isFluxKontextModelFieldInputInstance,
|
||||
isFluxKontextModelFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxReduxModelFieldInputInstance,
|
||||
isFluxReduxModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
isFluxVAEModelFieldInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isImageGeneratorFieldInputInstance,
|
||||
isImageGeneratorFieldInputTemplate,
|
||||
isImagen3ModelFieldInputInstance,
|
||||
isImagen3ModelFieldInputTemplate,
|
||||
isImagen4ModelFieldInputInstance,
|
||||
isImagen4ModelFieldInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isIntegerGeneratorFieldInputInstance,
|
||||
isIntegerGeneratorFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
isIPAdapterModelFieldInputTemplate,
|
||||
isLLaVAModelFieldInputInstance,
|
||||
isLLaVAModelFieldInputTemplate,
|
||||
isLoRAModelFieldInputInstance,
|
||||
isLoRAModelFieldInputTemplate,
|
||||
isMainModelFieldInputInstance,
|
||||
isMainModelFieldInputTemplate,
|
||||
isModelIdentifierFieldInputInstance,
|
||||
isModelIdentifierFieldInputTemplate,
|
||||
isRunwayModelFieldInputInstance,
|
||||
isRunwayModelFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isSD3MainModelFieldInputInstance,
|
||||
isSD3MainModelFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSigLipModelFieldInputInstance,
|
||||
isSigLipModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isStringGeneratorFieldInputInstance,
|
||||
isStringGeneratorFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
isT2IAdapterModelFieldInputTemplate,
|
||||
isT5EncoderModelFieldInputInstance,
|
||||
isT5EncoderModelFieldInputTemplate,
|
||||
isVAEModelFieldInputInstance,
|
||||
isVAEModelFieldInputTemplate,
|
||||
isVeo3ModelFieldInputInstance,
|
||||
isVeo3ModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo } from 'react';
|
||||
@@ -63,10 +121,33 @@ import { assert } from 'tsafe';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
|
||||
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
|
||||
import CogView4MainModelFieldInputComponent from './inputs/CogView4MainModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInputComponent';
|
||||
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import RunwayModelFieldInputComponent from './inputs/RunwayModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SigLipModelFieldInputComponent from './inputs/SigLipModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
import Veo3ModelFieldInputComponent from './inputs/Veo3ModelFieldInputComponent';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
@@ -206,6 +287,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <BoardFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isMainModelFieldInputTemplate(template)) {
|
||||
if (!isMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isModelIdentifierFieldInputTemplate(template)) {
|
||||
if (!isModelIdentifierFieldInputInstance(field)) {
|
||||
return null;
|
||||
@@ -213,6 +301,159 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSDXLRefinerModelFieldInputTemplate(template)) {
|
||||
if (!isSDXLRefinerModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <RefinerModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isVAEModelFieldInputTemplate(template)) {
|
||||
if (!isVAEModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <VAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isT5EncoderModelFieldInputTemplate(template)) {
|
||||
if (!isT5EncoderModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputTemplate(template)) {
|
||||
if (!isCLIPEmbedModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isCLIPLEmbedModelFieldInputTemplate(template)) {
|
||||
if (!isCLIPLEmbedModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isCLIPGEmbedModelFieldInputTemplate(template)) {
|
||||
if (!isCLIPGEmbedModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isControlLoRAModelFieldInputTemplate(template)) {
|
||||
if (!isControlLoRAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isLLaVAModelFieldInputTemplate(template)) {
|
||||
if (!isLLaVAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputTemplate(template)) {
|
||||
if (!isFluxVAEModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputTemplate(template)) {
|
||||
if (!isLoRAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isControlNetModelFieldInputTemplate(template)) {
|
||||
if (!isControlNetModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isIPAdapterModelFieldInputTemplate(template)) {
|
||||
if (!isIPAdapterModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isT2IAdapterModelFieldInputTemplate(template)) {
|
||||
if (!isT2IAdapterModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSpandrelImageToImageModelFieldInputTemplate(template)) {
|
||||
if (!isSpandrelImageToImageModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSigLipModelFieldInputTemplate(template)) {
|
||||
if (!isSigLipModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SigLipModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxReduxModelFieldInputTemplate(template)) {
|
||||
if (!isFluxReduxModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isImagen3ModelFieldInputTemplate(template)) {
|
||||
if (!isImagen3ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isImagen4ModelFieldInputTemplate(template)) {
|
||||
if (!isImagen4ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxKontextModelFieldInputTemplate(template)) {
|
||||
if (!isFluxKontextModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxKontextModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isChatGPT4oModelFieldInputTemplate(template)) {
|
||||
if (!isChatGPT4oModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <ChatGPT4oModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isVeo3ModelFieldInputTemplate(template)) {
|
||||
if (!isVeo3ModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <Veo3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isRunwayModelFieldInputTemplate(template)) {
|
||||
if (!isRunwayModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <RunwayModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isColorFieldInputTemplate(template)) {
|
||||
if (!isColorFieldInputInstance(field)) {
|
||||
return null;
|
||||
@@ -220,6 +461,34 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxMainModelFieldInputTemplate(template)) {
|
||||
if (!isFluxMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputTemplate(template)) {
|
||||
if (!isSD3MainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isCogView4MainModelFieldInputTemplate(template)) {
|
||||
if (!isCogView4MainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <CogView4MainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputTemplate(template)) {
|
||||
if (!isSDXLMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSchedulerFieldInputTemplate(template)) {
|
||||
if (!isSchedulerFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
const onChange = useCallback(
|
||||
(value: CLIPEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPEmbedModelFieldInputComponent);
|
||||
@@ -0,0 +1,45 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: CLIPGEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPGEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config))}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPGEmbedModelFieldInputComponent);
|
||||
@@ -0,0 +1,45 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: CLIPLEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPLEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config))}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPLEmbedModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldChatGPT4oModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useChatGPT4oModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const ChatGPT4oModelFieldInputComponent = (
|
||||
props: FieldComponentProps<ChatGPT4oModelFieldInputInstance, ChatGPT4oModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useChatGPT4oModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldChatGPT4oModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ChatGPT4oModelFieldInputComponent);
|
||||
@@ -0,0 +1,63 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
CogView4MainModelFieldInputInstance,
|
||||
CogView4MainModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCogView4Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CogView4MainModelFieldInputInstance, CogView4MainModelFieldInputTemplate>;
|
||||
|
||||
const CogView4MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCogView4Models();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
isDisabled={!options.length}
|
||||
isInvalid={!value && props.fieldTemplate.required}
|
||||
>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CogView4MainModelFieldInputComponent);
|
||||
@@ -0,0 +1,48 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
ControlLoRAModelFieldInputInstance,
|
||||
ControlLoRAModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
|
||||
import type { ControlLoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRAModelFieldInputTemplate>;
|
||||
|
||||
const ControlLoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useControlLoRAModel();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ControlLoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldControlLoRAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlLoRAModelFieldInputComponent);
|
||||
@@ -0,0 +1,45 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useControlNetModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate>;
|
||||
|
||||
const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ControlNetModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldControlNetModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetModelFieldInputComponent);
|
||||
@@ -0,0 +1,49 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxKontextModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FluxKontextModelFieldInputInstance,
|
||||
FluxKontextModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxKontextModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const FluxKontextModelFieldInputComponent = (
|
||||
props: FieldComponentProps<FluxKontextModelFieldInputInstance, FluxKontextModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useFluxKontextModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxKontextModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxKontextModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
|
||||
|
||||
const FluxMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxMainModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { FLUXReduxModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const FluxReduxModelFieldInputComponent = (
|
||||
props: FieldComponentProps<FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useFluxReduxModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: FLUXReduxModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxReduxModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxReduxModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
|
||||
|
||||
const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxVAEModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxVAEModelFieldInputComponent);
|
||||
@@ -0,0 +1,45 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const IPAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: IPAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldIPAdapterModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IPAdapterModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldImagen3ModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useImagen3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const Imagen3ModelFieldInputComponent = (
|
||||
props: FieldComponentProps<Imagen3ModelFieldInputInstance, Imagen3ModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useImagen3Models();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldImagen3ModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Imagen3ModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useImagen4Models } from 'services/api/hooks/modelsByType';
|
||||
import type { ApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const Imagen4ModelFieldInputComponent = (
|
||||
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useImagen4Models();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: ApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldImagen4ModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Imagen4ModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LlavaOnevisionConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
|
||||
|
||||
const LLaVAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLLaVAModels();
|
||||
const onChange = useCallback(
|
||||
(value: LlavaOnevisionConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldLLaVAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LLaVAModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate>;
|
||||
|
||||
const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||
const onChange = useCallback(
|
||||
(value: LoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldLoRAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoRAModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInputTemplate>;
|
||||
|
||||
const MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(MainModelFieldInputComponent);
|
||||
@@ -12,7 +12,7 @@ import type { FieldComponentProps } from './types';
|
||||
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
|
||||
|
||||
const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetModelConfigsQuery();
|
||||
const onChange = useCallback(
|
||||
@@ -36,31 +36,8 @@ const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
if (!fieldTemplate.ui_model_base && !fieldTemplate.ui_model_type) {
|
||||
return modelConfigsAdapterSelectors.selectAll(data);
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(data).filter((config) => {
|
||||
if (fieldTemplate.ui_model_base && !fieldTemplate.ui_model_base.includes(config.base)) {
|
||||
return false;
|
||||
}
|
||||
if (fieldTemplate.ui_model_type && !fieldTemplate.ui_model_type.includes(config.type)) {
|
||||
return false;
|
||||
}
|
||||
if (
|
||||
fieldTemplate.ui_model_variant &&
|
||||
'variant' in config &&
|
||||
config.variant &&
|
||||
!fieldTemplate.ui_model_variant.includes(config.variant)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
if (fieldTemplate.ui_model_format && !fieldTemplate.ui_model_format.includes(config.format)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}, [data, fieldTemplate]);
|
||||
return modelConfigsAdapterSelectors.selectAll(data);
|
||||
}, [data]);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefinerModelFieldInputTemplate>;
|
||||
|
||||
const RefinerModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldRefinerModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RefinerModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldRunwayModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useRunwayModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VideoApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const RunwayModelFieldInputComponent = (
|
||||
props: FieldComponentProps<RunwayModelFieldInputInstance, RunwayModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useRunwayModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: VideoApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldRunwayModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RunwayModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
|
||||
|
||||
const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SD3MainModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSDXLModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate>;
|
||||
|
||||
const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSDXLModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SDXLMainModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSigLipModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SigLipModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SigLipModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSigLipModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: SigLipModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSigLipModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SigLipModelFieldInputComponent);
|
||||
@@ -0,0 +1,49 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SpandrelImageToImageModelFieldInputInstance,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SpandrelImageToImageModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: SpandrelImageToImageModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSpandrelImageToImageModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SpandrelImageToImageModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const T2IAdapterModelFieldInputComponent = (
|
||||
props: FieldComponentProps<T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: T2IAdapterModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldT2IAdapterModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(T2IAdapterModelFieldInputComponent);
|
||||
@@ -0,0 +1,43 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate>;
|
||||
|
||||
const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
const onChange = useCallback(
|
||||
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldT5EncoderValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(T5EncoderModelFieldInputComponent);
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputTemplate>;
|
||||
|
||||
const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldVaeModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VAEModelFieldInputComponent);
|
||||
@@ -0,0 +1,46 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldVeo3ModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { Veo3ModelFieldInputInstance, Veo3ModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useVeo3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { VideoApiModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const Veo3ModelFieldInputComponent = (
|
||||
props: FieldComponentProps<Veo3ModelFieldInputInstance, Veo3ModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useVeo3Models();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: VideoApiModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldVeo3ModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(Veo3ModelFieldInputComponent);
|
||||
@@ -7,14 +7,14 @@ import { debounce } from 'es-toolkit/compat';
|
||||
import { getInitialWorkflow } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice, selectWorkflowId } from 'features/nodes/store/selectors';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
|
||||
import stableHash from 'stable-hash';
|
||||
|
||||
const $maybePreviewWorkflow = atom<WorkflowV4 | null>(null);
|
||||
const $maybePreviewWorkflow = atom<WorkflowV3 | null>(null);
|
||||
export const $previewWorkflow = computed(
|
||||
$maybePreviewWorkflow,
|
||||
(maybePreviewWorkflow) => maybePreviewWorkflow ?? EMPTY_OBJECT
|
||||
|
||||
@@ -17,8 +17,7 @@ import {
|
||||
selectWorkflowLibraryView,
|
||||
workflowLibraryViewChanged,
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { setWorkflowLibraryBrowseIntent } from 'features/workflowLibrary/store/workflowLibraryIntent';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetCountsByCategoryQuery } from 'services/api/endpoints/workflows';
|
||||
|
||||
@@ -30,12 +29,8 @@ export const WorkflowLibraryModal = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const workflowLibraryModal = useWorkflowLibraryModal();
|
||||
const didSync = useSyncInitialWorkflowLibraryCategories();
|
||||
const handleClose = useCallback(() => {
|
||||
setWorkflowLibraryBrowseIntent();
|
||||
workflowLibraryModal.close();
|
||||
}, [workflowLibraryModal]);
|
||||
return (
|
||||
<Modal isOpen={workflowLibraryModal.isOpen} onClose={handleClose} isCentered>
|
||||
<Modal isOpen={workflowLibraryModal.isOpen} onClose={workflowLibraryModal.close} isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent
|
||||
w="calc(100% - var(--invoke-sizes-40))"
|
||||
@@ -44,7 +39,7 @@ export const WorkflowLibraryModal = memo(() => {
|
||||
maxH="calc(100% - var(--invoke-sizes-40))"
|
||||
>
|
||||
<ModalHeader>{t('workflows.workflowLibrary')}</ModalHeader>
|
||||
<ModalCloseButton onClick={handleClose} />
|
||||
<ModalCloseButton />
|
||||
<ModalBody pb={6}>
|
||||
{didSync && (
|
||||
<Flex gap={4} h="100%">
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { LockedWorkflowIcon } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/LockedWorkflowIcon';
|
||||
import { ShareWorkflowButton } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/ShareWorkflow';
|
||||
import { selectWorkflowId } from 'features/nodes/store/selectors';
|
||||
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
|
||||
import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import {
|
||||
$workflowLibraryIntent,
|
||||
setWorkflowLibraryBrowseIntent,
|
||||
} from 'features/workflowLibrary/store/workflowLibraryIntent';
|
||||
import InvokeLogo from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -42,27 +36,12 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
const dispatch = useAppDispatch();
|
||||
const workflowId = useAppSelector(selectWorkflowId);
|
||||
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
|
||||
const workflowLibraryModal = useWorkflowLibraryModal();
|
||||
const workflowLibraryIntent = useStore($workflowLibraryIntent);
|
||||
|
||||
const isActive = useMemo(() => {
|
||||
return workflowId === workflow.workflow_id;
|
||||
}, [workflowId, workflow.workflow_id]);
|
||||
|
||||
const isTriggerMode = workflowLibraryIntent.mode === 'trigger-workflow';
|
||||
const isDisabled = isTriggerMode && !workflow.has_valid_image_output_field;
|
||||
|
||||
const handleClickLoad = useCallback(() => {
|
||||
if (isTriggerMode) {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
workflowLibraryIntent.onSelect(workflow);
|
||||
setWorkflowLibraryBrowseIntent();
|
||||
workflowLibraryModal.close();
|
||||
return;
|
||||
}
|
||||
|
||||
loadWorkflowWithDialog({
|
||||
type: 'library',
|
||||
data: workflow.workflow_id,
|
||||
@@ -70,32 +49,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
dispatch(workflowModeChanged('view'));
|
||||
},
|
||||
});
|
||||
}, [
|
||||
dispatch,
|
||||
isDisabled,
|
||||
isTriggerMode,
|
||||
loadWorkflowWithDialog,
|
||||
workflow,
|
||||
workflowLibraryIntent,
|
||||
workflowLibraryModal,
|
||||
]);
|
||||
}, [dispatch, loadWorkflowWithDialog, workflow.workflow_id]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
role="button"
|
||||
onClick={handleClickLoad}
|
||||
aria-disabled={isDisabled}
|
||||
bg="base.750"
|
||||
borderRadius="base"
|
||||
w="full"
|
||||
alignItems="stretch"
|
||||
sx={{
|
||||
...sx,
|
||||
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||
opacity: isDisabled ? 0.5 : 1,
|
||||
pointerEvents: isDisabled ? 'none' : 'auto',
|
||||
}}
|
||||
sx={sx}
|
||||
gap={2}
|
||||
>
|
||||
<Flex p={2} pr={0}>
|
||||
@@ -141,18 +106,6 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
{t('workflows.builder.published')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.has_valid_image_output_field && (
|
||||
<Badge
|
||||
color="invokeGreen.400"
|
||||
borderColor="invokeGreen.700"
|
||||
borderWidth={1}
|
||||
bg="transparent"
|
||||
flexShrink={0}
|
||||
variant="subtle"
|
||||
>
|
||||
{t('workflows.validImageOutput')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
|
||||
{workflow.category === 'default' && (
|
||||
<Image
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
import { Button, ButtonGroup, Divider, Flex, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { InvocationNodeContextProvider } from 'features/nodes/components/flow/nodes/Invocation/context';
|
||||
import { NodeFieldElementOverlay } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isSelectingOutputNode,
|
||||
$outputNodeId,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useMouseOverFormField } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useNodeTemplateTitleOrThrow } from 'features/nodes/hooks/useNodeTemplateTitleOrThrow';
|
||||
import { useNodeUserTitleOrThrow } from 'features/nodes/hooks/useNodeUserTitleOrThrow';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
|
||||
import { $templates, workflowOutputFieldsChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodes, selectNodesSlice, selectWorkflowOutputFields } from 'features/nodes/store/selectors';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { type AnyNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowOutputField } from 'features/nodes/types/workflow';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useEffect, useMemo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowLineRightBold, PiTrashBold } from 'react-icons/pi';
|
||||
|
||||
type OutputDetail = { nodeId: string; fieldName: string; userLabel: string | null };
|
||||
|
||||
const buildOutputDetails = (outputs: WorkflowOutputField[], templates: Templates, nodes: AnyNode[]): OutputDetail[] => {
|
||||
const details: OutputDetail[] = [];
|
||||
|
||||
for (const output of outputs) {
|
||||
if (details.length >= 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
const node = nodes.find((n) => n.id === output.nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
const fieldTemplate = template?.outputs?.[output.fieldName];
|
||||
|
||||
if (!fieldTemplate || fieldTemplate.type.name !== 'ImageField' || fieldTemplate.type.cardinality === 'COLLECTION') {
|
||||
continue;
|
||||
}
|
||||
|
||||
details.push({
|
||||
nodeId: output.nodeId,
|
||||
fieldName: output.fieldName,
|
||||
userLabel: output.userLabel ?? fieldTemplate.title ?? null,
|
||||
});
|
||||
}
|
||||
|
||||
return details;
|
||||
};
|
||||
|
||||
const WorkflowOutputFieldsTab = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const store = useAppStore();
|
||||
const outputFields = useAppSelector(selectWorkflowOutputFields);
|
||||
const nodes = useAppSelector(selectNodes);
|
||||
const templates = useStore($templates);
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const selectionSourceRef = useRef<'workflow-output-tab' | null>(null);
|
||||
const previousOutputsRef = useRef<WorkflowOutputField[]>(outputFields);
|
||||
|
||||
const outputDetails = useMemo(
|
||||
() => buildOutputDetails(outputFields, templates, nodes),
|
||||
[outputFields, templates, nodes]
|
||||
);
|
||||
|
||||
const revertSelection = useCallback(
|
||||
(outputsToRestore: WorkflowOutputField[]) => {
|
||||
dispatch(
|
||||
workflowOutputFieldsChanged(
|
||||
outputsToRestore.map(({ nodeId, fieldName, userLabel }) => ({
|
||||
nodeId,
|
||||
fieldName,
|
||||
userLabel,
|
||||
}))
|
||||
)
|
||||
);
|
||||
const restoredNodeId = outputsToRestore[0]?.nodeId ?? null;
|
||||
$outputNodeId.set(restoredNodeId);
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
previousOutputsRef.current = outputFields;
|
||||
}, [outputFields]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isInPublishFlow || isSelectingOutputNode) {
|
||||
return;
|
||||
}
|
||||
const currentNodeId = outputFields[0]?.nodeId ?? null;
|
||||
if ($outputNodeId.get() !== currentNodeId) {
|
||||
$outputNodeId.set(currentNodeId);
|
||||
}
|
||||
}, [isInPublishFlow, isSelectingOutputNode, outputFields]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectionSourceRef.current !== 'workflow-output-tab') {
|
||||
return;
|
||||
}
|
||||
if (isSelectingOutputNode) {
|
||||
return;
|
||||
}
|
||||
const selectedNodeId = outputNodeId;
|
||||
selectionSourceRef.current = null;
|
||||
|
||||
if (!selectedNodeId) {
|
||||
return;
|
||||
}
|
||||
|
||||
const state = store.getState();
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const node = nodesState.nodes.find((n) => n.id === selectedNodeId);
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('workflows.builder.outputFieldSelectInvalid', { defaultValue: 'Please select an invocation node.' }),
|
||||
});
|
||||
revertSelection(previousOutputsRef.current);
|
||||
return;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('workflows.builder.outputFieldMissingTemplate', {
|
||||
defaultValue: 'Unable to select outputs for this node.',
|
||||
}),
|
||||
});
|
||||
revertSelection(previousOutputsRef.current);
|
||||
return;
|
||||
}
|
||||
|
||||
const imageOutputEntry = Object.entries(template.outputs).find(([, output]) => {
|
||||
return output.type.name === 'ImageField' && output.type.cardinality !== 'COLLECTION';
|
||||
});
|
||||
|
||||
if (!imageOutputEntry) {
|
||||
toast({
|
||||
status: 'error',
|
||||
title: t('workflows.builder.outputFieldMustBeImage', {
|
||||
defaultValue: 'Selected node must have an image output.',
|
||||
}),
|
||||
});
|
||||
revertSelection(previousOutputsRef.current);
|
||||
return;
|
||||
}
|
||||
|
||||
const [fieldName, fieldTemplate] = imageOutputEntry;
|
||||
dispatch(
|
||||
workflowOutputFieldsChanged([
|
||||
{
|
||||
nodeId: selectedNodeId,
|
||||
fieldName,
|
||||
userLabel: fieldTemplate.title ?? null,
|
||||
},
|
||||
])
|
||||
);
|
||||
}, [dispatch, isSelectingOutputNode, outputNodeId, revertSelection, store, templates, t]);
|
||||
|
||||
const handleSelectNodeClick = useCallback(() => {
|
||||
selectionSourceRef.current = 'workflow-output-tab';
|
||||
previousOutputsRef.current = outputFields;
|
||||
$outputNodeId.set(null);
|
||||
$isSelectingOutputNode.set(true);
|
||||
}, [outputFields]);
|
||||
|
||||
const handleClear = useCallback(() => {
|
||||
previousOutputsRef.current = [];
|
||||
dispatch(workflowOutputFieldsChanged([]));
|
||||
$outputNodeId.set(null);
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} h="full">
|
||||
<Flex alignItems="center">
|
||||
<Text fontWeight="semibold">{t('workflows.builder.outputFieldsTab', 'Output Fields')}</Text>
|
||||
<Spacer />
|
||||
<ButtonGroup size="sm" variant="ghost" isAttached={false}>
|
||||
{outputDetails.length > 0 && (
|
||||
<Button leftIcon={<PiTrashBold />} onClick={handleClear}>
|
||||
{t('common.clear')}
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
leftIcon={<PiArrowLineRightBold />}
|
||||
onClick={handleSelectNodeClick}
|
||||
isDisabled={isSelectingOutputNode}
|
||||
>
|
||||
{isSelectingOutputNode
|
||||
? t('workflows.builder.selectingOutputNode')
|
||||
: outputDetails.length > 0
|
||||
? t('workflows.builder.changeOutputNode')
|
||||
: t('workflows.builder.selectOutputNode')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
<Divider />
|
||||
{outputDetails.length === 0 ? (
|
||||
<Text color="warning.300" fontWeight="semibold">
|
||||
{outputFields.length === 0
|
||||
? t('workflows.builder.noOutputNodeSelected')
|
||||
: t('workflows.builder.outputFieldPending', {
|
||||
defaultValue: 'Selected output is unavailable. Check that the node still exists.',
|
||||
})}
|
||||
</Text>
|
||||
) : (
|
||||
outputDetails.map((detail) => (
|
||||
<InvocationNodeContextProvider nodeId={detail.nodeId} key={`${detail.nodeId}-${detail.fieldName}`}>
|
||||
<SelectedOutputPreview
|
||||
nodeId={detail.nodeId}
|
||||
fieldName={detail.fieldName}
|
||||
fallbackLabel={detail.userLabel ?? detail.fieldName}
|
||||
/>
|
||||
</InvocationNodeContextProvider>
|
||||
))
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default WorkflowOutputFieldsTab;
|
||||
|
||||
const SelectedOutputPreview = ({
|
||||
nodeId,
|
||||
fieldName,
|
||||
fallbackLabel,
|
||||
}: {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fallbackLabel: string;
|
||||
}) => {
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const nodeUserTitle = useNodeUserTitleOrThrow();
|
||||
const nodeTemplateTitle = useNodeTemplateTitleOrThrow();
|
||||
const fieldTemplate = useOutputFieldTemplate(fieldName);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
position="relative"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
borderWidth={1}
|
||||
gap={1}
|
||||
onMouseOver={mouseOverFormField.handleMouseOver}
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
onClick={zoomToNode}
|
||||
>
|
||||
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldTemplate?.title ?? fallbackLabel}`}</Text>
|
||||
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
|
||||
<NodeFieldElementOverlay nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -8,7 +8,6 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import WorkflowGeneralTab from './WorkflowGeneralTab';
|
||||
import WorkflowJSONTab from './WorkflowJSONTab';
|
||||
import WorkflowOutputFieldsTab from './WorkflowOutputFieldsTab';
|
||||
|
||||
const WorkflowFieldsLinearViewPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
@@ -18,7 +17,6 @@ const WorkflowFieldsLinearViewPanel = () => {
|
||||
<TabList>
|
||||
<Tab>{t('workflows.builder.builder')}</Tab>
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>{t('workflows.builder.outputFieldsTab', 'Output Fields')}</Tab>
|
||||
<Tab>JSON</Tab>
|
||||
<Spacer />
|
||||
{allowPublishWorkflows && <StartPublishFlowButton />}
|
||||
@@ -31,9 +29,6 @@ const WorkflowFieldsLinearViewPanel = () => {
|
||||
<TabPanel h="full" p={0}>
|
||||
<WorkflowGeneralTab />
|
||||
</TabPanel>
|
||||
<TabPanel h="full" p={0}>
|
||||
<WorkflowOutputFieldsTab />
|
||||
</TabPanel>
|
||||
<TabPanel h="full" p={0}>
|
||||
<WorkflowJSONTab />
|
||||
</TabPanel>
|
||||
|
||||
@@ -2,16 +2,20 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInvocationNodeContext } from 'features/nodes/components/flow/nodes/Invocation/context';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { isSingleOrCollection, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import { isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
const isConnectionInputField = (field: FieldInputTemplate) => {
|
||||
return (field.input === 'connection' && !isSingleOrCollection(field.type)) || !isStatefulFieldType(field.type);
|
||||
return (
|
||||
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
|
||||
);
|
||||
};
|
||||
|
||||
const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
|
||||
return (
|
||||
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) && isStatefulFieldType(field.type)
|
||||
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
|
||||
field.type.name in TEMPLATE_BUILDER_MAP
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -24,44 +24,90 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
ChatGPT4oModelFieldValue,
|
||||
CLIPEmbedModelFieldValue,
|
||||
CLIPGEmbedModelFieldValue,
|
||||
CLIPLEmbedModelFieldValue,
|
||||
ColorFieldValue,
|
||||
ControlLoRAModelFieldValue,
|
||||
ControlNetModelFieldValue,
|
||||
EnumFieldValue,
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
FloatGeneratorFieldValue,
|
||||
FluxKontextModelFieldValue,
|
||||
FluxReduxModelFieldValue,
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldCollectionValue,
|
||||
ImageFieldValue,
|
||||
ImageGeneratorFieldValue,
|
||||
Imagen3ModelFieldValue,
|
||||
Imagen4ModelFieldValue,
|
||||
IntegerFieldCollectionValue,
|
||||
IntegerFieldValue,
|
||||
IntegerGeneratorFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LLaVAModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
ModelIdentifierFieldValue,
|
||||
RunwayModelFieldValue,
|
||||
SchedulerFieldValue,
|
||||
SDXLRefinerModelFieldValue,
|
||||
SigLipModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldCollectionValue,
|
||||
StringFieldValue,
|
||||
StringGeneratorFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
T5EncoderModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
Veo3ModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zChatGPT4oModelFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zCLIPGEmbedModelFieldValue,
|
||||
zCLIPLEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zControlLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFloatGeneratorFieldValue,
|
||||
zFluxKontextModelFieldValue,
|
||||
zFluxReduxModelFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
zImageFieldValue,
|
||||
zImageGeneratorFieldValue,
|
||||
zImagen3ModelFieldValue,
|
||||
zImagen4ModelFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zIntegerFieldValue,
|
||||
zIntegerGeneratorFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zLLaVAModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
zRunwayModelFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSigLipModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zStringGeneratorFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zVeo3ModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
@@ -74,8 +120,7 @@ import type {
|
||||
NodeFieldElement,
|
||||
TextElement,
|
||||
WorkflowCategory,
|
||||
WorkflowOutputField,
|
||||
WorkflowV4,
|
||||
WorkflowV3,
|
||||
} from 'features/nodes/types/workflow';
|
||||
import {
|
||||
getDefaultForm,
|
||||
@@ -102,8 +147,7 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
|
||||
tags: '',
|
||||
notes: '',
|
||||
exposedFields: [],
|
||||
output_fields: [] as WorkflowOutputField[],
|
||||
meta: { version: '4.0.0', category: 'user' },
|
||||
meta: { version: '3.0.0', category: 'user' },
|
||||
form: getDefaultForm(),
|
||||
nodes: [],
|
||||
edges: [],
|
||||
@@ -195,9 +239,6 @@ const slice = createSlice({
|
||||
);
|
||||
|
||||
if (didNodesChange) {
|
||||
const remainingNodeIds = new Set(state.nodes.map((node) => node.id));
|
||||
state.output_fields = state.output_fields.filter((field) => remainingNodeIds.has(field.nodeId));
|
||||
|
||||
const edgeChanges: EdgeChange<AnyEdge>[] = [];
|
||||
for (const e of state.edges) {
|
||||
const sourceExists = state.nodes.some((n) => n.id === e.source);
|
||||
@@ -452,9 +493,81 @@ const slice = createSlice({
|
||||
fieldColorValueChanged: (state, action: FieldValueAction<ColorFieldValue>) => {
|
||||
fieldValueReducer(state, action, zColorFieldValue);
|
||||
},
|
||||
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zMainModelFieldValue);
|
||||
},
|
||||
fieldModelIdentifierValueChanged: (state, action: FieldValueAction<ModelIdentifierFieldValue>) => {
|
||||
fieldValueReducer(state, action, zModelIdentifierFieldValue);
|
||||
},
|
||||
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
|
||||
},
|
||||
fieldVaeModelValueChanged: (state, action: FieldValueAction<VAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zVAEModelFieldValue);
|
||||
},
|
||||
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zLoRAModelFieldValue);
|
||||
},
|
||||
fieldLLaVAModelValueChanged: (state, action: FieldValueAction<LLaVAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zLLaVAModelFieldValue);
|
||||
},
|
||||
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zControlNetModelFieldValue);
|
||||
},
|
||||
fieldIPAdapterModelValueChanged: (state, action: FieldValueAction<IPAdapterModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zIPAdapterModelFieldValue);
|
||||
},
|
||||
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
|
||||
},
|
||||
fieldSpandrelImageToImageModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SpandrelImageToImageModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
|
||||
},
|
||||
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
|
||||
},
|
||||
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
|
||||
},
|
||||
fieldCLIPLEmbedValueChanged: (state, action: FieldValueAction<CLIPLEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPLEmbedModelFieldValue);
|
||||
},
|
||||
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
|
||||
},
|
||||
fieldControlLoRAModelValueChanged: (state, action: FieldValueAction<ControlLoRAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zControlLoRAModelFieldValue);
|
||||
},
|
||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||
},
|
||||
fieldSigLipModelValueChanged: (state, action: FieldValueAction<SigLipModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zSigLipModelFieldValue);
|
||||
},
|
||||
fieldFluxReduxModelValueChanged: (state, action: FieldValueAction<FluxReduxModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxReduxModelFieldValue);
|
||||
},
|
||||
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zImagen3ModelFieldValue);
|
||||
},
|
||||
fieldImagen4ModelValueChanged: (state, action: FieldValueAction<Imagen4ModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zImagen4ModelFieldValue);
|
||||
},
|
||||
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
|
||||
},
|
||||
fieldVeo3ModelValueChanged: (state, action: FieldValueAction<Veo3ModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zVeo3ModelFieldValue);
|
||||
},
|
||||
fieldRunwayModelValueChanged: (state, action: FieldValueAction<RunwayModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zRunwayModelFieldValue);
|
||||
},
|
||||
fieldFluxKontextModelValueChanged: (state, action: FieldValueAction<FluxKontextModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxKontextModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@@ -517,43 +630,6 @@ const slice = createSlice({
|
||||
workflowContactChanged: (state, action: PayloadAction<string>) => {
|
||||
state.contact = action.payload;
|
||||
},
|
||||
workflowOutputFieldsChanged: (
|
||||
state,
|
||||
action: PayloadAction<Array<{ nodeId: string; fieldName: string; userLabel?: string | null }>>
|
||||
) => {
|
||||
const validOutputs: WorkflowOutputField[] = [];
|
||||
|
||||
for (const identifier of action.payload) {
|
||||
if (validOutputs.length >= 1) {
|
||||
break; // only allow one output field
|
||||
}
|
||||
|
||||
const nodeId = identifier.nodeId;
|
||||
const fieldName = identifier.fieldName;
|
||||
if (!nodeId || !fieldName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const label = identifier.userLabel ?? null;
|
||||
|
||||
validOutputs.push({
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind: 'output',
|
||||
userLabel: label,
|
||||
node_id: nodeId,
|
||||
field_name: fieldName,
|
||||
user_label: label,
|
||||
});
|
||||
}
|
||||
|
||||
state.output_fields = validOutputs;
|
||||
},
|
||||
workflowIDChanged: (state, action: PayloadAction<string>) => {
|
||||
state.id = action.payload;
|
||||
},
|
||||
@@ -606,7 +682,7 @@ const slice = createSlice({
|
||||
const { formFieldInitialValues } = action.payload;
|
||||
state.formFieldInitialValues = formFieldInitialValues;
|
||||
},
|
||||
workflowLoaded: (state, action: PayloadAction<WorkflowV4>) => {
|
||||
workflowLoaded: (state, action: PayloadAction<WorkflowV3>) => {
|
||||
const { nodes, edges, is_published: _is_published, ...workflowExtra } = action.payload;
|
||||
|
||||
const formFieldInitialValues = getFormFieldInitialValues(workflowExtra.form, nodes);
|
||||
@@ -630,22 +706,45 @@ export const {
|
||||
fieldBoardValueChanged,
|
||||
fieldBooleanValueChanged,
|
||||
fieldColorValueChanged,
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldImageCollectionValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldSpandrelImageToImageModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldLLaVAModelValueChanged,
|
||||
fieldModelIdentifierValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldIntegerValueChanged,
|
||||
fieldFloatValueChanged,
|
||||
fieldFloatCollectionValueChanged,
|
||||
fieldIntegerCollectionValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldCLIPLEmbedValueChanged,
|
||||
fieldCLIPGEmbedValueChanged,
|
||||
fieldControlLoRAModelValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
fieldSigLipModelValueChanged,
|
||||
fieldFluxReduxModelValueChanged,
|
||||
fieldImagen3ModelValueChanged,
|
||||
fieldImagen4ModelValueChanged,
|
||||
fieldChatGPT4oModelValueChanged,
|
||||
fieldFluxKontextModelValueChanged,
|
||||
fieldFloatGeneratorValueChanged,
|
||||
fieldIntegerGeneratorValueChanged,
|
||||
fieldStringGeneratorValueChanged,
|
||||
fieldImageGeneratorValueChanged,
|
||||
fieldVeo3ModelValueChanged,
|
||||
fieldRunwayModelValueChanged,
|
||||
fieldDescriptionChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
@@ -663,7 +762,6 @@ export const {
|
||||
workflowNotesChanged,
|
||||
workflowVersionChanged,
|
||||
workflowContactChanged,
|
||||
workflowOutputFieldsChanged,
|
||||
workflowIDChanged,
|
||||
formReset,
|
||||
formElementAdded,
|
||||
|
||||
@@ -76,7 +76,6 @@ export const selectWorkflowContact = createNodesSelector((workflow) => workflow.
|
||||
export const selectWorkflowTags = createNodesSelector((workflow) => workflow.tags);
|
||||
export const selectWorkflowVersion = createNodesSelector((workflow) => workflow.version);
|
||||
export const selectWorkflowForm = createNodesSelector((workflow) => workflow.form);
|
||||
export const selectWorkflowOutputFields = createNodesSelector((workflow) => workflow.output_fields);
|
||||
|
||||
export const selectFormRootElementId = createNodesSelector((workflow) => {
|
||||
return workflow.form.rootElementId;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { HandleType } from '@xyflow/react';
|
||||
import { type FieldInputTemplate, type FieldOutputTemplate, zStatefulFieldValue } from 'features/nodes/types/field';
|
||||
import { type InvocationTemplate, type NodeExecutionState, zAnyEdge, zAnyNode } from 'features/nodes/types/invocation';
|
||||
import { zWorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import z from 'zod';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
@@ -21,6 +21,6 @@ export const zNodesState = z.object({
|
||||
nodes: z.array(zAnyNode),
|
||||
edges: z.array(zAnyEdge),
|
||||
formFieldInitialValues: z.record(z.string(), zStatefulFieldValue),
|
||||
...zWorkflowV4.omit({ nodes: true, edges: true, is_published: true }).shape,
|
||||
...zWorkflowV3.omit({ nodes: true, edges: true, is_published: true }).shape,
|
||||
});
|
||||
export type NodesState = z.infer<typeof zNodesState>;
|
||||
|
||||
@@ -40,7 +40,7 @@ describe(areTypesEqual.name, () => {
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'StringField',
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
@@ -54,7 +54,7 @@ describe(areTypesEqual.name, () => {
|
||||
|
||||
it('should handle equal original source type and target type', () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'FloatField',
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
@@ -78,7 +78,7 @@ describe(areTypesEqual.name, () => {
|
||||
|
||||
it('should handle equal original source type and original target type', () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
@@ -88,7 +88,7 @@ describe(areTypesEqual.name, () => {
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'StringField',
|
||||
name: 'LoRAModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
|
||||
@@ -247,12 +247,17 @@ export const main_model_loader: InvocationTemplate = {
|
||||
fieldKind: 'input',
|
||||
input: 'direct',
|
||||
ui_hidden: false,
|
||||
ui_model_base: ['sd-1'],
|
||||
ui_model_type: ['main'],
|
||||
ui_type: 'MainModelField',
|
||||
type: {
|
||||
name: 'ModelIdentifierField',
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
|
||||
originalType: {
|
||||
name: 'ModelIdentifierField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -791,8 +796,7 @@ export const schema = {
|
||||
input: 'direct',
|
||||
orig_required: true,
|
||||
ui_hidden: false,
|
||||
ui_model_base: ['sd-1'],
|
||||
ui_model_type: ['main'],
|
||||
ui_type: 'MainModelField',
|
||||
},
|
||||
type: {
|
||||
type: 'string',
|
||||
|
||||
@@ -12,15 +12,11 @@ import type {
|
||||
SchedulerField,
|
||||
SubModelType,
|
||||
T2IAdapterField,
|
||||
zClipVariantType,
|
||||
zModelFormat,
|
||||
zModelVariantType,
|
||||
} from 'features/nodes/types/common';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
import type { Invocation, ModelType, S } from 'services/api/types';
|
||||
import type { Equals, Extends } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
import type z from 'zod';
|
||||
|
||||
/**
|
||||
* These types originate from the server and are recreated as zod schemas manually, for use at runtime.
|
||||
@@ -42,9 +38,7 @@ describe('Common types', () => {
|
||||
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
|
||||
test('ModelIdentifier', () => assert<Equals<BaseModelType, S['BaseModelType']>>());
|
||||
test('ModelIdentifier', () => assert<Equals<SubModelType, S['SubModelType']>>());
|
||||
test('ClipVariantType', () => assert<Equals<z.infer<typeof zClipVariantType>, S['ClipVariantType']>>());
|
||||
test('ModelVariantType', () => assert<Equals<z.infer<typeof zModelVariantType>, S['ModelVariantType']>>());
|
||||
test('ModelFormat', () => assert<Equals<z.infer<typeof zModelFormat>, S['ModelFormat']>>());
|
||||
test('ModelIdentifier', () => assert<Equals<ModelType, S['ModelType']>>());
|
||||
|
||||
// Misc types
|
||||
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
|
||||
|
||||
@@ -73,7 +73,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
export const zBaseModelType = z.enum([
|
||||
const zBaseModel = z.enum([
|
||||
'any',
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
@@ -90,7 +90,7 @@ export const zBaseModelType = z.enum([
|
||||
'veo3',
|
||||
'runway',
|
||||
]);
|
||||
export type BaseModelType = z.infer<typeof zBaseModelType>;
|
||||
export type BaseModelType = z.infer<typeof zBaseModel>;
|
||||
export const zMainModelBase = z.enum([
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
@@ -143,31 +143,11 @@ const zSubModelType = z.enum([
|
||||
'safety_checker',
|
||||
]);
|
||||
export type SubModelType = z.infer<typeof zSubModelType>;
|
||||
|
||||
export const zClipVariantType = z.enum(['large', 'gigantic']);
|
||||
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
|
||||
export const zModelFormat = z.enum([
|
||||
'omi',
|
||||
'diffusers',
|
||||
'checkpoint',
|
||||
'lycoris',
|
||||
'onnx',
|
||||
'olive',
|
||||
'embedding_file',
|
||||
'embedding_folder',
|
||||
'invokeai',
|
||||
't5_encoder',
|
||||
'bnb_quantized_int8b',
|
||||
'bnb_quantized_nf4b',
|
||||
'gguf_quantized',
|
||||
'api',
|
||||
]);
|
||||
|
||||
export const zModelIdentifierField = z.object({
|
||||
key: z.string().min(1),
|
||||
hash: z.string().min(1),
|
||||
name: z.string().min(1),
|
||||
base: zBaseModelType,
|
||||
base: zBaseModel,
|
||||
type: zModelType,
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
|
||||
@@ -9,18 +9,7 @@ import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { ImageField } from './common';
|
||||
import {
|
||||
zBaseModelType,
|
||||
zBoardField,
|
||||
zClipVariantType,
|
||||
zColorField,
|
||||
zImageField,
|
||||
zModelFormat,
|
||||
zModelIdentifierField,
|
||||
zModelType,
|
||||
zModelVariantType,
|
||||
zSchedulerField,
|
||||
} from './common';
|
||||
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
|
||||
|
||||
/**
|
||||
* zod schemas & inferred types for fields.
|
||||
@@ -71,10 +60,6 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({
|
||||
default: z.undefined(),
|
||||
ui_component: zFieldUIComponent.nullish(),
|
||||
ui_choice_labels: z.record(z.string(), z.string()).nullish(),
|
||||
ui_model_base: z.array(zBaseModelType).nullish(),
|
||||
ui_model_type: z.array(zModelType).nullish(),
|
||||
ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(),
|
||||
ui_model_format: z.array(zModelFormat).nullish(),
|
||||
});
|
||||
const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
@@ -176,10 +161,118 @@ const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zModelIdentifierFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ModelIdentifierField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SD3MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCogView4MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CogView4MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zLLaVAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LLaVAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SpandrelImageToImageModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zT5EncoderModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T5EncoderModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPLEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPLEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPGEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zControlLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlLoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxVAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSigLipModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SigLipModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxReduxModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxReduxModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImagen3ModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('Imagen3ModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImagen4ModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('Imagen4ModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ChatGPT4oModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zVeo3ModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('Veo3ModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zRunwayModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('RunwayModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxKontextModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxKontextModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -209,6 +302,33 @@ const zStatefulFieldType = z.union([
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zModelIdentifierFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSD3MainModelFieldType,
|
||||
zCogView4MainModelFieldType,
|
||||
zFluxMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zLLaVAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zT5EncoderModelFieldType,
|
||||
zCLIPEmbedModelFieldType,
|
||||
zCLIPLEmbedModelFieldType,
|
||||
zCLIPGEmbedModelFieldType,
|
||||
zControlLoRAModelFieldType,
|
||||
zFluxVAEModelFieldType,
|
||||
zSigLipModelFieldType,
|
||||
zFluxReduxModelFieldType,
|
||||
zImagen3ModelFieldType,
|
||||
zImagen4ModelFieldType,
|
||||
zChatGPT4oModelFieldType,
|
||||
zFluxKontextModelFieldType,
|
||||
zVeo3ModelFieldType,
|
||||
zRunwayModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
zFloatGeneratorFieldType,
|
||||
@@ -224,7 +344,36 @@ const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
|
||||
const modelFieldTypeNames = [
|
||||
// Stateful model fields
|
||||
zModelIdentifierFieldType.shape.name.value,
|
||||
zMainModelFieldType.shape.name.value,
|
||||
zSDXLMainModelFieldType.shape.name.value,
|
||||
zSD3MainModelFieldType.shape.name.value,
|
||||
zCogView4MainModelFieldType.shape.name.value,
|
||||
zFluxMainModelFieldType.shape.name.value,
|
||||
zSDXLRefinerModelFieldType.shape.name.value,
|
||||
zVAEModelFieldType.shape.name.value,
|
||||
zLoRAModelFieldType.shape.name.value,
|
||||
zLLaVAModelFieldType.shape.name.value,
|
||||
zControlNetModelFieldType.shape.name.value,
|
||||
zIPAdapterModelFieldType.shape.name.value,
|
||||
zT2IAdapterModelFieldType.shape.name.value,
|
||||
zSpandrelImageToImageModelFieldType.shape.name.value,
|
||||
zT5EncoderModelFieldType.shape.name.value,
|
||||
zCLIPEmbedModelFieldType.shape.name.value,
|
||||
zCLIPLEmbedModelFieldType.shape.name.value,
|
||||
zCLIPGEmbedModelFieldType.shape.name.value,
|
||||
zControlLoRAModelFieldType.shape.name.value,
|
||||
zFluxVAEModelFieldType.shape.name.value,
|
||||
zSigLipModelFieldType.shape.name.value,
|
||||
zFluxReduxModelFieldType.shape.name.value,
|
||||
zImagen3ModelFieldType.shape.name.value,
|
||||
zImagen4ModelFieldType.shape.name.value,
|
||||
zChatGPT4oModelFieldType.shape.name.value,
|
||||
zFluxKontextModelFieldType.shape.name.value,
|
||||
zVeo3ModelFieldType.shape.name.value,
|
||||
zRunwayModelFieldType.shape.name.value,
|
||||
// Stateless model fields
|
||||
'UNetField',
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
@@ -630,6 +779,26 @@ export const isColorFieldInputInstance = buildInstanceTypeGuard(zColorFieldInput
|
||||
export const isColorFieldInputTemplate = buildTemplateTypeGuard<ColorFieldInputTemplate>('ColorField');
|
||||
// #endregion
|
||||
|
||||
// #region MainModelField
|
||||
export const zMainModelFieldValue = zModelIdentifierField.optional();
|
||||
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zMainModelFieldValue,
|
||||
});
|
||||
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
});
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
|
||||
export const isMainModelFieldInputInstance = buildInstanceTypeGuard(zMainModelFieldInputInstance);
|
||||
export const isMainModelFieldInputTemplate = buildTemplateTypeGuard<MainModelFieldInputTemplate>('MainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region ModelIdentifierField
|
||||
export const zModelIdentifierFieldValue = zModelIdentifierField.optional();
|
||||
const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@@ -651,6 +820,507 @@ export const isModelIdentifierFieldInputTemplate =
|
||||
buildTemplateTypeGuard<ModelIdentifierFieldInputTemplate>('ModelIdentifierField');
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSDXLMainModelFieldValue,
|
||||
});
|
||||
const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSDXLMainModelFieldValue,
|
||||
});
|
||||
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
});
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
export const isSDXLMainModelFieldInputInstance = buildInstanceTypeGuard(zSDXLMainModelFieldInputInstance);
|
||||
export const isSDXLMainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<SDXLMainModelFieldInputTemplate>('SDXLMainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region SD3MainModelField
|
||||
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
});
|
||||
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||
export const isSD3MainModelFieldInputInstance = buildInstanceTypeGuard(zSD3MainModelFieldInputInstance);
|
||||
export const isSD3MainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<SD3MainModelFieldInputTemplate>('SD3MainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region CogView4MainModelField
|
||||
const zCogView4MainModelFieldValue = zMainModelFieldValue;
|
||||
const zCogView4MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCogView4MainModelFieldValue,
|
||||
});
|
||||
const zCogView4MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCogView4MainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCogView4MainModelFieldValue,
|
||||
});
|
||||
const zCogView4MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zCogView4MainModelFieldType,
|
||||
});
|
||||
export type CogView4MainModelFieldInputInstance = z.infer<typeof zCogView4MainModelFieldInputInstance>;
|
||||
export type CogView4MainModelFieldInputTemplate = z.infer<typeof zCogView4MainModelFieldInputTemplate>;
|
||||
export const isCogView4MainModelFieldInputInstance = buildInstanceTypeGuard(zCogView4MainModelFieldInputInstance);
|
||||
export const isCogView4MainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<CogView4MainModelFieldInputTemplate>('CogView4MainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region FluxMainModelField
|
||||
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxMainModelFieldValue,
|
||||
});
|
||||
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxMainModelFieldValue,
|
||||
});
|
||||
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFluxMainModelFieldType,
|
||||
});
|
||||
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
|
||||
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
|
||||
export const isFluxMainModelFieldInputInstance = buildInstanceTypeGuard(zFluxMainModelFieldInputInstance);
|
||||
export const isFluxMainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<FluxMainModelFieldInputTemplate>('FluxMainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
/** @alias */ // tells knip to ignore this duplicate export
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
});
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
|
||||
export const isSDXLRefinerModelFieldInputInstance = buildInstanceTypeGuard(zSDXLRefinerModelFieldInputInstance);
|
||||
export const isSDXLRefinerModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<SDXLRefinerModelFieldInputTemplate>('SDXLRefinerModelField');
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
|
||||
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zVAEModelFieldValue,
|
||||
});
|
||||
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
});
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
|
||||
export const isVAEModelFieldInputInstance = buildInstanceTypeGuard(zVAEModelFieldInputInstance);
|
||||
export const isVAEModelFieldInputTemplate = buildTemplateTypeGuard<VAEModelFieldInputTemplate>('VAEModelField');
|
||||
// #endregion
|
||||
|
||||
// #region LoRAModelField
|
||||
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zLoRAModelFieldValue,
|
||||
});
|
||||
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
});
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
|
||||
export const isLoRAModelFieldInputInstance = buildInstanceTypeGuard(zLoRAModelFieldInputInstance);
|
||||
export const isLoRAModelFieldInputTemplate = buildTemplateTypeGuard<LoRAModelFieldInputTemplate>('LoRAModelField');
|
||||
// #endregion
|
||||
|
||||
// #region LLaVAModelField
|
||||
export const zLLaVAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zLLaVAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zLLaVAModelFieldValue,
|
||||
});
|
||||
const zLLaVAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLLaVAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zLLaVAModelFieldValue,
|
||||
});
|
||||
const zLLaVAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLLaVAModelFieldType,
|
||||
});
|
||||
export type LLaVAModelFieldValue = z.infer<typeof zLLaVAModelFieldValue>;
|
||||
export type LLaVAModelFieldInputInstance = z.infer<typeof zLLaVAModelFieldInputInstance>;
|
||||
export type LLaVAModelFieldInputTemplate = z.infer<typeof zLLaVAModelFieldInputTemplate>;
|
||||
export const isLLaVAModelFieldInputInstance = buildInstanceTypeGuard(zLLaVAModelFieldInputInstance);
|
||||
export const isLLaVAModelFieldInputTemplate = buildTemplateTypeGuard<LLaVAModelFieldInputTemplate>('LLaVAModelField');
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zControlNetModelFieldValue,
|
||||
});
|
||||
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
});
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
|
||||
export const isControlNetModelFieldInputInstance = buildInstanceTypeGuard(zControlNetModelFieldInputInstance);
|
||||
export const isControlNetModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<ControlNetModelFieldInputTemplate>('ControlNetModelField');
|
||||
// #endregion
|
||||
|
||||
// #region IPAdapterModelField
|
||||
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIPAdapterModelFieldValue,
|
||||
});
|
||||
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
});
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
|
||||
export const isIPAdapterModelFieldInputInstance = buildInstanceTypeGuard(zIPAdapterModelFieldInputInstance);
|
||||
export const isIPAdapterModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<IPAdapterModelFieldInputTemplate>('IPAdapterModelField');
|
||||
// #endregion
|
||||
|
||||
// #region T2IAdapterField
|
||||
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
});
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
|
||||
export const isT2IAdapterModelFieldInputInstance = buildInstanceTypeGuard(zT2IAdapterModelFieldInputInstance);
|
||||
export const isT2IAdapterModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<T2IAdapterModelFieldInputTemplate>('T2IAdapterModelField');
|
||||
// #endregion
|
||||
|
||||
// #region SpandrelModelToModelField
|
||||
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
|
||||
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
});
|
||||
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
|
||||
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
|
||||
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = buildInstanceTypeGuard(
|
||||
zSpandrelImageToImageModelFieldInputInstance
|
||||
);
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<SpandrelImageToImageModelFieldInputTemplate>('SpandrelImageToImageModelField');
|
||||
// #endregion
|
||||
|
||||
// #region T5EncoderModelField
|
||||
|
||||
export const zT5EncoderModelFieldValue = zModelIdentifierField.optional();
|
||||
const zT5EncoderModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zT5EncoderModelFieldValue,
|
||||
});
|
||||
const zT5EncoderModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT5EncoderModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zT5EncoderModelFieldValue,
|
||||
});
|
||||
export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>;
|
||||
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
|
||||
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
|
||||
export const isT5EncoderModelFieldInputInstance = buildInstanceTypeGuard(zT5EncoderModelFieldInputInstance);
|
||||
export const isT5EncoderModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<T5EncoderModelFieldInputTemplate>('T5EncoderModelField');
|
||||
// #endregion
|
||||
|
||||
// #region FluxVAEModelField
|
||||
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxVAEModelFieldValue,
|
||||
});
|
||||
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxVAEModelFieldValue,
|
||||
});
|
||||
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
|
||||
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
|
||||
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
|
||||
export const isFluxVAEModelFieldInputInstance = buildInstanceTypeGuard(zFluxVAEModelFieldInputInstance);
|
||||
export const isFluxVAEModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<FluxVAEModelFieldInputTemplate>('FluxVAEModelField');
|
||||
// #endregion
|
||||
|
||||
// #region CLIPEmbedModelField
|
||||
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
|
||||
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
|
||||
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPEmbedModelFieldInputInstance = buildInstanceTypeGuard(zCLIPEmbedModelFieldInputInstance);
|
||||
export const isCLIPEmbedModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<CLIPEmbedModelFieldInputTemplate>('CLIPEmbedModelField');
|
||||
// #endregion
|
||||
|
||||
// #region CLIPLEmbedModelField
|
||||
export const zCLIPLEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPLEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPLEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPLEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPLEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPLEmbedModelFieldValue,
|
||||
});
|
||||
export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
|
||||
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPLEmbedModelFieldInputInstance = buildInstanceTypeGuard(zCLIPLEmbedModelFieldInputInstance);
|
||||
export const isCLIPLEmbedModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<CLIPLEmbedModelFieldInputTemplate>('CLIPLEmbedModelField');
|
||||
// #endregion
|
||||
|
||||
// #region CLIPGEmbedModelField
|
||||
export const zCLIPGEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPGEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPGEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPGEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPGEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPGEmbedModelFieldValue,
|
||||
});
|
||||
export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
|
||||
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPGEmbedModelFieldInputInstance = buildInstanceTypeGuard(zCLIPGEmbedModelFieldInputInstance);
|
||||
export const isCLIPGEmbedModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<CLIPGEmbedModelFieldInputTemplate>('CLIPGEmbedModelField');
|
||||
// #endregion
|
||||
|
||||
// #region ControlLoRAModelField
|
||||
export const zControlLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zControlLoRAModelFieldValue,
|
||||
});
|
||||
const zControlLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zControlLoRAModelFieldValue,
|
||||
});
|
||||
export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
|
||||
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
|
||||
export const isControlLoRAModelFieldInputInstance = buildInstanceTypeGuard(zControlLoRAModelFieldInputInstance);
|
||||
export const isControlLoRAModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<ControlLoRAModelFieldInputTemplate>('ControlLoRAModelField');
|
||||
// #endregion
|
||||
|
||||
// #region SigLipModelField
|
||||
export const zSigLipModelFieldValue = zModelIdentifierField.optional();
|
||||
const zSigLipModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSigLipModelFieldValue,
|
||||
});
|
||||
const zSigLipModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSigLipModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSigLipModelFieldValue,
|
||||
});
|
||||
export type SigLipModelFieldValue = z.infer<typeof zSigLipModelFieldValue>;
|
||||
export type SigLipModelFieldInputInstance = z.infer<typeof zSigLipModelFieldInputInstance>;
|
||||
export type SigLipModelFieldInputTemplate = z.infer<typeof zSigLipModelFieldInputTemplate>;
|
||||
export const isSigLipModelFieldInputInstance = buildInstanceTypeGuard(zSigLipModelFieldInputInstance);
|
||||
export const isSigLipModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<SigLipModelFieldInputTemplate>('SigLipModelField');
|
||||
// #endregion
|
||||
|
||||
// #region FluxReduxModelField
|
||||
export const zFluxReduxModelFieldValue = zModelIdentifierField.optional();
|
||||
const zFluxReduxModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxReduxModelFieldValue,
|
||||
});
|
||||
const zFluxReduxModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxReduxModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxReduxModelFieldValue,
|
||||
});
|
||||
export type FluxReduxModelFieldValue = z.infer<typeof zFluxReduxModelFieldValue>;
|
||||
export type FluxReduxModelFieldInputInstance = z.infer<typeof zFluxReduxModelFieldInputInstance>;
|
||||
export type FluxReduxModelFieldInputTemplate = z.infer<typeof zFluxReduxModelFieldInputTemplate>;
|
||||
export const isFluxReduxModelFieldInputInstance = buildInstanceTypeGuard(zFluxReduxModelFieldInputInstance);
|
||||
export const isFluxReduxModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<FluxReduxModelFieldInputTemplate>('FluxReduxModelField');
|
||||
// #endregion
|
||||
|
||||
// #region Imagen3ModelField
|
||||
export const zImagen3ModelFieldValue = zModelIdentifierField.optional();
|
||||
const zImagen3ModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImagen3ModelFieldValue,
|
||||
});
|
||||
const zImagen3ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImagen3ModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zImagen3ModelFieldValue,
|
||||
});
|
||||
export type Imagen3ModelFieldValue = z.infer<typeof zImagen3ModelFieldValue>;
|
||||
export type Imagen3ModelFieldInputInstance = z.infer<typeof zImagen3ModelFieldInputInstance>;
|
||||
export type Imagen3ModelFieldInputTemplate = z.infer<typeof zImagen3ModelFieldInputTemplate>;
|
||||
export const isImagen3ModelFieldInputInstance = buildInstanceTypeGuard(zImagen3ModelFieldInputInstance);
|
||||
export const isImagen3ModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
|
||||
// #endregion
|
||||
|
||||
// #region Imagen4ModelField
|
||||
export const zImagen4ModelFieldValue = zModelIdentifierField.optional();
|
||||
const zImagen4ModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImagen4ModelFieldValue,
|
||||
});
|
||||
const zImagen4ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImagen4ModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zImagen4ModelFieldValue,
|
||||
});
|
||||
export type Imagen4ModelFieldValue = z.infer<typeof zImagen4ModelFieldValue>;
|
||||
export type Imagen4ModelFieldInputInstance = z.infer<typeof zImagen4ModelFieldInputInstance>;
|
||||
export type Imagen4ModelFieldInputTemplate = z.infer<typeof zImagen4ModelFieldInputTemplate>;
|
||||
export const isImagen4ModelFieldInputInstance = buildInstanceTypeGuard(zImagen4ModelFieldInputInstance);
|
||||
export const isImagen4ModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<Imagen4ModelFieldInputTemplate>('Imagen4ModelField');
|
||||
// #endregion
|
||||
|
||||
// #region FluxKontextModelField
|
||||
export const zFluxKontextModelFieldValue = zModelIdentifierField.optional();
|
||||
const zFluxKontextModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxKontextModelFieldValue,
|
||||
});
|
||||
const zFluxKontextModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxKontextModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxKontextModelFieldValue,
|
||||
});
|
||||
export type FluxKontextModelFieldValue = z.infer<typeof zFluxKontextModelFieldValue>;
|
||||
export type FluxKontextModelFieldInputInstance = z.infer<typeof zFluxKontextModelFieldInputInstance>;
|
||||
export type FluxKontextModelFieldInputTemplate = z.infer<typeof zFluxKontextModelFieldInputTemplate>;
|
||||
export const isFluxKontextModelFieldInputInstance = buildInstanceTypeGuard(zFluxKontextModelFieldInputInstance);
|
||||
export const isFluxKontextModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<FluxKontextModelFieldInputTemplate>('FluxKontextModelField');
|
||||
// #endregion
|
||||
|
||||
// #region ChatGPT4oModelField
|
||||
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
|
||||
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zChatGPT4oModelFieldValue,
|
||||
});
|
||||
const zChatGPT4oModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zChatGPT4oModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zChatGPT4oModelFieldValue,
|
||||
});
|
||||
export type ChatGPT4oModelFieldValue = z.infer<typeof zChatGPT4oModelFieldValue>;
|
||||
export type ChatGPT4oModelFieldInputInstance = z.infer<typeof zChatGPT4oModelFieldInputInstance>;
|
||||
export type ChatGPT4oModelFieldInputTemplate = z.infer<typeof zChatGPT4oModelFieldInputTemplate>;
|
||||
export const isChatGPT4oModelFieldInputInstance = buildInstanceTypeGuard(zChatGPT4oModelFieldInputInstance);
|
||||
export const isChatGPT4oModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<ChatGPT4oModelFieldInputTemplate>('ChatGPT4oModelField');
|
||||
// #endregion
|
||||
|
||||
// #region Veo3ModelField
|
||||
export const zVeo3ModelFieldValue = zModelIdentifierField.optional();
|
||||
const zVeo3ModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zVeo3ModelFieldValue,
|
||||
});
|
||||
const zVeo3ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVeo3ModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zVeo3ModelFieldValue,
|
||||
});
|
||||
export type Veo3ModelFieldValue = z.infer<typeof zVeo3ModelFieldValue>;
|
||||
export type Veo3ModelFieldInputInstance = z.infer<typeof zVeo3ModelFieldInputInstance>;
|
||||
export type Veo3ModelFieldInputTemplate = z.infer<typeof zVeo3ModelFieldInputTemplate>;
|
||||
export const isVeo3ModelFieldInputInstance = buildInstanceTypeGuard(zVeo3ModelFieldInputInstance);
|
||||
export const isVeo3ModelFieldInputTemplate = buildTemplateTypeGuard<Veo3ModelFieldInputTemplate>('Veo3ModelField');
|
||||
// #endregion
|
||||
|
||||
// #region RunwayModelField
|
||||
export const zRunwayModelFieldValue = zModelIdentifierField.optional();
|
||||
const zRunwayModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zRunwayModelFieldValue,
|
||||
});
|
||||
const zRunwayModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zRunwayModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zRunwayModelFieldValue,
|
||||
});
|
||||
export type RunwayModelFieldValue = z.infer<typeof zRunwayModelFieldValue>;
|
||||
export type RunwayModelFieldInputInstance = z.infer<typeof zRunwayModelFieldInputInstance>;
|
||||
export type RunwayModelFieldInputTemplate = z.infer<typeof zRunwayModelFieldInputTemplate>;
|
||||
export const isRunwayModelFieldInputInstance = buildInstanceTypeGuard(zRunwayModelFieldInputInstance);
|
||||
export const isRunwayModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<RunwayModelFieldInputTemplate>('RunwayModelField');
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@@ -1261,6 +1931,31 @@ export const zStatefulFieldValue = z.union([
|
||||
zImageFieldCollectionValue,
|
||||
zBoardFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zFluxMainModelFieldValue,
|
||||
zSD3MainModelFieldValue,
|
||||
zCogView4MainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zLLaVAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zCLIPLEmbedModelFieldValue,
|
||||
zCLIPGEmbedModelFieldValue,
|
||||
zControlLoRAModelFieldValue,
|
||||
zSigLipModelFieldValue,
|
||||
zFluxReduxModelFieldValue,
|
||||
zImagen3ModelFieldValue,
|
||||
zImagen4ModelFieldValue,
|
||||
zFluxKontextModelFieldValue,
|
||||
zChatGPT4oModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zFloatGeneratorFieldValue,
|
||||
@@ -1288,6 +1983,22 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zImageFieldCollectionInputInstance,
|
||||
zBoardFieldInputInstance,
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zFluxMainModelFieldInputInstance,
|
||||
zSD3MainModelFieldInputInstance,
|
||||
zCogView4MainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
zLLaVAModelFieldInputInstance,
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zSpandrelImageToImageModelFieldInputInstance,
|
||||
zT5EncoderModelFieldInputInstance,
|
||||
zFluxVAEModelFieldInputInstance,
|
||||
zCLIPEmbedModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
zFloatGeneratorFieldInputInstance,
|
||||
@@ -1314,6 +2025,33 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zImageFieldCollectionInputTemplate,
|
||||
zBoardFieldInputTemplate,
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zFluxMainModelFieldInputTemplate,
|
||||
zSD3MainModelFieldInputTemplate,
|
||||
zCogView4MainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
zLoRAModelFieldInputTemplate,
|
||||
zLLaVAModelFieldInputTemplate,
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zSpandrelImageToImageModelFieldInputTemplate,
|
||||
zT5EncoderModelFieldInputTemplate,
|
||||
zFluxVAEModelFieldInputTemplate,
|
||||
zCLIPEmbedModelFieldInputTemplate,
|
||||
zCLIPLEmbedModelFieldInputTemplate,
|
||||
zCLIPGEmbedModelFieldInputTemplate,
|
||||
zControlLoRAModelFieldInputTemplate,
|
||||
zSigLipModelFieldInputTemplate,
|
||||
zFluxReduxModelFieldInputTemplate,
|
||||
zImagen3ModelFieldInputTemplate,
|
||||
zImagen4ModelFieldInputTemplate,
|
||||
zChatGPT4oModelFieldInputTemplate,
|
||||
zFluxKontextModelFieldInputTemplate,
|
||||
zVeo3ModelFieldInputTemplate,
|
||||
zRunwayModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
@@ -1341,6 +2079,19 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zImageFieldCollectionOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zFluxMainModelFieldOutputTemplate,
|
||||
zSD3MainModelFieldOutputTemplate,
|
||||
zCogView4MainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
zLoRAModelFieldOutputTemplate,
|
||||
zLLaVAModelFieldOutputTemplate,
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
zSpandrelImageToImageModelFieldOutputTemplate,
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
zFloatGeneratorFieldOutputTemplate,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { XYPosition as ReactFlowXYPosition } from '@xyflow/react';
|
||||
import type { WorkflowCategory, WorkflowV4, XYPosition } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowCategory, WorkflowV3, XYPosition } from 'features/nodes/types/workflow';
|
||||
import type { S } from 'services/api/types';
|
||||
import type { Equals, Extends } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -14,5 +14,5 @@ import { describe, test } from 'vitest';
|
||||
describe('Workflow types', () => {
|
||||
test('XYPosition', () => assert<Equals<XYPosition, ReactFlowXYPosition>>());
|
||||
test('WorkflowCategory', () => assert<Equals<WorkflowCategory, S['WorkflowCategory']>>());
|
||||
test('WorkflowV4', () => assert<Extends<SetRequired<WorkflowV4, 'id'>, S['Workflow']>>());
|
||||
test('WorkflowV3', () => assert<Extends<SetRequired<WorkflowV3, 'id'>, S['Workflow']>>());
|
||||
});
|
||||
|
||||
@@ -57,41 +57,6 @@ const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({
|
||||
const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]);
|
||||
// #endregion
|
||||
|
||||
export type WorkflowOutputField = {
|
||||
kind: 'output';
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
userLabel: string | null;
|
||||
node_id: string;
|
||||
field_name: string;
|
||||
user_label: string | null;
|
||||
};
|
||||
|
||||
const zWorkflowOutputField = z
|
||||
.object({
|
||||
kind: z.literal('output').default('output'),
|
||||
nodeId: z.string().trim().min(1).optional(),
|
||||
node_id: z.string().trim().min(1),
|
||||
fieldName: z.string().trim().min(1).optional(),
|
||||
field_name: z.string().trim().min(1),
|
||||
userLabel: z.string().nullable().optional(),
|
||||
user_label: z.string().nullable().optional(),
|
||||
})
|
||||
.transform((field) => {
|
||||
const nodeId = field.nodeId ?? field.node_id;
|
||||
const fieldName = field.fieldName ?? field.field_name;
|
||||
const userLabel = field.userLabel ?? field.user_label ?? null;
|
||||
return {
|
||||
kind: 'output',
|
||||
nodeId,
|
||||
fieldName,
|
||||
userLabel,
|
||||
node_id: nodeId,
|
||||
field_name: fieldName,
|
||||
user_label: userLabel,
|
||||
} as WorkflowOutputField;
|
||||
});
|
||||
|
||||
// #region Workflow Builder
|
||||
const zElementId = z.string().trim().min(1);
|
||||
export type ElementId = z.infer<typeof zElementId>;
|
||||
@@ -398,7 +363,7 @@ const zValidatedBuilderForm = zBuilderForm
|
||||
//# endregion
|
||||
|
||||
// #region Workflow
|
||||
const workflowBaseShape = {
|
||||
export const zWorkflowV3 = z.object({
|
||||
id: z.string().min(1).optional(),
|
||||
name: z.string(),
|
||||
author: z.string(),
|
||||
@@ -410,27 +375,13 @@ const workflowBaseShape = {
|
||||
nodes: z.array(zWorkflowNode),
|
||||
edges: z.array(zWorkflowEdge),
|
||||
exposedFields: z.array(zFieldIdentifier),
|
||||
// Use the validated form schema!
|
||||
form: zValidatedBuilderForm,
|
||||
is_published: z.boolean().nullish(),
|
||||
} as const;
|
||||
|
||||
export const zWorkflowV3 = z.object({
|
||||
...workflowBaseShape,
|
||||
meta: z.object({
|
||||
category: zWorkflowCategory.default('user'),
|
||||
version: z.literal('3.0.0'),
|
||||
}),
|
||||
// Use the validated form schema!
|
||||
form: zValidatedBuilderForm,
|
||||
is_published: z.boolean().nullish(),
|
||||
});
|
||||
export type WorkflowV3 = z.infer<typeof zWorkflowV3>;
|
||||
|
||||
export const zWorkflowV4 = z.object({
|
||||
...workflowBaseShape,
|
||||
meta: z.object({
|
||||
category: zWorkflowCategory.default('user'),
|
||||
version: z.literal('4.0.0'),
|
||||
}),
|
||||
output_fields: z.array(zWorkflowOutputField).max(1).default([]),
|
||||
});
|
||||
export type WorkflowV4 = z.infer<typeof zWorkflowV4>;
|
||||
// #endregion
|
||||
|
||||
@@ -7,15 +7,12 @@ import type { Templates } from 'features/nodes/store/types';
|
||||
import type { BoardField } from 'features/nodes/types/common';
|
||||
import type { BoardFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isBatchNodeType, isGeneratorNodeType } from 'features/nodes/types/invocation';
|
||||
import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const log = logger('workflows');
|
||||
|
||||
type BoardFieldResolver = (field: BoardFieldInputInstance) => BoardField | undefined;
|
||||
|
||||
const getBoardField = (field: BoardFieldInputInstance, state: RootState): BoardField | undefined => {
|
||||
// Translate the UI value to the graph value. See note in BoardFieldInputComponent for more info.
|
||||
const { value } = field;
|
||||
@@ -37,70 +34,19 @@ const getBoardField = (field: BoardFieldInputInstance, state: RootState): BoardF
|
||||
return value;
|
||||
};
|
||||
|
||||
const defaultBoardFieldResolver: BoardFieldResolver = (field) => {
|
||||
const { value } = field;
|
||||
if (!value || value === 'none' || value === 'auto') {
|
||||
return undefined;
|
||||
}
|
||||
return value;
|
||||
};
|
||||
/**
|
||||
* Builds a graph from the node editor state.
|
||||
*/
|
||||
export const buildNodesGraph = (state: RootState, templates: Templates): Required<Graph> => {
|
||||
const { nodes, edges } = selectNodesSlice(state);
|
||||
|
||||
type NodeLike = {
|
||||
id: string;
|
||||
type?: string;
|
||||
data?: unknown;
|
||||
};
|
||||
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter(isExecutableNode);
|
||||
|
||||
type InvocationNodeLike = NodeLike & {
|
||||
data: InvocationNodeData;
|
||||
};
|
||||
|
||||
const isInvocationNodeLike = (node: NodeLike): node is InvocationNodeLike => {
|
||||
if (node.type !== 'invocation') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!node.data || typeof node.data !== 'object') {
|
||||
return false;
|
||||
}
|
||||
|
||||
const data = node.data as Partial<InvocationNodeData>;
|
||||
return Boolean(data.inputs && data.type && data.useCache !== undefined && data.isIntermediate !== undefined);
|
||||
};
|
||||
|
||||
type EdgeLike = {
|
||||
type?: string;
|
||||
source: string;
|
||||
target: string;
|
||||
sourceHandle?: string | null;
|
||||
targetHandle?: string | null;
|
||||
};
|
||||
|
||||
type BuildInvocationGraphArgs = {
|
||||
nodes: NodeLike[];
|
||||
edges: EdgeLike[];
|
||||
templates: Templates;
|
||||
resolveBoardField?: BoardFieldResolver;
|
||||
graphId?: string;
|
||||
};
|
||||
|
||||
export const buildInvocationGraph = ({
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
resolveBoardField = defaultBoardFieldResolver,
|
||||
graphId,
|
||||
}: BuildInvocationGraphArgs): Required<Graph> => {
|
||||
const invocationNodes = nodes.filter(isInvocationNodeLike);
|
||||
|
||||
const executableNodes = invocationNodes.filter((node) => {
|
||||
const nodeType = node.data.type;
|
||||
return !isBatchNodeType(nodeType) && !isGeneratorNodeType(nodeType);
|
||||
});
|
||||
|
||||
const parsedNodes = executableNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
|
||||
// Reduce the node editor nodes into invocation graph nodes
|
||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
|
||||
const { id, data } = node;
|
||||
const { type } = data;
|
||||
const { type, inputs, isIntermediate } = data;
|
||||
|
||||
const nodeTemplate = templates[type];
|
||||
if (!nodeTemplate) {
|
||||
@@ -108,8 +54,9 @@ export const buildInvocationGraph = ({
|
||||
return nodesAccumulator;
|
||||
}
|
||||
|
||||
// Transform each node's inputs to simple key-value pairs
|
||||
const transformedInputs = reduce(
|
||||
data.inputs,
|
||||
inputs,
|
||||
(inputsAccumulator, input, name) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[name];
|
||||
if (!fieldTemplate) {
|
||||
@@ -117,7 +64,7 @@ export const buildInvocationGraph = ({
|
||||
return inputsAccumulator;
|
||||
}
|
||||
if (isBoardFieldInputTemplate(fieldTemplate) && isBoardFieldInputInstance(input)) {
|
||||
inputsAccumulator[name] = resolveBoardField(input);
|
||||
inputsAccumulator[name] = getBoardField(input, state);
|
||||
} else {
|
||||
inputsAccumulator[name] = input.value;
|
||||
}
|
||||
@@ -127,15 +74,18 @@ export const buildInvocationGraph = ({
|
||||
{} as Record<Exclude<string, 'id' | 'type'>, unknown>
|
||||
);
|
||||
|
||||
transformedInputs['use_cache'] = data.useCache;
|
||||
// add reserved use_cache
|
||||
transformedInputs['use_cache'] = node.data.useCache;
|
||||
|
||||
// Build this specific node
|
||||
const graphNode = {
|
||||
type,
|
||||
id,
|
||||
...transformedInputs,
|
||||
is_intermediate: data.isIntermediate,
|
||||
} as AnyInvocation;
|
||||
is_intermediate: isIntermediate,
|
||||
};
|
||||
|
||||
// Add it to the nodes object
|
||||
Object.assign(nodesAccumulator, {
|
||||
[id]: graphNode,
|
||||
});
|
||||
@@ -143,59 +93,58 @@ export const buildInvocationGraph = ({
|
||||
return nodesAccumulator;
|
||||
}, {});
|
||||
|
||||
const executableNodeIds = executableNodes.map(({ id }) => id);
|
||||
const filteredNodeIds = filteredNodes.map(({ id }) => id);
|
||||
|
||||
const parsedEdges = edges
|
||||
// skip out the "dummy" edges between collapsed nodes
|
||||
const filteredEdges = edges
|
||||
.filter((edge) => edge.type !== 'collapsed')
|
||||
.filter((edge) => executableNodeIds.includes(edge.source) && executableNodeIds.includes(edge.target))
|
||||
.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
|
||||
|
||||
if (!sourceHandle || !targetHandle) {
|
||||
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
|
||||
return edgesAccumulator;
|
||||
}
|
||||
|
||||
edgesAccumulator.push({
|
||||
source: {
|
||||
node_id: source,
|
||||
field: sourceHandle,
|
||||
},
|
||||
destination: {
|
||||
node_id: target,
|
||||
field: targetHandle,
|
||||
},
|
||||
});
|
||||
// Reduce the node editor edges into invocation graph edges
|
||||
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
|
||||
if (!sourceHandle || !targetHandle) {
|
||||
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
|
||||
return edgesAccumulator;
|
||||
}, []);
|
||||
|
||||
parsedEdges.forEach((edge) => {
|
||||
const destinationNode = parsedNodes[edge.destination.node_id];
|
||||
if (!destinationNode) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Format the edges and add to the edges array
|
||||
edgesAccumulator.push({
|
||||
source: {
|
||||
node_id: source,
|
||||
field: sourceHandle,
|
||||
},
|
||||
destination: {
|
||||
node_id: target,
|
||||
field: targetHandle,
|
||||
},
|
||||
});
|
||||
|
||||
return edgesAccumulator;
|
||||
}, []);
|
||||
|
||||
/**
|
||||
* Omit all inputs that have edges connected.
|
||||
*
|
||||
* Fixes edge case where the user has connected an input, but also provided an invalid explicit,
|
||||
* value.
|
||||
*
|
||||
* In this edge case, pydantic will invalidate the node based on the invalid explicit value,
|
||||
* even though the actual value that will be used comes from the connection.
|
||||
*/
|
||||
parsedEdges.forEach((edge) => {
|
||||
const destination_node = parsedNodes[edge.destination.node_id];
|
||||
const field = edge.destination.field;
|
||||
parsedNodes[edge.destination.node_id] = omit(destinationNode, field) as AnyInvocation;
|
||||
parsedNodes[edge.destination.node_id] = omit(destination_node, field) as AnyInvocation;
|
||||
});
|
||||
|
||||
return {
|
||||
id: graphId ?? uuidv4(),
|
||||
// Assemble!
|
||||
const graph = {
|
||||
id: uuidv4(),
|
||||
nodes: parsedNodes,
|
||||
edges: parsedEdges,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds a graph from the node editor state.
|
||||
*/
|
||||
export const buildNodesGraph = (state: RootState, templates: Templates): Required<Graph> => {
|
||||
const { nodes, edges } = selectNodesSlice(state);
|
||||
|
||||
return buildInvocationGraph({
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
resolveBoardField: (field) => getBoardField(field, state),
|
||||
});
|
||||
return graph;
|
||||
};
|
||||
|
||||
@@ -205,7 +205,7 @@ export const getInfill = (
|
||||
assert(false, 'Unknown infill method');
|
||||
};
|
||||
|
||||
export const CANVAS_OUTPUT_PREFIX = 'canvas_output';
|
||||
const CANVAS_OUTPUT_PREFIX = 'canvas_output';
|
||||
|
||||
export const isMainModelWithoutUnet = (modelLoader: Invocation<MainModelLoaderNodes>) => {
|
||||
return (
|
||||
|
||||
@@ -9,9 +9,36 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
FloatField: 0,
|
||||
ImageField: undefined,
|
||||
IntegerField: 0,
|
||||
IPAdapterModelField: undefined,
|
||||
LoRAModelField: undefined,
|
||||
LLaVAModelField: undefined,
|
||||
ModelIdentifierField: undefined,
|
||||
MainModelField: undefined,
|
||||
SchedulerField: 'dpmpp_3m_k',
|
||||
SDXLMainModelField: undefined,
|
||||
FluxMainModelField: undefined,
|
||||
SD3MainModelField: undefined,
|
||||
CogView4MainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
SpandrelImageToImageModelField: undefined,
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
T5EncoderModelField: undefined,
|
||||
FluxVAEModelField: undefined,
|
||||
CLIPEmbedModelField: undefined,
|
||||
CLIPLEmbedModelField: undefined,
|
||||
CLIPGEmbedModelField: undefined,
|
||||
ControlLoRAModelField: undefined,
|
||||
SigLipModelField: undefined,
|
||||
FluxReduxModelField: undefined,
|
||||
Imagen3ModelField: undefined,
|
||||
Imagen4ModelField: undefined,
|
||||
ChatGPT4oModelField: undefined,
|
||||
FluxKontextModelField: undefined,
|
||||
Veo3ModelField: undefined,
|
||||
RunwayModelField: undefined,
|
||||
FloatGeneratorField: undefined,
|
||||
IntegerGeneratorField: undefined,
|
||||
StringGeneratorField: undefined,
|
||||
|
||||
@@ -3,26 +3,53 @@ import { FieldParseError } from 'features/nodes/types/error';
|
||||
import type {
|
||||
BoardFieldInputTemplate,
|
||||
BooleanFieldInputTemplate,
|
||||
ChatGPT4oModelFieldInputTemplate,
|
||||
CLIPEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelFieldInputTemplate,
|
||||
CLIPLEmbedModelFieldInputTemplate,
|
||||
CogView4MainModelFieldInputTemplate,
|
||||
ColorFieldInputTemplate,
|
||||
ControlLoRAModelFieldInputTemplate,
|
||||
ControlNetModelFieldInputTemplate,
|
||||
EnumFieldInputTemplate,
|
||||
FieldInputTemplate,
|
||||
FieldType,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
FloatFieldInputTemplate,
|
||||
FloatGeneratorFieldInputTemplate,
|
||||
FluxKontextModelFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
FluxReduxModelFieldInputTemplate,
|
||||
FluxVAEModelFieldInputTemplate,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
ImageGeneratorFieldInputTemplate,
|
||||
Imagen3ModelFieldInputTemplate,
|
||||
Imagen4ModelFieldInputTemplate,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IntegerGeneratorFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
LLaVAModelFieldInputTemplate,
|
||||
LoRAModelFieldInputTemplate,
|
||||
MainModelFieldInputTemplate,
|
||||
ModelIdentifierFieldInputTemplate,
|
||||
RunwayModelFieldInputTemplate,
|
||||
SchedulerFieldInputTemplate,
|
||||
SD3MainModelFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SigLipModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
StringGeneratorFieldInputTemplate,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
T5EncoderModelFieldInputTemplate,
|
||||
VAEModelFieldInputTemplate,
|
||||
Veo3ModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
getFloatGeneratorArithmeticSequenceDefaults,
|
||||
@@ -275,6 +302,373 @@ const buildModelIdentifierFieldInputTemplate: FieldInputTemplateBuilder<ModelIde
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SDXLMainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxMainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SD3MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCogView4MainModelFieldInputTemplate: FieldInputTemplateBuilder<CogView4MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CogView4MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SDXLRefinerModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: VAEModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5EncoderModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: T5EncoderModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPLEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPLEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPLEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPGEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<ControlLoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: ControlLoRAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLLaVAModelFieldInputTemplate: FieldInputTemplateBuilder<LLaVAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: LLaVAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxVAEModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: LoRAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<ControlNetModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: ControlNetModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<IPAdapterModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: IPAdapterModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapterModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: T2IAdapterModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
SpandrelImageToImageModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, fieldType }) => {
|
||||
const template: SpandrelImageToImageModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSigLipModelFieldInputTemplate: FieldInputTemplateBuilder<SigLipModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SigLipModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxReduxModelFieldInputTemplate: FieldInputTemplateBuilder<FluxReduxModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxReduxModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImagen3ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen3ModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: Imagen3ModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImagen4ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen4ModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: Imagen4ModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxKontextModelFieldInputTemplate: FieldInputTemplateBuilder<FluxKontextModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxKontextModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVeo3ModelFieldInputTemplate: FieldInputTemplateBuilder<Veo3ModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: Veo3ModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRunwayModelFieldInputTemplate: FieldInputTemplateBuilder<RunwayModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: RunwayModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildChatGPT4oModelFieldInputTemplate: FieldInputTemplateBuilder<ChatGPT4oModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: ChatGPT4oModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -449,41 +843,56 @@ const buildImageGeneratorFieldInputTemplate: FieldInputTemplateBuilder<ImageGene
|
||||
return template;
|
||||
};
|
||||
|
||||
const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
|
||||
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
|
||||
BoardField: buildBoardFieldInputTemplate,
|
||||
BooleanField: buildBooleanFieldInputTemplate,
|
||||
ColorField: buildColorFieldInputTemplate,
|
||||
ControlNetModelField: buildControlNetModelFieldInputTemplate,
|
||||
EnumField: buildEnumFieldInputTemplate,
|
||||
FloatField: buildFloatFieldInputTemplate,
|
||||
ImageField: buildImageFieldInputTemplate,
|
||||
IntegerField: buildIntegerFieldInputTemplate,
|
||||
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
|
||||
LoRAModelField: buildLoRAModelFieldInputTemplate,
|
||||
LLaVAModelField: buildLLaVAModelFieldInputTemplate,
|
||||
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
|
||||
MainModelField: buildMainModelFieldInputTemplate,
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
|
||||
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
|
||||
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
|
||||
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
|
||||
SigLipModelField: buildSigLipModelFieldInputTemplate,
|
||||
FluxReduxModelField: buildFluxReduxModelFieldInputTemplate,
|
||||
Imagen3ModelField: buildImagen3ModelFieldInputTemplate,
|
||||
Imagen4ModelField: buildImagen4ModelFieldInputTemplate,
|
||||
ChatGPT4oModelField: buildChatGPT4oModelFieldInputTemplate,
|
||||
FluxKontextModelField: buildFluxKontextModelFieldInputTemplate,
|
||||
Veo3ModelField: buildVeo3ModelFieldInputTemplate,
|
||||
RunwayModelField: buildRunwayModelFieldInputTemplate,
|
||||
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
|
||||
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
|
||||
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
|
||||
ImageGeneratorField: buildImageGeneratorFieldInputTemplate,
|
||||
};
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
fieldSchema: InvocationFieldSchema,
|
||||
fieldName: string,
|
||||
fieldType: FieldType
|
||||
): FieldInputTemplate => {
|
||||
const {
|
||||
input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
ui_order,
|
||||
ui_choice_labels,
|
||||
orig_required: required,
|
||||
ui_model_base,
|
||||
ui_model_type,
|
||||
ui_model_variant,
|
||||
ui_model_format,
|
||||
} = fieldSchema;
|
||||
const { input, ui_hidden, ui_component, ui_type, ui_order, ui_choice_labels, orig_required: required } = fieldSchema;
|
||||
|
||||
// This is the base field template that is common to all fields. The builder function will add all other
|
||||
// properties to this template.
|
||||
@@ -499,10 +908,6 @@ export const buildFieldInputTemplate = (
|
||||
ui_type,
|
||||
ui_order,
|
||||
ui_choice_labels,
|
||||
ui_model_base,
|
||||
ui_model_type,
|
||||
ui_model_variant,
|
||||
ui_model_format,
|
||||
};
|
||||
|
||||
if (isStatefulFieldType(fieldType)) {
|
||||
|
||||
@@ -6,8 +6,8 @@ import { pick } from 'es-toolkit/compat';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import i18n from 'i18n';
|
||||
import { useCallback } from 'react';
|
||||
import { fromZodError } from 'zod-validation-error/v4';
|
||||
@@ -23,27 +23,21 @@ const workflowKeys = [
|
||||
'tags',
|
||||
'notes',
|
||||
'exposedFields',
|
||||
'output_fields',
|
||||
'meta',
|
||||
'id',
|
||||
'form',
|
||||
] satisfies (keyof WorkflowV4)[];
|
||||
] satisfies (keyof WorkflowV3)[];
|
||||
|
||||
export const buildWorkflowFast = (nodesState: NodesState): WorkflowV4 => {
|
||||
export const buildWorkflowFast = (nodesState: NodesState): WorkflowV3 => {
|
||||
const { nodes, edges, ...rest } = nodesState;
|
||||
const clonedWorkflow = pick(rest, workflowKeys);
|
||||
|
||||
const newWorkflow: WorkflowV4 = {
|
||||
const newWorkflow: WorkflowV3 = {
|
||||
...clonedWorkflow,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
};
|
||||
|
||||
newWorkflow.meta = {
|
||||
...newWorkflow.meta,
|
||||
version: '4.0.0',
|
||||
};
|
||||
|
||||
for (const node of nodes) {
|
||||
if (isInvocationNode(node) && node.type) {
|
||||
const { id, type, data, position } = node;
|
||||
@@ -67,12 +61,12 @@ export const buildWorkflowFast = (nodesState: NodesState): WorkflowV4 => {
|
||||
return deepClone(newWorkflow);
|
||||
};
|
||||
|
||||
export const buildWorkflowWithValidation = (nodesState: NodesState): WorkflowV4 | null => {
|
||||
export const buildWorkflowWithValidation = (nodesState: NodesState): WorkflowV3 | null => {
|
||||
// builds what really, really should be a valid workflow
|
||||
const workflowToValidate = buildWorkflowFast(nodesState);
|
||||
|
||||
// but bc we are storing this in the DB, let's be extra sure
|
||||
const result = zWorkflowV4.safeParse(workflowToValidate);
|
||||
const result = zWorkflowV3.safeParse(workflowToValidate);
|
||||
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
@@ -86,7 +80,7 @@ export const buildWorkflowWithValidation = (nodesState: NodesState): WorkflowV4
|
||||
return result.data;
|
||||
};
|
||||
|
||||
export const useBuildWorkflowFast = (): (() => WorkflowV4) => {
|
||||
export const useBuildWorkflowFast = (): (() => WorkflowV3) => {
|
||||
const store = useAppStore();
|
||||
const buildWorkflow = useCallback(() => {
|
||||
const nodesState = selectNodesSlice(store.getState());
|
||||
|
||||
@@ -4,7 +4,7 @@ import { forEach } from 'es-toolkit/compat';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { getDefaultForm } from 'features/nodes/types/workflow';
|
||||
import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
@@ -19,24 +19,23 @@ const log = logger('workflows');
|
||||
* @param autoLayout Whether to auto-layout the nodes using `dagre`. If false, nodes will be simply stacked on top of one another with an offset.
|
||||
* @returns The workflow.
|
||||
*/
|
||||
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV4 => {
|
||||
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
|
||||
const templates = $templates.get();
|
||||
|
||||
// Initialize the workflow
|
||||
const workflow: WorkflowV4 = {
|
||||
const workflow: WorkflowV3 = {
|
||||
name: '',
|
||||
author: '',
|
||||
contact: '',
|
||||
description: '',
|
||||
meta: {
|
||||
category: 'user',
|
||||
version: '4.0.0',
|
||||
version: '3.0.0',
|
||||
},
|
||||
notes: '',
|
||||
tags: '',
|
||||
version: '',
|
||||
exposedFields: [],
|
||||
output_fields: [],
|
||||
edges: [],
|
||||
nodes: [],
|
||||
form: getDefaultForm(),
|
||||
|
||||
@@ -10,8 +10,8 @@ import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1';
|
||||
import type { StatelessFieldType } from 'features/nodes/types/v2/field';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import type { WorkflowOutputField, WorkflowV3, WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3, zWorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { t } from 'i18next';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -76,72 +76,12 @@ const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => {
|
||||
return zWorkflowV3.parse(workflowToMigrate);
|
||||
};
|
||||
|
||||
const normalizeOutputField = (field: unknown): WorkflowOutputField | null => {
|
||||
if (!field || typeof field !== 'object') {
|
||||
return null;
|
||||
}
|
||||
const maybeField = field as Record<string, unknown>;
|
||||
const nodeId =
|
||||
typeof maybeField.nodeId === 'string'
|
||||
? maybeField.nodeId
|
||||
: typeof maybeField.node_id === 'string'
|
||||
? maybeField.node_id
|
||||
: null;
|
||||
const fieldName =
|
||||
typeof maybeField.fieldName === 'string'
|
||||
? maybeField.fieldName
|
||||
: typeof maybeField.field_name === 'string'
|
||||
? maybeField.field_name
|
||||
: null;
|
||||
if (!nodeId || !fieldName) {
|
||||
return null;
|
||||
}
|
||||
const userLabel =
|
||||
typeof maybeField.userLabel === 'string'
|
||||
? maybeField.userLabel
|
||||
: typeof maybeField.userLabel === 'object'
|
||||
? null
|
||||
: typeof maybeField.user_label === 'string'
|
||||
? maybeField.user_label
|
||||
: null;
|
||||
|
||||
return {
|
||||
kind: 'output',
|
||||
nodeId,
|
||||
fieldName,
|
||||
userLabel,
|
||||
node_id: nodeId,
|
||||
field_name: fieldName,
|
||||
user_label: userLabel,
|
||||
} satisfies WorkflowOutputField;
|
||||
};
|
||||
|
||||
const migrateV3toV4 = (workflowToMigrate: WorkflowV3 | Record<string, unknown>): WorkflowV4 => {
|
||||
const rawOutputs = Array.isArray((workflowToMigrate as Record<string, unknown>).output_fields)
|
||||
? ((workflowToMigrate as Record<string, unknown>).output_fields as unknown[])
|
||||
: [];
|
||||
|
||||
const normalizedOutputs = rawOutputs
|
||||
.map(normalizeOutputField)
|
||||
.filter((field): field is WorkflowOutputField => field !== null)
|
||||
.slice(0, 1);
|
||||
|
||||
const migrated = {
|
||||
...(workflowToMigrate as Record<string, unknown>),
|
||||
output_fields: normalizedOutputs,
|
||||
} as WorkflowV4;
|
||||
|
||||
migrated.meta.version = '4.0.0';
|
||||
|
||||
return zWorkflowV4.parse(migrated);
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses a workflow and migrates it to the latest version if necessary.
|
||||
*
|
||||
* This function will return a new workflow object, so the original workflow is not modified.
|
||||
*/
|
||||
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV4 => {
|
||||
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
|
||||
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
|
||||
|
||||
if (!workflowVersionResult.success) {
|
||||
@@ -160,13 +100,8 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV4 => {
|
||||
workflow = migrateV2toV3(v2);
|
||||
}
|
||||
|
||||
if (get(workflow, 'meta.version') === '3.0.0') {
|
||||
const v3 = zWorkflowV3.parse(workflow);
|
||||
workflow = migrateV3toV4(v3);
|
||||
}
|
||||
|
||||
// We should now have a V4 workflow
|
||||
const migratedWorkflow = zWorkflowV4.parse(workflow);
|
||||
// We should now have a V3 workflow
|
||||
const migratedWorkflow = zWorkflowV3.parse(workflow);
|
||||
|
||||
return migratedWorkflow;
|
||||
};
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { get } from 'es-toolkit/compat';
|
||||
import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { getDefaultForm } from 'features/nodes/types/workflow';
|
||||
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
//TODO(psyche): Test workflow validation for form builder fields
|
||||
describe('validateWorkflow', () => {
|
||||
const getWorkflow = (): WorkflowV4 => ({
|
||||
const getWorkflow = (): WorkflowV3 => ({
|
||||
name: '',
|
||||
author: '',
|
||||
description: '',
|
||||
@@ -16,9 +16,8 @@ describe('validateWorkflow', () => {
|
||||
tags: '',
|
||||
notes: '',
|
||||
exposedFields: [],
|
||||
output_fields: [],
|
||||
form: getDefaultForm(),
|
||||
meta: { version: '4.0.0', category: 'user' },
|
||||
meta: { version: '3.0.0', category: 'user' },
|
||||
nodes: [
|
||||
{
|
||||
id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad',
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
isModelFieldType,
|
||||
isModelIdentifierFieldInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import {
|
||||
buildNodeFieldElement,
|
||||
getDefaultForm,
|
||||
@@ -36,7 +36,7 @@ type ValidateWorkflowArgs = {
|
||||
};
|
||||
|
||||
type ValidateWorkflowResult = {
|
||||
workflow: WorkflowV4;
|
||||
workflow: WorkflowV3;
|
||||
warnings: WorkflowWarning[];
|
||||
};
|
||||
|
||||
@@ -233,65 +233,6 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
// Remove invalid edges
|
||||
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
|
||||
|
||||
if (_workflow.output_fields.length > 0) {
|
||||
const validOutputs: typeof _workflow.output_fields = [];
|
||||
for (const output of _workflow.output_fields) {
|
||||
const node = nodes.find(({ id }) => id === output.nodeId);
|
||||
if (!node) {
|
||||
warnings.push({
|
||||
message: t('workflows.builder.outputFieldMissingNode', {
|
||||
defaultValue: 'Removed workflow output referencing a missing node.',
|
||||
}),
|
||||
data: parseify(output),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (!isWorkflowInvocationNode(node)) {
|
||||
warnings.push({
|
||||
message: t('workflows.builder.outputFieldInvalidNode', {
|
||||
defaultValue: 'Removed workflow output because the selected node is not an invocation node.',
|
||||
}),
|
||||
data: parseify({ output, node }),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
warnings.push({
|
||||
message: t('workflows.builder.outputFieldMissingTemplate', {
|
||||
defaultValue: 'Removed workflow output because the node template is missing.',
|
||||
}),
|
||||
data: parseify({ output, node }),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
const field = template.outputs[output.fieldName];
|
||||
if (!field) {
|
||||
warnings.push({
|
||||
message: t('workflows.builder.outputFieldMissingHandle', {
|
||||
defaultValue: 'Removed workflow output because the node no longer exposes that field.',
|
||||
}),
|
||||
data: parseify({ output, node, template }),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (!(field.type.name === 'ImageField' && field.type.cardinality !== 'COLLECTION')) {
|
||||
warnings.push({
|
||||
message: t('workflows.builder.outputFieldMustBeImage', {
|
||||
defaultValue: 'Removed workflow output because it is not an image field.',
|
||||
}),
|
||||
data: parseify({ output, node, field }),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
validOutputs.push(output);
|
||||
if (validOutputs.length >= 1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_workflow.output_fields = validOutputs.slice(0, 1);
|
||||
}
|
||||
|
||||
// Migrated exposed fields to form elements if they exist and the form does not
|
||||
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!
|
||||
if (_workflow.exposedFields.length > 0 && getIsFormEmpty(_workflow.form)) {
|
||||
|
||||
@@ -16,7 +16,6 @@ import { SelectObject } from 'features/controlLayers/components/SelectObject/Sel
|
||||
import { StagingAreaContextProvider } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
|
||||
import { Transform } from 'features/controlLayers/components/Transform/Transform';
|
||||
import { TriggerWorkflow } from 'features/controlLayers/components/TriggerWorkflow/TriggerWorkflow';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
@@ -113,7 +112,6 @@ export const CanvasWorkspacePanel = memo(() => {
|
||||
<Flex position="absolute" bottom={4}>
|
||||
<CanvasManagerProviderGate>
|
||||
<Filter />
|
||||
<TriggerWorkflow />
|
||||
<Transform />
|
||||
<SelectObject />
|
||||
</CanvasManagerProviderGate>
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
|
||||
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { useLoadWorkflowFromFile } from 'features/workflowLibrary/hooks/useLoadWorkflowFromFile';
|
||||
import { useLoadWorkflowFromImage } from 'features/workflowLibrary/hooks/useLoadWorkflowFromImage';
|
||||
import { useLoadWorkflowFromLibrary } from 'features/workflowLibrary/hooks/useLoadWorkflowFromLibrary';
|
||||
@@ -13,7 +13,7 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type LoadWorkflowOptions = {
|
||||
onSuccess?: (workflow: WorkflowV4) => void;
|
||||
onSuccess?: (workflow: WorkflowV3) => void;
|
||||
onError?: () => void;
|
||||
onCompleted?: () => void;
|
||||
};
|
||||
|
||||
@@ -15,7 +15,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $workflowLibraryCategoriesOptions } from 'features/nodes/store/workflowLibrarySlice';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { isDraftWorkflow, useCreateLibraryWorkflow } from 'features/workflowLibrary/hooks/useCreateNewWorkflow';
|
||||
import { t } from 'i18next';
|
||||
import { atom, computed } from 'nanostores';
|
||||
@@ -28,7 +28,7 @@ import { assert } from 'tsafe';
|
||||
*
|
||||
* This state is used to determine whether or not the modal is open.
|
||||
*/
|
||||
const $workflowToSave = atom<WorkflowV4 | null>(null);
|
||||
const $workflowToSave = atom<WorkflowV3 | null>(null);
|
||||
|
||||
/**
|
||||
* Whether or not the modal is open. It is open if there is a workflow to save.
|
||||
@@ -40,7 +40,7 @@ const $workflowToSave = atom<WorkflowV4 | null>(null);
|
||||
*/
|
||||
const $isOpen = computed($workflowToSave, (val) => val !== null);
|
||||
|
||||
const getInitialName = (workflow: WorkflowV4): string => {
|
||||
const getInitialName = (workflow: WorkflowV3): string => {
|
||||
if (!workflow.id) {
|
||||
// If the workflow has no ID, that means it's a new workflow that has never been saved to the server. In this case,
|
||||
// we should use whatever the user has entered in the workflow name field.
|
||||
@@ -60,7 +60,7 @@ const getInitialName = (workflow: WorkflowV4): string => {
|
||||
* The workflow object is deep cloned to prevent any changes to the original workflow object.
|
||||
* @param workflow The workflow to save as a new workflow.
|
||||
*/
|
||||
export const saveWorkflowAs = (workflow: WorkflowV4) => {
|
||||
export const saveWorkflowAs = (workflow: WorkflowV3) => {
|
||||
$workflowToSave.set(deepClone(workflow));
|
||||
};
|
||||
|
||||
@@ -82,7 +82,7 @@ export const SaveWorkflowAsDialog = () => {
|
||||
);
|
||||
};
|
||||
|
||||
const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV4; cancelRef: RefObject<HTMLButtonElement> }) => {
|
||||
const Content = memo(({ workflow, cancelRef }: { workflow: WorkflowV3; cancelRef: RefObject<HTMLButtonElement> }) => {
|
||||
const workflowCategories = useStore($workflowLibraryCategoriesOptions);
|
||||
const [name, setName] = useState(() => {
|
||||
if (workflow) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
workflowIDChanged,
|
||||
workflowNameChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { useGetFormFieldInitialValues } from 'features/workflowLibrary/hooks/useGetFormInitialValues';
|
||||
import { newWorkflowSaved } from 'features/workflowLibrary/store/actions';
|
||||
import { useCallback, useRef } from 'react';
|
||||
@@ -19,12 +19,12 @@ import type { SetFieldType } from 'type-fest';
|
||||
* A draft workflow is a workflow that is has not been saved yet. It does not have an id and is not in the default category.
|
||||
*/
|
||||
type DraftWorkflow = SetFieldType<
|
||||
SetFieldType<WorkflowV4, 'id', undefined>,
|
||||
SetFieldType<WorkflowV3, 'id', undefined>,
|
||||
'meta',
|
||||
SetFieldType<WorkflowV4['meta'], 'category', Exclude<WorkflowV4['meta']['category'], 'default'>>
|
||||
SetFieldType<WorkflowV3['meta'], 'category', Exclude<WorkflowV3['meta']['category'], 'default'>>
|
||||
>;
|
||||
|
||||
export const isDraftWorkflow = (workflow: WorkflowV4): workflow is DraftWorkflow =>
|
||||
export const isDraftWorkflow = (workflow: WorkflowV3): workflow is DraftWorkflow =>
|
||||
!workflow.id && workflow.meta.category !== 'default';
|
||||
|
||||
type CreateLibraryWorkflowArg = {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
|
||||
import { workflowLoadedFromFile } from 'features/workflowLibrary/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
@@ -17,12 +17,12 @@ export const useLoadWorkflowFromFile = () => {
|
||||
(
|
||||
file: File,
|
||||
options: {
|
||||
onSuccess?: (workflow: WorkflowV4) => void;
|
||||
onSuccess?: (workflow: WorkflowV3) => void;
|
||||
onError?: () => void;
|
||||
onCompleted?: () => void;
|
||||
} = {}
|
||||
) => {
|
||||
return new Promise<WorkflowV4 | void>((resolve, reject) => {
|
||||
return new Promise<WorkflowV3 | void>((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
const rawJSON = reader.result;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
|
||||
@@ -21,7 +21,7 @@ export const useLoadWorkflowFromImage = () => {
|
||||
async (
|
||||
imageName: string,
|
||||
options: {
|
||||
onSuccess?: (workflow: WorkflowV4) => void;
|
||||
onSuccess?: (workflow: WorkflowV3) => void;
|
||||
onError?: () => void;
|
||||
onCompleted?: () => void;
|
||||
} = {}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useToast } from '@invoke-ai/ui-library';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -21,7 +21,7 @@ export const useLoadWorkflowFromLibrary = () => {
|
||||
async (
|
||||
workflowId: string,
|
||||
options: {
|
||||
onSuccess?: (workflow: WorkflowV4) => void;
|
||||
onSuccess?: (workflow: WorkflowV3) => void;
|
||||
onError?: () => void;
|
||||
onCompleted?: () => void;
|
||||
} = {}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { useValidateAndLoadWorkflow } from 'features/workflowLibrary/hooks/useValidateAndLoadWorkflow';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
@@ -14,7 +14,7 @@ export const useLoadWorkflowFromObject = () => {
|
||||
async (
|
||||
unvalidatedWorkflow: unknown,
|
||||
options: {
|
||||
onSuccess?: (workflow: WorkflowV4) => void;
|
||||
onSuccess?: (workflow: WorkflowV3) => void;
|
||||
onError?: () => void;
|
||||
onCompleted?: () => void;
|
||||
} = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { ToastId } from '@invoke-ai/ui-library';
|
||||
import { useToast } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { formFieldInitialValuesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { useGetFormFieldInitialValues } from 'features/workflowLibrary/hooks/useGetFormInitialValues';
|
||||
import { workflowUpdated } from 'features/workflowLibrary/store/actions';
|
||||
import { useCallback, useRef } from 'react';
|
||||
@@ -14,12 +14,12 @@ import type { SetFieldType, SetRequired } from 'type-fest';
|
||||
* A library workflow is a workflow that is already saved in the library. It has an id and is not in the default category.
|
||||
*/
|
||||
type LibraryWorkflow = SetFieldType<
|
||||
SetRequired<WorkflowV4, 'id'>,
|
||||
SetRequired<WorkflowV3, 'id'>,
|
||||
'meta',
|
||||
SetFieldType<WorkflowV4['meta'], 'category', Exclude<WorkflowV4['meta']['category'], 'default'>>
|
||||
SetFieldType<WorkflowV3['meta'], 'category', Exclude<WorkflowV3['meta']['category'], 'default'>>
|
||||
>;
|
||||
|
||||
export const isLibraryWorkflow = (workflow: WorkflowV4): workflow is LibraryWorkflow =>
|
||||
export const isLibraryWorkflow = (workflow: WorkflowV3): workflow is LibraryWorkflow =>
|
||||
!!workflow.id && workflow.meta.category !== 'default';
|
||||
|
||||
type UseSaveLibraryWorkflowReturn = {
|
||||
|
||||
@@ -6,7 +6,7 @@ import { $templates, workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||
import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import type { WorkflowV4 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
@@ -48,7 +48,7 @@ export const useValidateAndLoadWorkflow = () => {
|
||||
async (
|
||||
unvalidatedWorkflow: unknown,
|
||||
origin: 'file' | 'image' | 'object' | 'library'
|
||||
): Promise<WorkflowV4 | null> => {
|
||||
): Promise<WorkflowV3 | null> => {
|
||||
try {
|
||||
const templates = $templates.get();
|
||||
const { workflow, warnings } = await validateWorkflow({
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import { atom } from 'nanostores';
|
||||
import type { WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types';
|
||||
|
||||
type WorkflowLibraryIntent =
|
||||
| { mode: 'browse' }
|
||||
| { mode: 'trigger-workflow'; onSelect: (workflow: WorkflowRecordListItemWithThumbnailDTO) => void };
|
||||
|
||||
export const $workflowLibraryIntent = atom<WorkflowLibraryIntent>({ mode: 'browse' });
|
||||
|
||||
export const setWorkflowLibraryBrowseIntent = () => {
|
||||
$workflowLibraryIntent.set({ mode: 'browse' });
|
||||
};
|
||||
|
||||
export const setWorkflowLibraryTriggerIntent = (
|
||||
onSelect: (workflow: WorkflowRecordListItemWithThumbnailDTO) => void
|
||||
) => {
|
||||
$workflowLibraryIntent.set({ mode: 'trigger-workflow', onSelect });
|
||||
};
|
||||
@@ -12,25 +12,34 @@ import {
|
||||
isChatGPT4oModelConfig,
|
||||
isCLIPEmbedModelConfig,
|
||||
isCLIPVisionModelConfig,
|
||||
isCogView4MainModelModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isControlLoRAModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isFluxKontextApiModelConfig,
|
||||
isFluxKontextModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isGemini2_5ModelConfig,
|
||||
isImagen3ModelConfig,
|
||||
isImagen4ModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLLaVAModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isRunwayModelConfig,
|
||||
isSD3MainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSigLipModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isT5EncoderModelConfig,
|
||||
isTIModelConfig,
|
||||
isVAEModelConfig,
|
||||
isVeo3ModelConfig,
|
||||
isVideoModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
@@ -57,7 +66,12 @@ const buildModelsHook =
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useCogView4Models = buildModelsHook(isCogView4MainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlLoRAModel = buildModelsHook(isControlLoRAModelConfig);
|
||||
export const useControlLayerModels = buildModelsHook(isControlLayerModelConfig);
|
||||
@@ -89,6 +103,12 @@ export const useRegionalReferenceImageModels = buildModelsHook(
|
||||
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
|
||||
);
|
||||
export const useLLaVAModels = buildModelsHook(isLLaVAModelConfig);
|
||||
export const useImagen3Models = buildModelsHook(isImagen3ModelConfig);
|
||||
export const useImagen4Models = buildModelsHook(isImagen4ModelConfig);
|
||||
export const useChatGPT4oModels = buildModelsHook(isChatGPT4oModelConfig);
|
||||
export const useFluxKontextModels = buildModelsHook(isFluxKontextApiModelConfig);
|
||||
export const useVeo3Models = buildModelsHook(isVeo3ModelConfig);
|
||||
export const useRunwayModels = buildModelsHook(isRunwayModelConfig);
|
||||
export const useVideoModels = buildModelsHook(isVideoModelConfig);
|
||||
|
||||
const buildModelsSelector =
|
||||
|
||||
@@ -5580,7 +5580,7 @@ export type components = {
|
||||
repo_variant?: components["schemas"]["ModelRepoVariant"] | null;
|
||||
};
|
||||
/**
|
||||
* ControlNet - SD1.5, SD2, SDXL
|
||||
* ControlNet - SD1.5, SDXL
|
||||
* @description Collects ControlNet info to pass to other nodes
|
||||
*/
|
||||
ControlNetInvocation: {
|
||||
@@ -11805,26 +11805,6 @@ export type components = {
|
||||
ui_choice_labels: {
|
||||
[key: string]: string;
|
||||
} | null;
|
||||
/**
|
||||
* Ui Model Base
|
||||
* @default null
|
||||
*/
|
||||
ui_model_base: components["schemas"]["BaseModelType"][] | null;
|
||||
/**
|
||||
* Ui Model Type
|
||||
* @default null
|
||||
*/
|
||||
ui_model_type: components["schemas"]["ModelType"][] | null;
|
||||
/**
|
||||
* Ui Model Variant
|
||||
* @default null
|
||||
*/
|
||||
ui_model_variant: (components["schemas"]["ClipVariantType"] | components["schemas"]["ModelVariantType"])[] | null;
|
||||
/**
|
||||
* Ui Model Format
|
||||
* @default null
|
||||
*/
|
||||
ui_model_format: components["schemas"]["ModelFormat"][] | null;
|
||||
};
|
||||
/**
|
||||
* InstallStatus
|
||||
@@ -15183,7 +15163,7 @@ export type components = {
|
||||
guidance?: number | null;
|
||||
};
|
||||
/**
|
||||
* Main Model - SD1.5, SD2
|
||||
* Main Model - SD1.5
|
||||
* @description Loads a main model, outputting its submodels.
|
||||
*/
|
||||
MainModelLoaderInvocation: {
|
||||
@@ -17739,18 +17719,11 @@ export type components = {
|
||||
*/
|
||||
OutputFieldJSONSchemaExtra: {
|
||||
field_kind: components["schemas"]["FieldKind"];
|
||||
/**
|
||||
* Ui Hidden
|
||||
* @default false
|
||||
*/
|
||||
/** Ui Hidden */
|
||||
ui_hidden: boolean;
|
||||
/**
|
||||
* Ui Order
|
||||
* @default null
|
||||
*/
|
||||
ui_order: number | null;
|
||||
/** @default null */
|
||||
ui_type: components["schemas"]["UIType"] | null;
|
||||
/** Ui Order */
|
||||
ui_order: number | null;
|
||||
};
|
||||
/** PaginatedResults[WorkflowRecordListItemWithThumbnailDTO] */
|
||||
PaginatedResults_WorkflowRecordListItemWithThumbnailDTO_: {
|
||||
@@ -21857,7 +21830,7 @@ export type components = {
|
||||
* used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
* @enum {string}
|
||||
*/
|
||||
UIType: "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict" | "DEPRECATED_MainModelField" | "DEPRECATED_CogView4MainModelField" | "DEPRECATED_FluxMainModelField" | "DEPRECATED_SD3MainModelField" | "DEPRECATED_SDXLMainModelField" | "DEPRECATED_SDXLRefinerModelField" | "DEPRECATED_ONNXModelField" | "DEPRECATED_VAEModelField" | "DEPRECATED_FluxVAEModelField" | "DEPRECATED_LoRAModelField" | "DEPRECATED_ControlNetModelField" | "DEPRECATED_IPAdapterModelField" | "DEPRECATED_T2IAdapterModelField" | "DEPRECATED_T5EncoderModelField" | "DEPRECATED_CLIPEmbedModelField" | "DEPRECATED_CLIPLEmbedModelField" | "DEPRECATED_CLIPGEmbedModelField" | "DEPRECATED_SpandrelImageToImageModelField" | "DEPRECATED_ControlLoRAModelField" | "DEPRECATED_SigLipModelField" | "DEPRECATED_FluxReduxModelField" | "DEPRECATED_LLaVAModelField" | "DEPRECATED_Imagen3ModelField" | "DEPRECATED_Imagen4ModelField" | "DEPRECATED_ChatGPT4oModelField" | "DEPRECATED_Gemini2_5ModelField" | "DEPRECATED_FluxKontextModelField" | "DEPRECATED_Veo3ModelField" | "DEPRECATED_RunwayModelField";
|
||||
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "Imagen3ModelField" | "Imagen4ModelField" | "ChatGPT4oModelField" | "Gemini2_5ModelField" | "FluxKontextModelField" | "Veo3ModelField" | "RunwayModelField" | "SchedulerField" | "AnyField" | "VideoField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||
/** UNetField */
|
||||
UNetField: {
|
||||
/** @description Info to load unet submodel */
|
||||
@@ -22203,7 +22176,7 @@ export type components = {
|
||||
seamless_axes?: string[];
|
||||
};
|
||||
/**
|
||||
* VAE Model - SD1.5, SD2, SDXL, SD3, FLUX
|
||||
* VAE Model - SD1.5, SDXL, SD3, FLUX
|
||||
* @description Loads a VAE model, outputting a VaeLoaderOutput
|
||||
*/
|
||||
VAELoaderInvocation: {
|
||||
@@ -22572,11 +22545,6 @@ export type components = {
|
||||
* @description The exposed fields of the workflow.
|
||||
*/
|
||||
exposedFields: components["schemas"]["ExposedField"][];
|
||||
/**
|
||||
* Output Fields
|
||||
* @description The fields designated as output fields for the workflow.
|
||||
*/
|
||||
output_fields?: components["schemas"]["FieldIdentifier"][] | null;
|
||||
/** @description The meta of the workflow. */
|
||||
meta: components["schemas"]["WorkflowMeta"];
|
||||
/**
|
||||
@@ -22718,12 +22686,6 @@ export type components = {
|
||||
* @description The tags of the workflow.
|
||||
*/
|
||||
tags: string;
|
||||
/**
|
||||
* Has Valid Image Output Field
|
||||
* @description True when the workflow exposes exactly one output field and it is an image output.
|
||||
* @default false
|
||||
*/
|
||||
has_valid_image_output_field?: boolean;
|
||||
/**
|
||||
* Thumbnail Url
|
||||
* @description The URL of the workflow thumbnail.
|
||||
@@ -22818,11 +22780,6 @@ export type components = {
|
||||
* @description The exposed fields of the workflow.
|
||||
*/
|
||||
exposedFields: components["schemas"]["ExposedField"][];
|
||||
/**
|
||||
* Output Fields
|
||||
* @description The fields designated as output fields for the workflow.
|
||||
*/
|
||||
output_fields?: components["schemas"]["FieldIdentifier"][] | null;
|
||||
/** @description The meta of the workflow. */
|
||||
meta: components["schemas"]["WorkflowMeta"];
|
||||
/**
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user