diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py
index 6b875d37ce..20b2781ef0 100644
--- a/invokeai/app/api_app.py
+++ b/invokeai/app/api_app.py
@@ -38,7 +38,7 @@ import mimetypes
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
-from .invocations.baseinvocation import BaseInvocation
+from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch
@@ -134,6 +134,11 @@ def custom_openapi():
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
+ # Add Node Editor UI helper schemas
+ ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
+ for schema_key, output_schema in ui_config_schemas["definitions"].items():
+ openapi_schema["components"]["schemas"][schema_key] = output_schema
+
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py
index 758ab2e787..65aeef75d8 100644
--- a/invokeai/app/invocations/baseinvocation.py
+++ b/invokeai/app/invocations/baseinvocation.py
@@ -3,15 +3,353 @@
from __future__ import annotations
from abc import ABC, abstractmethod
+from enum import Enum
from inspect import signature
-from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
+from typing import (
+ TYPE_CHECKING,
+ AbstractSet,
+ Any,
+ Callable,
+ ClassVar,
+ Mapping,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ get_args,
+ get_type_hints,
+)
-from pydantic import BaseConfig, BaseModel, Field
+from pydantic import BaseModel, Field
+from pydantic.fields import Undefined
+from pydantic.typing import NoArgAnyCallable
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
+class FieldDescriptions:
+ denoising_start = "When to start denoising, expressed a percentage of total steps"
+ denoising_end = "When to stop denoising, expressed a percentage of total steps"
+ cfg_scale = "Classifier-Free Guidance scale"
+ scheduler = "Scheduler to use during inference"
+ positive_cond = "Positive conditioning tensor"
+ negative_cond = "Negative conditioning tensor"
+ noise = "Noise tensor"
+ clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
+ unet = "UNet (scheduler, LoRAs)"
+ vae = "VAE"
+ cond = "Conditioning tensor"
+ controlnet_model = "ControlNet model to load"
+ vae_model = "VAE model to load"
+ lora_model = "LoRA model to load"
+ main_model = "Main model (UNet, VAE, CLIP) to load"
+ sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
+ sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
+ onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
+ lora_weight = "The weight at which the LoRA is applied to each model"
+ compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
+ raw_prompt = "Raw prompt text (no parsing)"
+ sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
+ skipped_layers = "Number of layers to skip in text encoder"
+ seed = "Seed for random number generation"
+ steps = "Number of steps to run"
+ width = "Width of output (px)"
+ height = "Height of output (px)"
+ control = "ControlNet(s) to apply"
+ denoised_latents = "Denoised latents tensor"
+ latents = "Latents tensor"
+ strength = "Strength of denoising (proportional to steps)"
+ core_metadata = "Optional core metadata to be written to image"
+ interp_mode = "Interpolation mode"
+ torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
+ fp32 = "Whether or not to use full float32 precision"
+ precision = "Precision to use"
+ tiled = "Processing using overlapping tiles (reduce memory consumption)"
+ detect_res = "Pixel resolution for detection"
+ image_res = "Pixel resolution for output image"
+ safe_mode = "Whether or not to use safe mode"
+ scribble_mode = "Whether or not to use scribble mode"
+ scale_factor = "The factor by which to scale"
+ num_1 = "The first number"
+ num_2 = "The second number"
+ mask = "The mask to use for the operation"
+
+
+class Input(str, Enum):
+ """
+ The type of input a field accepts.
+ - `Input.Direct`: The field must have its value provided directly, when the invocation and field \
+ are instantiated.
+ - `Input.Connection`: The field must have its value provided by a connection.
+ - `Input.Any`: The field may have its value provided either directly or by a connection.
+ """
+
+ Connection = "connection"
+ Direct = "direct"
+ Any = "any"
+
+
+class UITypeHint(str, Enum):
+ """
+ Type hints for the UI.
+ If a field should be provided a data type that does not exactly match the python type of the field, \
+ use this to provide the type that should be used instead. See the node development docs for detail \
+ on adding a new field type, which involves client-side changes.
+ """
+
+ Integer = "integer"
+ Float = "float"
+ Boolean = "boolean"
+ String = "string"
+ Enum = "enum"
+ Array = "array"
+ ImageField = "ImageField"
+ LatentsField = "LatentsField"
+ ConditioningField = "ConditioningField"
+ ControlField = "ControlField"
+ MainModelField = "MainModelField"
+ SDXLMainModelField = "SDXLMainModelField"
+ SDXLRefinerModelField = "SDXLRefinerModelField"
+ ONNXModelField = "ONNXModelField"
+ VaeModelField = "VaeModelField"
+ LoRAModelField = "LoRAModelField"
+ ControlNetModelField = "ControlNetModelField"
+ UNetField = "UNetField"
+ VaeField = "VaeField"
+ ClipField = "ClipField"
+ ColorField = "ColorField"
+ ImageCollection = "ImageCollection"
+ IntegerCollection = "IntegerCollection"
+ FloatCollection = "FloatCollection"
+ StringCollection = "StringCollection"
+ BooleanCollection = "BooleanCollection"
+ Collection = "Collection"
+ CollectionItem = "CollectionItem"
+ Seed = "Seed"
+ FilePath = "FilePath"
+
+
+class UIComponent(str, Enum):
+ """
+ The type of UI component to use for a field, used to override the default components, which are \
+ inferred from the field type.
+ """
+
+ None_ = "none"
+ Textarea = "textarea"
+ Slider = "slider"
+
+
+class _InputField(BaseModel):
+ """
+ *DO NOT USE*
+ This helper class is used to tell the client about our custom field attributes via OpenAPI
+ schema generation, and Typescript type generation from that schema. It serves no functional
+ purpose in the backend.
+ """
+
+ input: Input
+ ui_hidden: bool
+ ui_type_hint: Optional[UITypeHint]
+ ui_component: Optional[UIComponent]
+
+
+class _OutputField(BaseModel):
+ """
+ *DO NOT USE*
+ This helper class is used to tell the client about our custom field attributes via OpenAPI
+ schema generation, and Typescript type generation from that schema. It serves no functional
+ purpose in the backend.
+ """
+
+ ui_hidden: bool
+ ui_type_hint: Optional[UITypeHint]
+
+
+def InputField(
+ *args: Any,
+ default: Any = Undefined,
+ default_factory: Optional[NoArgAnyCallable] = None,
+ alias: Optional[str] = None,
+ title: Optional[str] = None,
+ description: Optional[str] = None,
+ exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
+ include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
+ const: Optional[bool] = None,
+ gt: Optional[float] = None,
+ ge: Optional[float] = None,
+ lt: Optional[float] = None,
+ le: Optional[float] = None,
+ multiple_of: Optional[float] = None,
+ allow_inf_nan: Optional[bool] = None,
+ max_digits: Optional[int] = None,
+ decimal_places: Optional[int] = None,
+ min_items: Optional[int] = None,
+ max_items: Optional[int] = None,
+ unique_items: Optional[bool] = None,
+ min_length: Optional[int] = None,
+ max_length: Optional[int] = None,
+ allow_mutation: bool = True,
+ regex: Optional[str] = None,
+ discriminator: Optional[str] = None,
+ repr: bool = True,
+ input: Input = Input.Any,
+ ui_type_hint: Optional[UITypeHint] = None,
+ ui_component: Optional[UIComponent] = None,
+ ui_hidden: bool = False,
+ **kwargs: Any,
+) -> Any:
+ """
+ Creates an input field for an invocation.
+
+ This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
+ that adds a few extra parameters to support graph execution and the node editor UI.
+
+ :param Input input: [Input.Any] The kind of input this field requires. \
+ `Input.Direct` means a value must be provided on instantiation. \
+ `Input.Connection` means the value must be provided by a connection. \
+ `Input.Any` means either will do.
+
+ :param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \
+ In some situations, the field's type is not enough to infer the correct UI type. \
+ For example, model selection fields should render a dropdown UI component to select a model. \
+ Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
+ `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
+ `UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field.
+
+ :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
+ The UI will always render a suitable component, but sometimes you want something different than the default. \
+ For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
+ For this case, you could provide `UIComponent.Textarea`.
+
+ : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
+ """
+ return Field(
+ *args,
+ default=default,
+ default_factory=default_factory,
+ alias=alias,
+ title=title,
+ description=description,
+ exclude=exclude,
+ include=include,
+ const=const,
+ gt=gt,
+ ge=ge,
+ lt=lt,
+ le=le,
+ multiple_of=multiple_of,
+ allow_inf_nan=allow_inf_nan,
+ max_digits=max_digits,
+ decimal_places=decimal_places,
+ min_items=min_items,
+ max_items=max_items,
+ unique_items=unique_items,
+ min_length=min_length,
+ max_length=max_length,
+ allow_mutation=allow_mutation,
+ regex=regex,
+ discriminator=discriminator,
+ repr=repr,
+ input=input,
+ ui_type_hint=ui_type_hint,
+ ui_component=ui_component,
+ ui_hidden=ui_hidden,
+ **kwargs,
+ )
+
+
+def OutputField(
+ *args: Any,
+ default: Any = Undefined,
+ default_factory: Optional[NoArgAnyCallable] = None,
+ alias: Optional[str] = None,
+ title: Optional[str] = None,
+ description: Optional[str] = None,
+ exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
+ include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
+ const: Optional[bool] = None,
+ gt: Optional[float] = None,
+ ge: Optional[float] = None,
+ lt: Optional[float] = None,
+ le: Optional[float] = None,
+ multiple_of: Optional[float] = None,
+ allow_inf_nan: Optional[bool] = None,
+ max_digits: Optional[int] = None,
+ decimal_places: Optional[int] = None,
+ min_items: Optional[int] = None,
+ max_items: Optional[int] = None,
+ unique_items: Optional[bool] = None,
+ min_length: Optional[int] = None,
+ max_length: Optional[int] = None,
+ allow_mutation: bool = True,
+ regex: Optional[str] = None,
+ discriminator: Optional[str] = None,
+ repr: bool = True,
+ ui_type_hint: Optional[UITypeHint] = None,
+ ui_hidden: bool = False,
+ **kwargs: Any,
+) -> Any:
+ """
+ Creates an output field for an invocation output.
+
+ This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
+ that adds a few extra parameters to support graph execution and the node editor UI.
+
+ :param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \
+ In some situations, the field's type is not enough to infer the correct UI type. \
+ For example, model selection fields should render a dropdown UI component to select a model. \
+ Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
+ `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
+ `UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field.
+
+ : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
+ """
+ return Field(
+ *args,
+ default=default,
+ default_factory=default_factory,
+ alias=alias,
+ title=title,
+ description=description,
+ exclude=exclude,
+ include=include,
+ const=const,
+ gt=gt,
+ ge=ge,
+ lt=lt,
+ le=le,
+ multiple_of=multiple_of,
+ allow_inf_nan=allow_inf_nan,
+ max_digits=max_digits,
+ decimal_places=decimal_places,
+ min_items=min_items,
+ max_items=max_items,
+ unique_items=unique_items,
+ min_length=min_length,
+ max_length=max_length,
+ allow_mutation=allow_mutation,
+ regex=regex,
+ discriminator=discriminator,
+ repr=repr,
+ ui_type_hint=ui_type_hint,
+ ui_hidden=ui_hidden,
+ **kwargs,
+ )
+
+
+class UIConfigBase(BaseModel):
+ """
+ Provides additional node configuration to the UI.
+ This is used internally by the @tags and @title decorator logic. You probably want to use those
+ decorators, though you may add this class to a node definition to specify the title and tags.
+ """
+
+ tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
+ title: Optional[str] = Field(default=None, description="The display name of the node")
+
+
class InvocationContext:
services: InvocationServices
graph_execution_state_id: str
@@ -39,6 +377,20 @@ class BaseInvocationOutput(BaseModel):
return tuple(subclasses)
+class RequiredConnectionException(Exception):
+ """Raised when an field which requires a connection did not receive a value."""
+
+ def __init__(self, node_id: str, field_name: str):
+ super().__init__(f"Node {node_id} missing connections for field {field_name}")
+
+
+class MissingInputException(Exception):
+ """Raised when an field which requires some input, but did not receive a value."""
+
+ def __init__(self, node_id: str, field_name: str):
+ super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
+
+
class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
@@ -76,70 +428,81 @@ class BaseInvocation(ABC, BaseModel):
def get_output_type(cls):
return signature(cls.invoke).return_annotation
+ class Config:
+ @staticmethod
+ def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
+ uiconfig = getattr(model_class, "UIConfig", None)
+ if uiconfig and hasattr(uiconfig, "title"):
+ schema["title"] = uiconfig.title
+ if uiconfig and hasattr(uiconfig, "tags"):
+ schema["tags"] = uiconfig.tags
+
@abstractmethod
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs."""
pass
- # fmt: off
- id: str = Field(description="The id of this node. Must be unique among all nodes.")
- is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
- # fmt: on
+ def __init__(self, **data):
+ # nodes may have required fields, that can accept input from connections
+ # on instantiation of the model, we need to exclude these from validation
+ restore = dict()
+ try:
+ field_names = list(self.__fields__.keys())
+ for field_name in field_names:
+ # if the field is required and may get its value from a connection, exclude it from validation
+ field = self.__fields__[field_name]
+ _input = field.field_info.extra.get("input", None)
+ if _input in [Input.Connection, Input.Any] and field.required:
+ if field_name not in data:
+ restore[field_name] = self.__fields__.pop(field_name)
+ # instantiate the node, which will validate the data
+ super().__init__(**data)
+ finally:
+ # restore the removed fields
+ for field_name, field in restore.items():
+ self.__fields__[field_name] = field
+
+ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
+ for field_name, field in self.__fields__.items():
+ _input = field.field_info.extra.get("input", None)
+ if field.required and not hasattr(self, field_name):
+ if _input == Input.Connection:
+ raise RequiredConnectionException(self.__fields__["type"].default, field_name)
+ elif _input == Input.Any:
+ raise MissingInputException(self.__fields__["type"].default, field_name)
+ return self.invoke(context)
+
+ id: str = InputField(description="The id of this node. Must be unique among all nodes.")
+ is_intermediate: bool = InputField(
+ default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
+ )
+ UIConfig: ClassVar[Type[UIConfigBase]]
-# TODO: figure out a better way to provide these hints
-# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
-class UIConfig(TypedDict, total=False):
- type_hints: Dict[
- str,
- Literal[
- "integer",
- "float",
- "boolean",
- "string",
- "enum",
- "image",
- "latents",
- "model",
- "control",
- "image_collection",
- "vae_model",
- "lora_model",
- ],
- ]
- tags: List[str]
- title: str
+T = TypeVar("T", bound=BaseInvocation)
-class CustomisedSchemaExtra(TypedDict):
- ui: UIConfig
+def title(title: str) -> Callable[[Type[T]], Type[T]]:
+ """Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
+
+ def wrapper(cls: Type[T]) -> Type[T]:
+ uiconf_name = cls.__qualname__ + ".UIConfig"
+ if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
+ cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
+ cls.UIConfig.title = title
+ return cls
+
+ return wrapper
-class InvocationConfig(BaseConfig):
- """Customizes pydantic's BaseModel.Config class for use by Invocations.
+def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
+ """Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
- Provide `schema_extra` a `ui` dict to add hints for generated UIs.
+ def wrapper(cls: Type[T]) -> Type[T]:
+ uiconf_name = cls.__qualname__ + ".UIConfig"
+ if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
+ cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
+ cls.UIConfig.tags = list(tags)
+ return cls
- `tags`
- - A list of strings, used to categorise invocations.
-
- `type_hints`
- - A dict of field types which override the types in the invocation definition.
- - Each key should be the name of one of the invocation's fields.
- - Each value should be one of the valid types:
- - `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
-
- ```python
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "tags": ["stable-diffusion", "image"],
- "type_hints": {
- "initial_image": "image",
- },
- },
- }
- ```
- """
-
- schema_extra: CustomisedSchemaExtra
+ return wrapper
diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py
index 01c003da96..0dd3b757dc 100644
--- a/invokeai/app/invocations/collections.py
+++ b/invokeai/app/invocations/collections.py
@@ -3,58 +3,78 @@
from typing import Literal
import numpy as np
-from pydantic import Field, validator
+from pydantic import validator
from invokeai.app.models.image import ImageField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ InputField,
+ InvocationContext,
+ OutputField,
+ UITypeHint,
+ tags,
+ title,
+)
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
- type: Literal["int_collection"] = "int_collection"
+ type: Literal["int_collection_output"] = "int_collection_output"
# Outputs
- collection: list[int] = Field(default=[], description="The int collection")
+ collection: list[int] = OutputField(
+ default=[], description="The int collection", ui_type_hint=UITypeHint.IntegerCollection
+ )
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
- type: Literal["float_collection"] = "float_collection"
+ type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
- collection: list[float] = Field(default=[], description="The float collection")
+ collection: list[float] = OutputField(
+ default=[], description="The float collection", ui_type_hint=UITypeHint.FloatCollection
+ )
+
+
+class StringCollectionOutput(BaseInvocationOutput):
+ """A collection of strings"""
+
+ type: Literal["string_collection_output"] = "string_collection_output"
+
+ # Outputs
+ collection: list[str] = OutputField(
+ default=[], description="The output strings", ui_type_hint=UITypeHint.StringCollection
+ )
class ImageCollectionOutput(BaseInvocationOutput):
"""A collection of images"""
- type: Literal["image_collection"] = "image_collection"
+ type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
- collection: list[ImageField] = Field(default=[], description="The output images")
-
- class Config:
- schema_extra = {"required": ["type", "collection"]}
+ collection: list[ImageField] = OutputField(
+ default=[], description="The output images", ui_type_hint=UITypeHint.ImageCollection
+ )
+@title("Integer Range")
+@tags("collection", "integer", "range")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range"
# Inputs
- start: int = Field(default=0, description="The start of the range")
- stop: int = Field(default=10, description="The stop of the range")
- step: int = Field(default=1, description="The step of the range")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
- }
+ start: int = InputField(default=0, description="The start of the range")
+ stop: int = InputField(default=10, description="The stop of the range")
+ step: int = InputField(default=1, description="The step of the range")
@validator("stop")
def stop_gt_start(cls, v, values):
@@ -66,72 +86,56 @@ class RangeInvocation(BaseInvocation):
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
+@title("Integer Range of Size")
+@tags("range", "integer", "size", "collection")
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
- start: int = Field(default=0, description="The start of the range")
- size: int = Field(default=1, description="The number of values")
- step: int = Field(default=1, description="The step of the range")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
- }
+ start: int = InputField(default=0, description="The start of the range")
+ size: int = InputField(default=1, description="The number of values")
+ step: int = InputField(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
+@title("Random Range")
+@tags("range", "integer", "random", "collection")
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
- low: int = Field(default=0, description="The inclusive low value")
- high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
- size: int = Field(default=1, description="The number of values to generate")
- seed: int = Field(
+ low: int = InputField(default=0, description="The inclusive low value")
+ high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
+ size: int = InputField(default=1, description="The number of values to generate")
+ seed: int = InputField(
ge=0,
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
- }
-
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
rng = np.random.default_rng(self.seed)
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
+@title("Image Collection")
+@tags("image", "collection")
class ImageCollectionInvocation(BaseInvocation):
"""Load a collection of images and provide it as output."""
- # fmt: off
type: Literal["image_collection"] = "image_collection"
# Inputs
- images: list[ImageField] = Field(
- default=[], description="The image collection to load"
+ images: list[ImageField] = InputField(
+ default=[], description="The image collection to load", ui_type_hint=UITypeHint.ImageCollection
)
- # fmt: on
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.images)
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "type_hints": {
- "title": "Image Collection",
- "images": "image_collection",
- }
- },
- }
diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index 86565366d9..0f7c61a6dd 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -1,29 +1,39 @@
-from typing import Literal, Optional, Union, List, Annotated
-from pydantic import BaseModel, Field
import re
-
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
-from .model import ClipField
-
-from ...backend.util.devices import torch_dtype
-from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
-from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
+from dataclasses import dataclass
+from typing import List, Literal, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
-from ...backend.util.devices import torch_dtype
-from ...backend.model_management import ModelType
-from ...backend.model_management.models import ModelNotFoundException
+from pydantic import BaseModel, Field
+
+from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
+ BasicConditioningInfo,
+ SDXLConditioningInfo,
+)
+
+from ...backend.model_management import ModelPatcher, ModelType
from ...backend.model_management.lora import ModelPatcher
-from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
+from ...backend.model_management.models import ModelNotFoundException
+from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
+from ...backend.util.devices import torch_dtype
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ FieldDescriptions,
+ Input,
+ InputField,
+ InvocationContext,
+ OutputField,
+ UIComponent,
+ tags,
+ title,
+)
from .model import ClipField
-from dataclasses import dataclass
class ConditioningField(BaseModel):
- conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
+ conditioning_name: str = Field(description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
@@ -47,23 +57,27 @@ class CompelOutput(BaseInvocationOutput):
# fmt: off
type: Literal["compel_output"] = "compel_output"
- conditioning: ConditioningField = Field(default=None, description="Conditioning")
+ conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
# fmt: on
+@title("Compel Prompt")
+@tags("prompt", "compel")
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
- prompt: str = Field(default="", description="Prompt")
- clip: ClipField = Field(None, description="Clip to use")
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
- }
+ prompt: str = InputField(
+ default="",
+ description=FieldDescriptions.compel_prompt,
+ ui_component=UIComponent.Textarea,
+ )
+ clip: ClipField = InputField(
+ title="CLIP",
+ description=FieldDescriptions.clip,
+ input=Input.Connection,
+ )
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
@@ -270,27 +284,23 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec
+@title("SDXL Compel Prompt")
+@tags("sdxl", "compel", "prompt")
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
- prompt: str = Field(default="", description="Prompt")
- style: str = Field(default="", description="Style prompt")
- original_width: int = Field(1024, description="")
- original_height: int = Field(1024, description="")
- crop_top: int = Field(0, description="")
- crop_left: int = Field(0, description="")
- target_width: int = Field(1024, description="")
- target_height: int = Field(1024, description="")
- clip: ClipField = Field(None, description="Clip to use")
- clip2: ClipField = Field(None, description="Clip2 to use")
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
- }
+ prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
+ style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
+ original_width: int = InputField(default=1024, description="")
+ original_height: int = InputField(default=1024, description="")
+ crop_top: int = InputField(default=0, description="")
+ crop_left: int = InputField(default=0, description="")
+ target_width: int = InputField(default=1024, description="")
+ target_height: int = InputField(default=1024, description="")
+ clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
+ clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
@@ -333,28 +343,22 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
+@title("SDXL Refiner Compel Prompt")
+@tags("sdxl", "compel", "prompt")
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
- style: str = Field(default="", description="Style prompt") # TODO: ?
- original_width: int = Field(1024, description="")
- original_height: int = Field(1024, description="")
- crop_top: int = Field(0, description="")
- crop_left: int = Field(0, description="")
- aesthetic_score: float = Field(6.0, description="")
- clip2: ClipField = Field(None, description="Clip to use")
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "SDXL Refiner Prompt (Compel)",
- "tags": ["prompt", "compel"],
- "type_hints": {"model": "model"},
- },
- }
+ style: str = InputField(
+ default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
+ ) # TODO: ?
+ original_width: int = InputField(default=1024, description="")
+ original_height: int = InputField(default=1024, description="")
+ crop_top: int = InputField(default=0, description="")
+ crop_left: int = InputField(default=0, description="")
+ aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
+ clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
@@ -391,21 +395,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"
- clip: ClipField = Field(None, description="Clip with skipped layers")
+ clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
+@title("CLIP Skip")
+@tags("clipskip", "clip", "skip")
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip"
- clip: ClipField = Field(None, description="Clip to use")
- skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
- }
+ clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
+ skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py
index d2b2d44526..de8ad00026 100644
--- a/invokeai/app/invocations/controlnet_image_processors.py
+++ b/invokeai/app/invocations/controlnet_image_processors.py
@@ -28,77 +28,27 @@ from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
-from ..models.image import ImageOutput, PILInvocationConfig
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ FieldDescriptions,
+ InputField,
+ Input,
+ InvocationContext,
+ OutputField,
+ UITypeHint,
+ tags,
+ title,
+)
+from ..models.image import ImageOutput
-CONTROLNET_DEFAULT_MODELS = [
- ###########################################
- # lllyasviel sd v1.5, ControlNet v1.0 models
- ##############################################
- "lllyasviel/sd-controlnet-canny",
- "lllyasviel/sd-controlnet-depth",
- "lllyasviel/sd-controlnet-hed",
- "lllyasviel/sd-controlnet-seg",
- "lllyasviel/sd-controlnet-openpose",
- "lllyasviel/sd-controlnet-scribble",
- "lllyasviel/sd-controlnet-normal",
- "lllyasviel/sd-controlnet-mlsd",
- #############################################
- # lllyasviel sd v1.5, ControlNet v1.1 models
- #############################################
- "lllyasviel/control_v11p_sd15_canny",
- "lllyasviel/control_v11p_sd15_openpose",
- "lllyasviel/control_v11p_sd15_seg",
- # "lllyasviel/control_v11p_sd15_depth", # broken
- "lllyasviel/control_v11f1p_sd15_depth",
- "lllyasviel/control_v11p_sd15_normalbae",
- "lllyasviel/control_v11p_sd15_scribble",
- "lllyasviel/control_v11p_sd15_mlsd",
- "lllyasviel/control_v11p_sd15_softedge",
- "lllyasviel/control_v11p_sd15s2_lineart_anime",
- "lllyasviel/control_v11p_sd15_lineart",
- "lllyasviel/control_v11p_sd15_inpaint",
- # "lllyasviel/control_v11u_sd15_tile",
- # problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
- # so for now replace "lllyasviel/control_v11f1e_sd15_tile",
- "lllyasviel/control_v11e_sd15_shuffle",
- "lllyasviel/control_v11e_sd15_ip2p",
- "lllyasviel/control_v11f1e_sd15_tile",
- #################################################
- # thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
- ##################################################
- "thibaud/controlnet-sd21-openpose-diffusers",
- "thibaud/controlnet-sd21-canny-diffusers",
- "thibaud/controlnet-sd21-depth-diffusers",
- "thibaud/controlnet-sd21-scribble-diffusers",
- "thibaud/controlnet-sd21-hed-diffusers",
- "thibaud/controlnet-sd21-zoedepth-diffusers",
- "thibaud/controlnet-sd21-color-diffusers",
- "thibaud/controlnet-sd21-openposev2-diffusers",
- "thibaud/controlnet-sd21-lineart-diffusers",
- "thibaud/controlnet-sd21-normalbae-diffusers",
- "thibaud/controlnet-sd21-ade20k-diffusers",
- ##############################################
- # ControlNetMediaPipeface, ControlNet v1.1
- ##############################################
- # ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
- # diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
- # hacked t2l to split to model & subfolder if format is "model,subfolder"
- "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
- "CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
-]
-CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
-CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
+CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
- tuple(
- [
- "just_resize",
- "crop_resize",
- "fill_resize",
- "just_resize_simple",
- ]
- )
+ "just_resize",
+ "crop_resize",
+ "fill_resize",
+ "just_resize_simple",
]
@@ -110,9 +60,8 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel):
- image: ImageField = Field(default=None, description="The control image")
- control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
- # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
+ image: ImageField = Field(description="The control image")
+ control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@@ -135,60 +84,39 @@ class ControlField(BaseModel):
raise ValueError("Control weights must be within -1 to 2 range")
return v
- class Config:
- schema_extra = {
- "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
- "ui": {
- "type_hints": {
- "control_weight": "float",
- "control_model": "controlnet_model",
- # "control_weight": "number",
- }
- },
- }
-
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
- # fmt: off
type: Literal["control_output"] = "control_output"
- control: ControlField = Field(default=None, description="The control info")
- # fmt: on
+
+ # Outputs
+ control: ControlField = OutputField(description=FieldDescriptions.control)
+@title("ControlNet")
+@tags("controlnet")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
- # fmt: off
type: Literal["controlnet"] = "controlnet"
- # Inputs
- image: ImageField = Field(default=None, description="The control image")
- control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
- description="control model used")
- control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
- begin_step_percent: float = Field(default=0, ge=-1, le=2,
- description="When the ControlNet is first applied (% of total steps)")
- end_step_percent: float = Field(default=1, ge=0, le=1,
- description="When the ControlNet is last applied (% of total steps)")
- control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
- resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "ControlNet",
- "tags": ["controlnet", "latents"],
- "type_hints": {
- "model": "model",
- "control": "control",
- # "cfg_scale": "float",
- "cfg_scale": "number",
- "control_weight": "float",
- },
- },
- }
+ # Inputs
+ image: ImageField = InputField(description="The control image")
+ control_model: ControlNetModelField = InputField(
+ default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
+ )
+ control_weight: Union[float, List[float]] = InputField(
+ default=1.0, description="The weight given to the ControlNet", ui_type_hint=UITypeHint.Float
+ )
+ begin_step_percent: float = InputField(
+ default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
+ )
+ end_step_percent: float = InputField(
+ default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
+ )
+ control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
+ resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
@@ -204,19 +132,13 @@ class ControlNetInvocation(BaseInvocation):
)
-class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
+class ImageProcessorInvocation(BaseInvocation):
"""Base class for invocations that preprocess images for ControlNet"""
- # fmt: off
type: Literal["image_processor"] = "image_processor"
- # Inputs
- image: ImageField = Field(default=None, description="The image to process")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Image Processor", "tags": ["image", "processor"]},
- }
+ # Inputs
+ image: ImageField = InputField(description="The image to process")
def run_processor(self, image):
# superclass just passes through image without processing
@@ -255,20 +177,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
)
-class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Canny Processor")
+@tags("controlnet", "canny")
+class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
- # fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor"
- # Input
- low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
- high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
- }
+ # Input
+ low_threshold: int = InputField(
+ default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
+ )
+ high_threshold: int = InputField(
+ default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
+ )
def run_processor(self, image):
canny_processor = CannyDetector()
@@ -276,23 +198,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image
-class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("HED (softedge) Processor")
+@tags("controlnet", "hed", "softedge")
+class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
- # fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor"
- # Inputs
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- # safe not supported in controlnet_aux v0.0.3
- # safe: bool = Field(default=False, description="whether to use safe mode")
- scribble: bool = Field(default=False, description="Whether to use scribble mode")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
- }
+ # Inputs
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
+ # safe not supported in controlnet_aux v0.0.3
+ # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
+ scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
@@ -307,21 +225,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
return processed_image
-class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Lineart Processor")
+@tags("controlnet", "lineart")
+class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
- # fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor"
- # Inputs
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- coarse: bool = Field(default=False, description="Whether to use coarse mode")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
- }
+ # Inputs
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
+ coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
@@ -331,23 +245,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
return processed_image
-class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Lineart Anime Processor")
+@tags("controlnet", "lineart", "anime")
+class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
- # fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
- # Inputs
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Lineart Anime Processor",
- "tags": ["controlnet", "lineart", "anime", "image", "processor"],
- },
- }
+ # Inputs
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
@@ -359,21 +266,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
return processed_image
-class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Openpose Processor")
+@tags("controlnet", "openpose", "pose")
+class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
- # fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor"
- # Inputs
- hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
- }
+ # Inputs
+ hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
@@ -386,22 +289,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
-class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Midas (Depth) Processor")
+@tags("controlnet", "midas", "depth")
+class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
- # fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
- # Inputs
- a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
- bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
- # depth_and_normal not supported in controlnet_aux v0.0.3
- # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
- }
+ # Inputs
+ a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
+ bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
+ # depth_and_normal not supported in controlnet_aux v0.0.3
+ # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
@@ -415,20 +314,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
return processed_image
-class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Normal BAE Processor")
+@tags("controlnet", "normal", "bae")
+class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
- # fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
- # Inputs
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
- }
+ # Inputs
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
@@ -438,22 +333,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
return processed_image
-class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("MLSD Processor")
+@tags("controlnet", "mlsd")
+class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
- # fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
- # Inputs
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
- thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
- }
+ # Inputs
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
+ thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
+ thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
@@ -467,22 +358,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image
-class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("PIDI Processor")
+@tags("controlnet", "pidi")
+class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
- # fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor"
- # Inputs
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- safe: bool = Field(default=False, description="Whether to use safe mode")
- scribble: bool = Field(default=False, description="Whether to use scribble mode")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
- }
+ # Inputs
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
+ safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
+ scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
@@ -496,26 +383,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image
-class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Content Shuffle Processor")
+@tags("controlnet", "contentshuffle")
+class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
- # fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
- # Inputs
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
- w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
- f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Content Shuffle Processor",
- "tags": ["controlnet", "contentshuffle", "image", "processor"],
- },
- }
+ # Inputs
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
+ h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
+ w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
+ f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
@@ -531,17 +411,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
-class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Zoe (Depth) Processor")
+@tags("controlnet", "zoe", "depth")
+class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
- # fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
- }
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
@@ -549,20 +424,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
-class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Mediapipe Face Processor")
+@tags("controlnet", "mediapipe", "face")
+class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
- # fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
- # Inputs
- max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
- min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
- }
+ # Inputs
+ max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
+ min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
@@ -574,23 +445,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
-class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Leres (Depth) Processor")
+@tags("controlnet", "leres", "depth")
+class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
- # fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor"
- # Inputs
- thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
- thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
- boost: bool = Field(default=False, description="Whether to use boost mode")
- detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
- image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
- }
+ # Inputs
+ thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
+ thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
+ boost: bool = InputField(default=False, description="Whether to use boost mode")
+ detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
+ image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
@@ -605,21 +472,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image
-class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
- # fmt: off
- type: Literal["tile_image_processor"] = "tile_image_processor"
- # Inputs
- #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
- down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
- # fmt: on
+@title("Tile Resample Processor")
+@tags("controlnet", "tile")
+class TileResamplerProcessorInvocation(ImageProcessorInvocation):
+ """Tile resampler processor"""
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Tile Resample Processor",
- "tags": ["controlnet", "tile", "resample", "image", "processor"],
- },
- }
+ type: Literal["tile_image_processor"] = "tile_image_processor"
+
+ # Inputs
+ # res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
+ down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(
@@ -648,20 +510,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
-class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+@title("Segment Anything Processor")
+@tags("controlnet", "segmentanything")
+class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
- # fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor"
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Segment Anything Processor",
- "tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
- },
- }
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py
index bd3a4adbe4..ed2030a835 100644
--- a/invokeai/app/invocations/cv.py
+++ b/invokeai/app/invocations/cv.py
@@ -5,40 +5,22 @@ from typing import Literal
import cv2 as cv
import numpy
from PIL import Image, ImageOps
-from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
-from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
+from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .image import ImageOutput
-class CvInvocationConfig(BaseModel):
- """Helper class to provide all OpenCV invocations with additional config"""
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "tags": ["cv", "image"],
- },
- }
-
-
-class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
+@title("OpenCV Inpaint")
+@tags("opencv", "inpaint")
+class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""
- # fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
- image: ImageField = Field(default=None, description="The image to inpaint")
- mask: ImageField = Field(default=None, description="The mask to use when inpainting")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
- }
+ image: ImageField = InputField(description="The image to inpaint")
+ mask: ImageField = InputField(description="The mask to use when inpainting")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py
index 2c47020207..5c277ec30f 100644
--- a/invokeai/app/invocations/image.py
+++ b/invokeai/app/invocations/image.py
@@ -1,37 +1,30 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
-from typing import Literal, Optional, Union
+from typing import Literal, Optional
import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
-from pydantic import Field
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
-from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin
-from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
+from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, ResourceOrigin
+from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
+@title("Load Image")
+@tags("image")
class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output."""
- # fmt: off
+ # Metadata
type: Literal["load_image"] = "load_image"
# Inputs
- image: Optional[ImageField] = Field(
- default=None, description="The image to load"
- )
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Load Image", "tags": ["image", "load"]},
- }
+ image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -43,18 +36,16 @@ class LoadImageInvocation(BaseInvocation):
)
+@title("Show Image")
+@tags("image")
class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline."""
+ # Metadata
type: Literal["show_image"] = "show_image"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to show")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Show Image", "tags": ["image", "show"]},
- }
+ image: ImageField = InputField(description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -70,24 +61,20 @@ class ShowImageInvocation(BaseInvocation):
)
-class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
+@title("Crop Image")
+@tags("image", "crop")
+class ImageCropInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
- # fmt: off
+ # Metadata
type: Literal["img_crop"] = "img_crop"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to crop")
- x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
- y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
- width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
- height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Crop Image", "tags": ["image", "crop"]},
- }
+ image: ImageField = InputField(description="The image to crop")
+ x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
+ y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
+ width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
+ height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -111,24 +98,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
)
-class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
+@title("Paste Image")
+@tags("image", "paste")
+class ImagePasteInvocation(BaseInvocation):
"""Pastes an image into another image."""
- # fmt: off
+ # Metadata
type: Literal["img_paste"] = "img_paste"
# Inputs
- base_image: Optional[ImageField] = Field(default=None, description="The base image")
- image: Optional[ImageField] = Field(default=None, description="The image to paste")
- mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
- x: int = Field(default=0, description="The left x coordinate at which to paste the image")
- y: int = Field(default=0, description="The top y coordinate at which to paste the image")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Paste Image", "tags": ["image", "paste"]},
- }
+ base_image: ImageField = InputField(description="The base image")
+ image: ImageField = InputField(description="The image to paste")
+ mask: Optional[ImageField] = InputField(
+ default=None,
+ description="The mask to use when pasting",
+ )
+ x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
+ y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(self.base_image.image_name)
@@ -164,21 +150,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
)
-class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
+@title("Mask from Alpha")
+@tags("image", "mask")
+class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
- # fmt: off
+ # Metadata
type: Literal["tomask"] = "tomask"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
- invert: bool = Field(default=False, description="Whether or not to invert the mask")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
- }
+ image: ImageField = InputField(description="The image to create the mask from")
+ invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -203,21 +185,17 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
)
-class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
+@title("Multiply Images")
+@tags("image", "multiply")
+class ImageMultiplyInvocation(BaseInvocation):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
- # fmt: off
+ # Metadata
type: Literal["img_mul"] = "img_mul"
# Inputs
- image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
- image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
- }
+ image1: ImageField = InputField(description="The first image to multiply")
+ image2: ImageField = InputField(description="The second image to multiply")
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(self.image1.image_name)
@@ -244,21 +222,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
-class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
+@title("Extract Image Channel")
+@tags("image", "channel")
+class ImageChannelInvocation(BaseInvocation):
"""Gets a channel from an image."""
- # fmt: off
+ # Metadata
type: Literal["img_chan"] = "img_chan"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
- channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Image Channel", "tags": ["image", "channel"]},
- }
+ image: ImageField = InputField(description="The image to get the channel from")
+ channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -284,21 +258,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
-class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
+@title("Convert Image Mode")
+@tags("image", "convert")
+class ImageConvertInvocation(BaseInvocation):
"""Converts an image to a different mode."""
- # fmt: off
+ # Metadata
type: Literal["img_conv"] = "img_conv"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to convert")
- mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Convert Image", "tags": ["image", "convert"]},
- }
+ image: ImageField = InputField(description="The image to convert")
+ mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -321,22 +291,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
)
-class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
+@title("Blur Image")
+@tags("image", "blur")
+class ImageBlurInvocation(BaseInvocation):
"""Blurs an image"""
- # fmt: off
+ # Metadata
type: Literal["img_blur"] = "img_blur"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to blur")
- radius: float = Field(default=8.0, ge=0, description="The blur radius")
- blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Blur Image", "tags": ["image", "blur"]},
- }
+ image: ImageField = InputField(description="The image to blur")
+ radius: float = InputField(default=8.0, ge=0, description="The blur radius")
+ # Metadata
+ blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -382,23 +349,19 @@ PIL_RESAMPLING_MAP = {
}
-class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
+@title("Resize Image")
+@tags("image", "resize")
+class ImageResizeInvocation(BaseInvocation):
"""Resizes an image to specific dimensions"""
- # fmt: off
+ # Metadata
type: Literal["img_resize"] = "img_resize"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to resize")
- width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
- height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
- resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Resize Image", "tags": ["image", "resize"]},
- }
+ image: ImageField = InputField(description="The image to resize")
+ width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
+ height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
+ resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -426,22 +389,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
)
-class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
+@title("Scale Image")
+@tags("image", "scale")
+class ImageScaleInvocation(BaseInvocation):
"""Scales an image by a factor"""
- # fmt: off
+ # Metadata
type: Literal["img_scale"] = "img_scale"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to scale")
- scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
- resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Scale Image", "tags": ["image", "scale"]},
- }
+ image: ImageField = InputField(description="The image to scale")
+ scale_factor: float = InputField(
+ default=2.0,
+ gt=0,
+ description="The factor by which to scale the image",
+ )
+ resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -471,22 +434,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
)
-class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
+@title("Lerp Image")
+@tags("image", "lerp")
+class ImageLerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
- # fmt: off
+ # Metadata
type: Literal["img_lerp"] = "img_lerp"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to lerp")
- min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
- max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
- }
+ image: ImageField = InputField(description="The image to lerp")
+ min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
+ max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -512,25 +471,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
)
-class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
+@title("Inverse Lerp Image")
+@tags("image", "ilerp")
+class ImageInverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
- # fmt: off
+ # Metadata
type: Literal["img_ilerp"] = "img_ilerp"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to lerp")
- min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
- max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Image Inverse Linear Interpolation",
- "tags": ["image", "linear", "interpolation", "inverse"],
- },
- }
+ image: ImageField = InputField(description="The image to lerp")
+ min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
+ max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -556,21 +508,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
)
-class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
+@title("Blur NSFW Image")
+@tags("image", "nsfw")
+class ImageNSFWBlurInvocation(BaseInvocation):
"""Add blur to NSFW-flagged images"""
- # fmt: off
+ # Metadata
type: Literal["img_nsfw"] = "img_nsfw"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to check")
- metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
- }
+ image: ImageField = InputField(description="The image to check")
+ metadata: Optional[CoreMetadata] = InputField(
+ default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
+ )
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -607,22 +557,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
return caution.resize((caution.width // 2, caution.height // 2))
-class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
+@title("Add Invisible Watermark")
+@tags("image", "watermark")
+class ImageWatermarkInvocation(BaseInvocation):
"""Add an invisible watermark to an image"""
- # fmt: off
+ # Metadata
type: Literal["img_watermark"] = "img_watermark"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to check")
- text: str = Field(default='InvokeAI', description="Watermark text")
- metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
- }
+ image: ImageField = InputField(description="The image to check")
+ text: str = InputField(default="InvokeAI", description="Watermark text")
+ metadata: Optional[CoreMetadata] = InputField(
+ default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
+ )
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -644,19 +592,21 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
)
-class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
+@title("Mask Edge")
+@tags("image", "mask", "inpaint")
+class MaskEdgeInvocation(BaseInvocation):
"""Applies an edge mask to an image"""
- # fmt: off
type: Literal["mask_edge"] = "mask_edge"
# Inputs
- image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to")
- edge_size: int = Field(description="The size of the edge")
- edge_blur: int = Field(description="The amount of blur on the edge")
- low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection")
- high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection")
- # fmt: on
+ image: ImageField = InputField(description="The image to apply the mask to")
+ edge_size: int = InputField(description="The size of the edge")
+ edge_blur: int = InputField(description="The amount of blur on the edge")
+ low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
+ high_threshold: int = InputField(
+ description="Second threshold for the hysteresis procedure in Canny edge detection"
+ )
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.services.images.get_pil_image(self.image.image_name)
@@ -690,21 +640,16 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
)
-class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
+@title("Combine Mask")
+@tags("image", "mask", "multiply")
+class MaskCombineInvocation(BaseInvocation):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
- # fmt: off
type: Literal["mask_combine"] = "mask_combine"
# Inputs
- mask1: ImageField = Field(default=None, description="The first mask to combine")
- mask2: ImageField = Field(default=None, description="The second image to combine")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
- }
+ mask1: ImageField = InputField(description="The first mask to combine")
+ mask2: ImageField = InputField(description="The second image to combine")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
@@ -728,7 +673,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
)
-class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
+@title("Color Correct")
+@tags("image", "color")
+class ColorCorrectInvocation(BaseInvocation):
"""
Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image.
@@ -736,10 +683,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["color_correct"] = "color_correct"
- image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
- reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
- mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
- mask_blur_radius: float = Field(default=8, description="Mask blur radius")
+ # Inputs
+ image: ImageField = InputField(description="The image to color-correct")
+ reference: ImageField = InputField(description="Reference image for color-correction")
+ mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
+ mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_init_mask = None
@@ -833,16 +781,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
)
+@title("Image Hue Adjustment")
+@tags("image", "hue", "hsl")
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
- # fmt: off
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
- image: ImageField = Field(default=None, description="The image to adjust")
- hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
- # fmt: on
+ image: ImageField = InputField(description="The image to adjust")
+ hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
@@ -877,16 +825,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
)
+@title("Image Luminosity Adjustment")
+@tags("image", "luminosity", "hsl")
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
- # fmt: off
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
- image: ImageField = Field(default=None, description="The image to adjust")
- luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
- # fmt: on
+ image: ImageField = InputField(description="The image to adjust")
+ luminosity: float = InputField(
+ default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
+ )
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
@@ -925,16 +875,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
)
+@title("Image Saturation Adjustment")
+@tags("image", "saturation", "hsl")
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""
- # fmt: off
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
- image: ImageField = Field(default=None, description="The image to adjust")
- saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
- # fmt: on
+ image: ImageField = InputField(description="The image to adjust")
+ saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py
index cd5b2f9a11..2294f806ca 100644
--- a/invokeai/app/invocations/infill.py
+++ b/invokeai/app/invocations/infill.py
@@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args
import numpy as np
import math
from PIL import Image, ImageOps
-from pydantic import Field
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
-from .baseinvocation import (
- BaseInvocation,
- InvocationConfig,
- InvocationContext,
-)
+from .baseinvocation import BaseInvocation, InputField, InvocationContext, UITypeHint, title, tags
def infill_methods() -> list[str]:
@@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si
+@title("Solid Color Infill")
+@tags("image", "inpaint")
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
- image: Optional[ImageField] = Field(default=None, description="The image to infill")
- color: ColorField = Field(
+
+ # Inputs
+ image: ImageField = InputField(description="The image to infill")
+ color: ColorField = InputField(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
- }
-
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation):
)
+@title("Tile Infill")
+@tags("image", "inpaint")
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
- image: Optional[ImageField] = Field(default=None, description="The image to infill")
- tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
- seed: int = Field(
+ # Input
+ image: ImageField = InputField(description="The image to infill")
+ tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
+ seed: int = InputField(
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
- }
-
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation):
)
+@title("PatchMatch Infill")
+@tags("image", "inpaint")
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
- image: Optional[ImageField] = Field(default=None, description="The image to infill")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
- }
+ # Inputs
+ image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py
index c66c9c6214..63cfd95394 100644
--- a/invokeai/app/invocations/latent.py
+++ b/invokeai/app/invocations/latent.py
@@ -13,7 +13,8 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
-from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
+from diffusers.schedulers import DPMSolverSDEScheduler
+from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from torchvision.transforms.functional import resize as tv_resize
@@ -23,6 +24,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management import BaseModelType, ModelPatcher
+from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
@@ -32,9 +34,20 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
-from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype
+from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ImageField, ResourceOrigin
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ FieldDescriptions,
+ Input,
+ InputField,
+ InvocationContext,
+ OutputField,
+ UITypeHint,
+ tags,
+ title,
+)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
@@ -46,8 +59,8 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
- latents_name: Optional[str] = Field(default=None, description="The name of the latents")
- seed: Optional[int] = Field(description="Seed used to generate this latents")
+ latents_name: str = Field(description="The name of the latents")
+ seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
class Config:
schema_extra = {"required": ["latents_name"]}
@@ -56,14 +69,14 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
- # fmt: off
type: Literal["latents_output"] = "latents_output"
# Inputs
- latents: LatentsField = Field(default=None, description="The output latents")
- width: int = Field(description="The width of the latents in pixels")
- height: int = Field(description="The height of the latents in pixels")
- # fmt: on
+ latents: LatentsField = OutputField(
+ description=FieldDescriptions.latents,
+ )
+ width: int = OutputField(description=FieldDescriptions.width)
+ height: int = OutputField(description=FieldDescriptions.height)
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
@@ -111,30 +124,36 @@ def get_scheduler(
return scheduler
+@title("Denoise Latents")
+@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
type: Literal["denoise_latents"] = "denoise_latents"
# Inputs
- positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
- negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
- noise: Optional[LatentsField] = Field(description="The noise to use")
- steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
- cfg_scale: Union[float, List[float]] = Field(
- default=7.5,
- ge=1,
- description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
+ positive_conditioning: ConditioningField = InputField(
+ description=FieldDescriptions.positive_cond, input=Input.Connection
)
- denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
- denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
- scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
- unet: UNetField = Field(default=None, description="UNet submodel")
- control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
- latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
- mask: Optional[ImageField] = Field(
- None,
- description="Mask",
+ negative_conditioning: ConditioningField = InputField(
+ description=FieldDescriptions.negative_cond, input=Input.Connection
+ )
+ noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
+ steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
+ cfg_scale: Union[float, List[float]] = InputField(
+ default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type_hint=UITypeHint.Float
+ )
+ denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
+ denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
+ scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
+ unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
+ control: Union[ControlField, list[ControlField]] = InputField(
+ default=None, description=FieldDescriptions.control, input=Input.Connection
+ )
+ latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
+ mask: Optional[ImageField] = InputField(
+ default=None,
+ description=FieldDescriptions.mask,
)
@validator("cfg_scale")
@@ -149,20 +168,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Denoise Latents",
- "tags": ["denoise", "latents"],
- "type_hints": {
- "model": "model",
- "control": "control",
- "cfg_scale": "number",
- },
- },
- }
-
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self,
@@ -474,29 +479,29 @@ class DenoiseLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
-# Latent to image
+@title("Latents to Image")
+@tags("latents", "image", "vae")
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
- latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
- vae: VaeField = Field(default=None, description="Vae submodel")
- tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
- fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
- metadata: Optional[CoreMetadata] = Field(
- default=None, description="Optional core metadata to be written to the image"
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
+ )
+ vae: VaeField = InputField(
+ description=FieldDescriptions.vae,
+ input=Input.Connection,
+ )
+ tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
+ fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
+ metadata: CoreMetadata = InputField(
+ default=None,
+ description=FieldDescriptions.core_metadata,
+ ui_hidden=True,
)
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Latents To Image",
- "tags": ["latents", "image"],
- },
- }
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -574,24 +579,30 @@ class LatentsToImageInvocation(BaseInvocation):
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
+@title("Resize Latents")
+@tags("latents", "resize")
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
- latents: Optional[LatentsField] = Field(description="The latents to resize")
- width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
- height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
- mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
- antialias: bool = Field(
- default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
)
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
- }
+ width: int = InputField(
+ ge=64,
+ multiple_of=8,
+ description=FieldDescriptions.width,
+ )
+ height: int = InputField(
+ ge=64,
+ multiple_of=8,
+ description=FieldDescriptions.width,
+ )
+ mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
+ antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@@ -616,23 +627,21 @@ class ResizeLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
+@title("Scale Latents")
+@tags("latents", "resize")
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
- latents: Optional[LatentsField] = Field(description="The latents to scale")
- scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
- mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
- antialias: bool = Field(
- default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.latents,
+ input=Input.Connection,
)
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
- }
+ scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
+ mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
+ antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@@ -658,22 +667,23 @@ class ScaleLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
+@title("Image to Latents")
+@tags("latents", "image", "vae")
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
- image: Optional[ImageField] = Field(description="The image to encode")
- vae: VaeField = Field(default=None, description="Vae submodel")
- tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
- fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
- }
+ image: ImageField = InputField(
+ description="The image to encode",
+ )
+ vae: VaeField = InputField(
+ description=FieldDescriptions.vae,
+ input=Input.Connection,
+ )
+ tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
+ fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py
index 32b1ab2a39..81c032ca89 100644
--- a/invokeai/app/invocations/math.py
+++ b/invokeai/app/invocations/math.py
@@ -2,134 +2,104 @@
from typing import Literal
-from pydantic import BaseModel, Field
import numpy as np
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
+ FieldDescriptions,
+ InputField,
InvocationContext,
- InvocationConfig,
+ OutputField,
+ tags,
+ title,
)
-class MathInvocationConfig(BaseModel):
- """Helper class to provide all math invocations with additional config"""
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "tags": ["math"],
- }
- }
-
-
class IntOutput(BaseInvocationOutput):
"""An integer output"""
- # fmt: off
type: Literal["int_output"] = "int_output"
- a: int = Field(default=None, description="The output integer")
- # fmt: on
+ a: int = OutputField(default=None, description="The output integer")
class FloatOutput(BaseInvocationOutput):
"""A float output"""
- # fmt: off
type: Literal["float_output"] = "float_output"
- param: float = Field(default=None, description="The output float")
- # fmt: on
+ a: float = OutputField(default=None, description="The output float")
-class AddInvocation(BaseInvocation, MathInvocationConfig):
+@title("Add Integers")
+@tags("math")
+class AddInvocation(BaseInvocation):
"""Adds two numbers"""
- # fmt: off
type: Literal["add"] = "add"
- a: int = Field(default=0, description="The first number")
- b: int = Field(default=0, description="The second number")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Add", "tags": ["math", "add"]},
- }
+ # Inputs
+ a: int = InputField(default=0, description=FieldDescriptions.num_1)
+ b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
-class SubtractInvocation(BaseInvocation, MathInvocationConfig):
+@title("Subtract Integers")
+@tags("math")
+class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
- # fmt: off
type: Literal["sub"] = "sub"
- a: int = Field(default=0, description="The first number")
- b: int = Field(default=0, description="The second number")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Subtract", "tags": ["math", "subtract"]},
- }
+ # Inputs
+ a: int = InputField(default=0, description=FieldDescriptions.num_1)
+ b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
-class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
+@title("Multiply Integers")
+@tags("math")
+class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
- # fmt: off
type: Literal["mul"] = "mul"
- a: int = Field(default=0, description="The first number")
- b: int = Field(default=0, description="The second number")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Multiply", "tags": ["math", "multiply"]},
- }
+ # Inputs
+ a: int = InputField(default=0, description=FieldDescriptions.num_1)
+ b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
-class DivideInvocation(BaseInvocation, MathInvocationConfig):
+@title("Divide Integers")
+@tags("math")
+class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
- # fmt: off
type: Literal["div"] = "div"
- a: int = Field(default=0, description="The first number")
- b: int = Field(default=0, description="The second number")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Divide", "tags": ["math", "divide"]},
- }
+ # Inputs
+ a: int = InputField(default=0, description=FieldDescriptions.num_1)
+ b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))
+@title("Random Integer")
+@tags("math")
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
- # fmt: off
type: Literal["rand_int"] = "rand_int"
- low: int = Field(default=0, description="The inclusive low value")
- high: int = Field(
- default=np.iinfo(np.int32).max, description="The exclusive high value"
- )
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
- }
+ # Inputs
+ low: int = InputField(default=0, description="The inclusive low value")
+ high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=np.random.randint(self.low, self.high))
diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py
index d0549f8539..b0e7c13d43 100644
--- a/invokeai/app/invocations/metadata.py
+++ b/invokeai/app/invocations/metadata.py
@@ -1,18 +1,21 @@
-from typing import Literal, Optional, Union
+from typing import Literal, Optional
from pydantic import Field
-from ...version import __version__
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
- InvocationConfig,
+ InputField,
InvocationContext,
+ tags,
+ title,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
+from ...version import __version__
+
class LoRAMetadataField(BaseModelExcludeNull):
"""LoRA metadata for an image generated in InvokeAI."""
@@ -43,37 +46,37 @@ class CoreMetadata(BaseModelExcludeNull):
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
- vae: Union[VAEModelField, None] = Field(
+ vae: Optional[VAEModelField] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# Latents-to-Latents
- strength: Union[float, None] = Field(
+ strength: Optional[float] = Field(
default=None,
description="The strength used for latents-to-latents",
)
- init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
+ init_image: Optional[str] = Field(default=None, description="The name of the initial image")
# SDXL
- positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
- negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
+ positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
+ negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
# SDXL Refiner
- refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
- refiner_cfg_scale: Union[float, None] = Field(
+ refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
+ refiner_cfg_scale: Optional[float] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
- refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
- refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
- refiner_positive_aesthetic_store: Union[float, None] = Field(
+ refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
+ refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
+ refiner_positive_aesthetic_store: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
- refiner_negative_aesthetic_store: Union[float, None] = Field(
+ refiner_negative_aesthetic_store: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
- refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
+ refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModelExcludeNull):
@@ -94,66 +97,83 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
metadata: CoreMetadata = Field(description="The core metadata for the image")
+@title("Metadata Accumulator")
+@tags("metadata")
class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator"
- generation_mode: str = Field(
+ generation_mode: str = InputField(
description="The generation mode that output this image",
)
- positive_prompt: str = Field(description="The positive prompt parameter")
- negative_prompt: str = Field(description="The negative prompt parameter")
- width: int = Field(description="The width parameter")
- height: int = Field(description="The height parameter")
- seed: int = Field(description="The seed used for noise generation")
- rand_device: str = Field(description="The device used for random number generation")
- cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
- steps: int = Field(description="The number of steps used for inference")
- scheduler: str = Field(description="The scheduler used for inference")
- clip_skip: int = Field(
+ positive_prompt: str = InputField(description="The positive prompt parameter")
+ negative_prompt: str = InputField(description="The negative prompt parameter")
+ width: int = InputField(description="The width parameter")
+ height: int = InputField(description="The height parameter")
+ seed: int = InputField(description="The seed used for noise generation")
+ rand_device: str = InputField(description="The device used for random number generation")
+ cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
+ steps: int = InputField(description="The number of steps used for inference")
+ scheduler: str = InputField(description="The scheduler used for inference")
+ clip_skip: int = InputField(
description="The number of skipped CLIP layers",
)
- model: MainModelField = Field(description="The main model used for inference")
- controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
- loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
- strength: Union[float, None] = Field(
+ model: MainModelField = InputField(description="The main model used for inference")
+ controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
+ loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
+ strength: Optional[float] = InputField(
default=None,
description="The strength used for latents-to-latents",
)
- init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
- vae: Union[VAEModelField, None] = Field(
+ init_image: Optional[str] = InputField(
+ default=None,
+ description="The name of the initial image",
+ )
+ vae: Optional[VAEModelField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# SDXL
- positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
- negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
+ positive_style_prompt: Optional[str] = InputField(
+ default=None,
+ description="The positive style prompt parameter",
+ )
+ negative_style_prompt: Optional[str] = InputField(
+ default=None,
+ description="The negative style prompt parameter",
+ )
# SDXL Refiner
- refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
- refiner_cfg_scale: Union[float, None] = Field(
+ refiner_model: Optional[MainModelField] = InputField(
+ default=None,
+ description="The SDXL Refiner model used",
+ )
+ refiner_cfg_scale: Optional[float] = InputField(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
- refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
- refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
- refiner_positive_aesthetic_score: Union[float, None] = Field(
- default=None, description="The aesthetic score used for the refiner"
+ refiner_steps: Optional[int] = InputField(
+ default=None,
+ description="The number of steps used for the refiner",
)
- refiner_negative_aesthetic_score: Union[float, None] = Field(
- default=None, description="The aesthetic score used for the refiner"
+ refiner_scheduler: Optional[str] = InputField(
+ default=None,
+ description="The scheduler used for the refiner",
+ )
+ refiner_positive_aesthetic_store: Optional[float] = InputField(
+ default=None,
+ description="The aesthetic score used for the refiner",
+ )
+ refiner_negative_aesthetic_store: Optional[float] = InputField(
+ default=None,
+ description="The aesthetic score used for the refiner",
+ )
+ refiner_start: Optional[float] = InputField(
+ default=None,
+ description="The start value used for refiner denoising",
)
- refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Metadata Accumulator",
- "tags": ["image", "metadata", "generation"],
- },
- }
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object"""
diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py
index 0d21f8f0ce..de32a9948f 100644
--- a/invokeai/app/invocations/model.py
+++ b/invokeai/app/invocations/model.py
@@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ FieldDescriptions,
+ InputField,
+ Input,
+ InvocationContext,
+ OutputField,
+ UITypeHint,
+ tags,
+ title,
+)
class ModelInfo(BaseModel):
@@ -39,13 +50,11 @@ class VaeField(BaseModel):
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
- # fmt: off
type: Literal["model_loader_output"] = "model_loader_output"
- unet: UNetField = Field(default=None, description="UNet submodel")
- clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
- vae: VaeField = Field(default=None, description="Vae submodel")
- # fmt: on
+ unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
+ clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
+ vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
class MainModelField(BaseModel):
@@ -63,24 +72,17 @@ class LoRAModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model")
+@title("Main Model Loader")
+@tags("model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["main_model_loader"] = "main_model_loader"
- model: MainModelField = Field(description="The model to load")
+ # Inputs
+ model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Model Loader",
- "tags": ["model", "loader"],
- "type_hints": {"model": "model"},
- },
- }
-
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
@@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
loras=[],
skipped_layers=0,
),
- clip2=ClipField(
- tokenizer=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=model_type,
- submodel=SubModelType.Tokenizer2,
- ),
- text_encoder=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=model_type,
- submodel=SubModelType.TextEncoder2,
- ),
- loras=[],
- skipped_layers=0,
- ),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
@@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput):
# fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
- unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
- clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
+ unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
+ clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
# fmt: on
+@title("LoRA Loader")
+@tags("lora", "model")
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
- lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
- weight: float = Field(default=0.75, description="With what weight to apply lora")
-
- unet: Optional[UNetField] = Field(description="UNet model for applying lora")
- clip: Optional[ClipField] = Field(description="Clip model for applying lora")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Lora Loader",
- "tags": ["lora", "loader"],
- "type_hints": {"lora": "lora_model"},
- },
- }
+ # Inputs
+ lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
+ weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
+ unet: Optional[UNetField] = InputField(
+ default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
+ )
+ clip: Optional[ClipField] = InputField(
+ default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
+ )
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
@@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation):
class SDXLLoraLoaderOutput(BaseInvocationOutput):
- """Model loader output"""
+ """SDXL LoRA Loader Output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
- unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
- clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
- clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
+ unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
+ clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
+ clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
# fmt: on
+@title("SDXL LoRA Loader")
+@tags("sdxl", "lora", "model")
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
- lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
- weight: float = Field(default=0.75, description="With what weight to apply lora")
-
- unet: Optional[UNetField] = Field(description="UNet model for applying lora")
- clip: Optional[ClipField] = Field(description="Clip model for applying lora")
- clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "SDXL Lora Loader",
- "tags": ["lora", "loader"],
- "type_hints": {"lora": "lora_model"},
- },
- }
+ lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
+ weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
+ unet: Optional[UNetField] = Field(
+ default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
+ )
+ clip: Optional[ClipField] = Field(
+ default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
+ )
+ clip2: Optional[ClipField] = Field(
+ default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
+ )
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
@@ -369,29 +350,23 @@ class VAEModelField(BaseModel):
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
- # fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output"
- vae: VaeField = Field(default=None, description="Vae model")
- # fmt: on
+ # Outputs
+ vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
+@title("VAE Loader")
+@tags("vae", "model")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
- vae_model: VAEModelField = Field(description="The VAE to load")
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "VAE Loader",
- "tags": ["vae", "loader"],
- "type_hints": {"vae_model": "vae_model"},
- },
- }
+ # Inputs
+ vae_model: VAEModelField = InputField(
+ description=FieldDescriptions.vae_model, input=Input.Direct, ui_type_hint=UITypeHint.VaeModelField, title="VAE"
+ )
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py
index db64e5b6e5..7049dad61a 100644
--- a/invokeai/app/invocations/noise.py
+++ b/invokeai/app/invocations/noise.py
@@ -1,19 +1,24 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
-import math
from typing import Literal
-from pydantic import Field, validator
import torch
-from invokeai.app.invocations.latent import LatentsField
+from pydantic import validator
+from invokeai.app.invocations.latent import LatentsField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
+
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
- InvocationConfig,
+ FieldDescriptions,
+ InputField,
InvocationContext,
+ OutputField,
+ UITypeHint,
+ tags,
+ title,
)
"""
@@ -61,14 +66,12 @@ Nodes
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
- # fmt: off
- type: Literal["noise_output"] = "noise_output"
+ type: Literal["noise_output"] = "noise_output"
# Inputs
- noise: LatentsField = Field(default=None, description="The output noise")
- width: int = Field(description="The width of the noise in pixels")
- height: int = Field(description="The height of the noise in pixels")
- # fmt: on
+ noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
+ width: int = OutputField(description=FieldDescriptions.width)
+ height: int = OutputField(description=FieldDescriptions.height)
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
@@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
)
+@title("Noise")
+@tags("latents", "noise")
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
- seed: int = Field(
+ seed: int = InputField(
ge=0,
le=SEED_MAX,
- description="The seed to use",
+ description=FieldDescriptions.seed,
default_factory=get_random_seed,
)
- width: int = Field(
+ width: int = InputField(
default=512,
multiple_of=8,
gt=0,
- description="The width of the resulting noise",
+ description=FieldDescriptions.width,
)
- height: int = Field(
+ height: int = InputField(
default=512,
multiple_of=8,
gt=0,
- description="The height of the resulting noise",
+ description=FieldDescriptions.height,
)
- use_cpu: bool = Field(
+ use_cpu: bool = InputField(
default=True,
description="Use CPU for noise generation (for reproducible results across platforms)",
)
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Noise",
- "tags": ["latents", "noise"],
- },
- }
-
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py
index 4f04a4f023..cd73d35d78 100644
--- a/invokeai/app/invocations/onnx.py
+++ b/invokeai/app/invocations/onnx.py
@@ -1,37 +1,44 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
+import inspect
+import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
-import re
-import inspect
-
-from pydantic import BaseModel, Field, validator
-import torch
import numpy as np
+import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
-
-from ..models.image import ImageCategory, ImageField, ResourceOrigin
-from ...backend.model_management import ONNXModelPatcher
-from ...backend.util import choose_torch_device
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
-from .compel import ConditioningField
-from .controlnet_image_processors import ControlField
-from .image import ImageOutput
-from .model import ModelInfo, UNetField, VaeField
+from pydantic import BaseModel, Field, validator
+from tqdm import tqdm
from invokeai.app.invocations.metadata import CoreMetadata
-from invokeai.backend import BaseModelType, ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_step_callback
+from invokeai.backend import BaseModelType, ModelType, SubModelType
+
+from ...backend.model_management import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
-
-from tqdm import tqdm
-from .model import ClipField
-from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
-from .compel import CompelOutput
-
+from ...backend.util import choose_torch_device
+from ..models.image import ImageCategory, ImageField, ResourceOrigin
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ FieldDescriptions,
+ InputField,
+ Input,
+ InvocationContext,
+ OutputField,
+ UIComponent,
+ UITypeHint,
+ tags,
+ title,
+)
+from .compel import CompelOutput, ConditioningField
+from .controlnet_image_processors import ControlField
+from .image import ImageOutput
+from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
+from .model import ClipField, ModelInfo, UNetField, VaeField
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
@@ -51,11 +58,13 @@ ORT_TO_NP_TYPE = {
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
+@title("ONNX Prompt (Raw)")
+@tags("onnx", "prompt")
class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx"
- prompt: str = Field(default="", description="Prompt")
- clip: ClipField = Field(None, description="Clip to use")
+ prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
+ clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
@@ -134,25 +143,48 @@ class ONNXPromptInvocation(BaseInvocation):
# Text to image
+@title("ONNX Text to Latents")
+@tags("latents", "inference", "txt2img", "onnx")
class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs
- # fmt: off
- positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
- negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
- noise: Optional[LatentsField] = Field(description="The noise to use")
- steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
- cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
- scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
- precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
- unet: UNetField = Field(default=None, description="UNet submodel")
- control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
- # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
- # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
- # fmt: on
+ positive_conditioning: ConditioningField = InputField(
+ description=FieldDescriptions.positive_cond,
+ input=Input.Connection,
+ )
+ negative_conditioning: ConditioningField = InputField(
+ description=FieldDescriptions.negative_cond,
+ input=Input.Connection,
+ )
+ noise: LatentsField = InputField(
+ description=FieldDescriptions.noise,
+ input=Input.Connection,
+ )
+ steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
+ cfg_scale: Union[float, List[float]] = InputField(
+ default=7.5,
+ ge=1,
+ description=FieldDescriptions.cfg_scale,
+ ui_type_hint=UITypeHint.Float,
+ )
+ scheduler: SAMPLER_NAME_VALUES = InputField(
+ default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
+ )
+ precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
+ unet: UNetField = InputField(
+ description=FieldDescriptions.unet,
+ input=Input.Connection,
+ )
+ control: Optional[Union[ControlField, list[ControlField]]] = InputField(
+ default=None,
+ description=FieldDescriptions.control,
+ ui_type_hint=UITypeHint.ControlField,
+ )
+ # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
+ # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
@validator("cfg_scale")
def ge_one(cls, v):
@@ -166,20 +198,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "tags": ["latents"],
- "type_hints": {
- "model": "model",
- "control": "control",
- # "cfg_scale": "float",
- "cfg_scale": "number",
- },
- },
- }
-
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -300,26 +318,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# Latent to image
+@title("ONNX Latents to Image")
+@tags("latents", "image", "vae", "onnx")
class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs
- latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
- vae: VaeField = Field(default=None, description="Vae submodel")
- metadata: Optional[CoreMetadata] = Field(
- default=None, description="Optional core metadata to be written to the image"
+ latents: LatentsField = InputField(
+ description=FieldDescriptions.denoised_latents,
+ input=Input.Connection,
)
- # tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "tags": ["latents", "image"],
- },
- }
+ vae: VaeField = InputField(
+ description=FieldDescriptions.vae,
+ input=Input.Connection,
+ )
+ metadata: Optional[CoreMetadata] = InputField(
+ default=None,
+ description=FieldDescriptions.core_metadata,
+ ui_hidden=True,
+ )
+ # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
@@ -373,89 +393,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
# fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
- unet: UNetField = Field(default=None, description="UNet submodel")
- clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
- vae_decoder: VaeField = Field(default=None, description="Vae submodel")
- vae_encoder: VaeField = Field(default=None, description="Vae submodel")
+ unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
+ clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
+ vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
+ vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
# fmt: on
-class ONNXSD1ModelLoaderInvocation(BaseInvocation):
- """Loading submodels of selected model."""
-
- type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
-
- model_name: str = Field(default="", description="Model to load")
- # TODO: precision?
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
- }
-
- def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
- model_name = "stable-diffusion-v1-5"
- base_model = BaseModelType.StableDiffusion1
-
- # TODO: not found exceptions
- if not context.services.model_manager.model_exists(
- model_name=model_name,
- base_model=BaseModelType.StableDiffusion1,
- model_type=ModelType.ONNX,
- ):
- raise Exception(f"Unkown model name: {model_name}!")
-
- return ONNXModelLoaderOutput(
- unet=UNetField(
- unet=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=ModelType.ONNX,
- submodel=SubModelType.UNet,
- ),
- scheduler=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=ModelType.ONNX,
- submodel=SubModelType.Scheduler,
- ),
- loras=[],
- ),
- clip=ClipField(
- tokenizer=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=ModelType.ONNX,
- submodel=SubModelType.Tokenizer,
- ),
- text_encoder=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=ModelType.ONNX,
- submodel=SubModelType.TextEncoder,
- ),
- loras=[],
- ),
- vae_decoder=VaeField(
- vae=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=ModelType.ONNX,
- submodel=SubModelType.VaeDecoder,
- ),
- ),
- vae_encoder=VaeField(
- vae=ModelInfo(
- model_name=model_name,
- base_model=base_model,
- model_type=ModelType.ONNX,
- submodel=SubModelType.VaeEncoder,
- ),
- ),
- )
-
-
class OnnxModelField(BaseModel):
"""Onnx model field"""
@@ -464,22 +408,17 @@ class OnnxModelField(BaseModel):
model_type: ModelType = Field(description="Model Type")
+@title("ONNX Model Loader")
+@tags("onnx", "model")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["onnx_model_loader"] = "onnx_model_loader"
- model: OnnxModelField = Field(description="The model to load")
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "Onnx Model Loader",
- "tags": ["model", "loader"],
- "type_hints": {"model": "model"},
- },
- }
+ # Inputs
+ model: OnnxModelField = InputField(
+ description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type_hint=UITypeHint.ONNXModelField
+ )
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
base_model = self.model.base_model
diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py
index f910e5379c..d35dba5df3 100644
--- a/invokeai/app/invocations/param_easing.py
+++ b/invokeai/app/invocations/param_easing.py
@@ -1,73 +1,63 @@
import io
-from typing import Literal, Optional, Any
+from typing import Literal, Optional
-# from PIL.Image import Image
-import PIL.Image
-from matplotlib.ticker import MaxNLocator
-from matplotlib.figure import Figure
-
-from pydantic import BaseModel, Field
-import numpy as np
import matplotlib.pyplot as plt
+import numpy as np
+import PIL.Image
from easing_functions import (
- LinearInOut,
- QuadEaseInOut,
- QuadEaseIn,
- QuadEaseOut,
- CubicEaseInOut,
- CubicEaseIn,
- CubicEaseOut,
- QuarticEaseInOut,
- QuarticEaseIn,
- QuarticEaseOut,
- QuinticEaseInOut,
- QuinticEaseIn,
- QuinticEaseOut,
- SineEaseInOut,
- SineEaseIn,
- SineEaseOut,
- CircularEaseIn,
- CircularEaseInOut,
- CircularEaseOut,
- ExponentialEaseInOut,
- ExponentialEaseIn,
- ExponentialEaseOut,
- ElasticEaseIn,
- ElasticEaseInOut,
- ElasticEaseOut,
BackEaseIn,
BackEaseInOut,
BackEaseOut,
BounceEaseIn,
BounceEaseInOut,
BounceEaseOut,
+ CircularEaseIn,
+ CircularEaseInOut,
+ CircularEaseOut,
+ CubicEaseIn,
+ CubicEaseInOut,
+ CubicEaseOut,
+ ElasticEaseIn,
+ ElasticEaseInOut,
+ ElasticEaseOut,
+ ExponentialEaseIn,
+ ExponentialEaseInOut,
+ ExponentialEaseOut,
+ LinearInOut,
+ QuadEaseIn,
+ QuadEaseInOut,
+ QuadEaseOut,
+ QuarticEaseIn,
+ QuarticEaseInOut,
+ QuarticEaseOut,
+ QuinticEaseIn,
+ QuinticEaseInOut,
+ QuinticEaseOut,
+ SineEaseIn,
+ SineEaseInOut,
+ SineEaseOut,
)
+from matplotlib.figure import Figure
+from matplotlib.ticker import MaxNLocator
+from pydantic import BaseModel, Field
-from .baseinvocation import (
- BaseInvocation,
- BaseInvocationOutput,
- InvocationContext,
- InvocationConfig,
-)
from ...backend.util.logging import InvokeAILogger
+from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .collections import FloatCollectionOutput
+@title("Float Range")
+@tags("math", "range")
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
- start: float = Field(default=5, description="The first value of the range")
- stop: float = Field(default=10, description="The last value of the range")
- steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
- }
+ start: float = InputField(default=5, description="The first value of the range")
+ stop: float = InputField(default=10, description="The last value of the range")
+ steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
@@ -108,37 +98,32 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut,
}
-EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
+EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
# actually I think for now could just use CollectionOutput (which is list[Any]
+@title("Step Param Easing")
+@tags("step", "easing")
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
- # fmt: off
- easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
- num_steps: int = Field(default=20, description="number of denoising steps")
- start_value: float = Field(default=0.0, description="easing starting value")
- end_value: float = Field(default=1.0, description="easing ending value")
- start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
- end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
+ easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
+ num_steps: int = InputField(default=20, description="number of denoising steps")
+ start_value: float = InputField(default=0.0, description="easing starting value")
+ end_value: float = InputField(default=1.0, description="easing ending value")
+ start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
+ end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start
- pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
+ pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
# if None, then end value is used prior to easing end
- post_end_value: Optional[float] = Field(default=None, description="value after easing end")
- mirror: bool = Field(default=False, description="include mirror of easing function")
+ post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
+ mirror: bool = InputField(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
- # alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
- show_easing_plot: bool = Field(default=False, description="show easing plot")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
- }
+ # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
+ show_easing_plot: bool = InputField(default=False, description="show easing plot")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py
index 513eb8762f..27382d8f8d 100644
--- a/invokeai/app/invocations/params.py
+++ b/invokeai/app/invocations/params.py
@@ -2,82 +2,80 @@
from typing import Literal
-from pydantic import Field
-
from invokeai.app.invocations.prompt import PromptOutput
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ InputField,
+ InvocationContext,
+ OutputField,
+ tags,
+ title,
+)
from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs
+@title("Integer Parameter")
+@tags("integer")
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
- # fmt: off
type: Literal["param_int"] = "param_int"
- a: int = Field(default=0, description="The integer value")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
- }
+ # Inputs
+ a: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
+@title("Float Parameter")
+@tags("float")
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
- # fmt: off
type: Literal["param_float"] = "param_float"
- param: float = Field(default=0.0, description="The float value")
- # fmt: on
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"tags": ["param", "float"], "title": "Float Parameter"},
- }
+ # Inputs
+ param: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
- return FloatOutput(param=self.param)
+ return FloatOutput(a=self.param)
class StringOutput(BaseInvocationOutput):
"""A string output"""
type: Literal["string_output"] = "string_output"
- text: str = Field(default=None, description="The output string")
+ text: str = OutputField(description="The output string")
+@title("String Parameter")
+@tags("string")
class ParamStringInvocation(BaseInvocation):
"""A string parameter"""
type: Literal["param_string"] = "param_string"
- text: str = Field(default="", description="The string value")
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"tags": ["param", "string"], "title": "String Parameter"},
- }
+ # Inputs
+ text: str = InputField(default="", description="The string value")
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
+@title("Prompt Parameter")
+@tags("prompt")
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
type: Literal["param_prompt"] = "param_prompt"
- prompt: str = Field(default="", description="The prompt value")
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"tags": ["param", "prompt"], "title": "Prompt"},
- }
+ # Inputs
+ prompt: str = InputField(default="", description="The prompt value")
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)
diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py
index 83a397ddcf..57320c695c 100644
--- a/invokeai/app/invocations/prompt.py
+++ b/invokeai/app/invocations/prompt.py
@@ -2,56 +2,52 @@ from os.path import exists
from typing import Literal, Optional
import numpy as np
-from pydantic import Field, validator
+from pydantic import validator
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ InputField,
+ InvocationContext,
+ OutputField,
+ UIComponent,
+ UITypeHint,
+ title,
+ tags,
+)
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
- # fmt: off
type: Literal["prompt"] = "prompt"
- prompt: str = Field(default=None, description="The output prompt")
- # fmt: on
-
- class Config:
- schema_extra = {
- "required": [
- "type",
- "prompt",
- ]
- }
+ prompt: str = OutputField(description="The output prompt")
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
- # fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
- prompt_collection: list[str] = Field(description="The output prompt collection")
- count: int = Field(description="The size of the prompt collection")
- # fmt: on
-
- class Config:
- schema_extra = {"required": ["type", "prompt_collection", "count"]}
+ prompt_collection: list[str] = OutputField(
+ description="The output prompt collection", ui_type_hint=UITypeHint.StringCollection
+ )
+ count: int = OutputField(description="The size of the prompt collection")
+@title("Dynamic Prompt")
+@tags("prompt", "collection")
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
- prompt: str = Field(description="The prompt to parse with dynamicprompts")
- max_prompts: int = Field(default=1, description="The number of prompts to generate")
- combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
- }
+ # Inputs
+ prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
+ max_prompts: int = InputField(default=1, description="The number of prompts to generate")
+ combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
@@ -64,24 +60,23 @@ class DynamicPromptInvocation(BaseInvocation):
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
+@title("Prompts from File")
+@tags("prompt", "file")
class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file"""
- # fmt: off
- type: Literal['prompt_from_file'] = 'prompt_from_file'
+ type: Literal["prompt_from_file"] = "prompt_from_file"
# Inputs
- file_path: str = Field(description="Path to prompt text file")
- pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
- post_prompt: Optional[str] = Field(description="String to append to each prompt")
- start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
- max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
- # fmt: on
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
- }
+ file_path: str = InputField(description="Path to prompt text file", ui_type_hint=UITypeHint.FilePath)
+ pre_prompt: Optional[str] = InputField(
+ description="String to prepend to each prompt", ui_component=UIComponent.Textarea
+ )
+ post_prompt: Optional[str] = InputField(
+ description="String to append to each prompt", ui_component=UIComponent.Textarea
+ )
+ start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
+ max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
@validator("file_path")
def file_path_exists(cls, v):
diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py
index a5a1c2c641..d25a37327e 100644
--- a/invokeai/app/invocations/sdxl.py
+++ b/invokeai/app/invocations/sdxl.py
@@ -1,55 +1,55 @@
-import torch
from typing import Literal
-from pydantic import Field
from ...backend.model_management import ModelType, SubModelType
-from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
-from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ FieldDescriptions,
+ Input,
+ InputField,
+ InvocationContext,
+ OutputField,
+ UITypeHint,
+ tags,
+ title,
+)
+from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
- # fmt: off
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
- unet: UNetField = Field(default=None, description="UNet submodel")
- clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
- clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
- vae: VaeField = Field(default=None, description="Vae submodel")
- # fmt: on
+ unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
+ clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
+ clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
+ vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
- # fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
- unet: UNetField = Field(default=None, description="UNet submodel")
- clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
- vae: VaeField = Field(default=None, description="Vae submodel")
- # fmt: on
- # fmt: on
+
+ unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
+ clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
+ vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
+@title("SDXL Main Model Loader")
+@tags("model", "sdxl")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
- model: MainModelField = Field(description="The model to load")
+ # Inputs
+ model: MainModelField = InputField(
+ description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type_hint=UITypeHint.SDXLMainModelField
+ )
# TODO: precision?
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "SDXL Model Loader",
- "tags": ["model", "loader", "sdxl"],
- "type_hints": {"model": "model"},
- },
- }
-
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
@@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation):
)
+@title("SDXL Refiner Model Loader")
+@tags("model", "sdxl", "refiner")
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
- model: MainModelField = Field(description="The model to load")
+ # Inputs
+ model: MainModelField = InputField(
+ description=FieldDescriptions.sdxl_refiner_model,
+ input=Input.Direct,
+ ui_type_hint=UITypeHint.SDXLRefinerModelField,
+ )
# TODO: precision?
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "title": "SDXL Refiner Model Loader",
- "tags": ["model", "loader", "sdxl_refiner"],
- "type_hints": {"model": "refiner_model"},
- },
- }
-
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py
index fd220223db..4e9c9fac2f 100644
--- a/invokeai/app/invocations/upscale.py
+++ b/invokeai/app/invocations/upscale.py
@@ -6,12 +6,11 @@ import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
-from pydantic import Field
from realesrgan import RealESRGANer
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
-from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
+from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
from .image import ImageOutput
# TODO: Populate this from disk?
@@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[
]
+@title("Upscale (RealESRGAN)")
+@tags("esrgan", "upscale")
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
type: Literal["esrgan"] = "esrgan"
- image: Union[ImageField, None] = Field(default=None, description="The input image")
- model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
- }
+ # Inputs
+ image: ImageField = InputField(description="The input image")
+ model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py
index 2a5a0f9d3b..3a34a7d6da 100644
--- a/invokeai/app/models/image.py
+++ b/invokeai/app/models/image.py
@@ -5,14 +5,13 @@ from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
- InvocationConfig,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
- image_name: Optional[str] = Field(default=None, description="The name of the image")
+ image_name: str = Field(description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
@@ -36,17 +35,6 @@ class ProgressImage(BaseModel):
dataURL: str = Field(description="The image data as a b64 data URL")
-class PILInvocationConfig(BaseModel):
- """Helper class to provide all PIL invocations with additional config"""
-
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "tags": ["PIL", "image"],
- },
- }
-
-
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py
index d7f021df14..5dacfe1ec1 100644
--- a/invokeai/app/services/graph.py
+++ b/invokeai/app/services/graph.py
@@ -3,16 +3,7 @@
import copy
import itertools
import uuid
-from typing import (
- Annotated,
- Any,
- Literal,
- Optional,
- Union,
- get_args,
- get_origin,
- get_type_hints,
-)
+from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, root_validator, validator
@@ -22,7 +13,11 @@ from ..invocations import *
from ..invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
+ Input,
+ InputField,
InvocationContext,
+ OutputField,
+ UITypeHint,
)
# in 3.10 this would be "from types import NoneType"
@@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
type: Literal["iterate_output"] = "iterate_output"
- item: Any = Field(description="The item being iterated over")
-
- class Config:
- schema_extra = {
- "required": [
- "type",
- "item",
- ]
- }
+ item: Any = OutputField(
+ description="The item being iterated over", title="Collection Item", ui_type_hint=UITypeHint.CollectionItem
+ )
# TODO: Fill this out and move to invocations
@@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation):
type: Literal["iterate"] = "iterate"
- collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
- index: int = Field(description="The index, will be provided on executed iterators", default=0)
+ collection: list[Any] = InputField(
+ description="The list of items to iterate over", default_factory=list, ui_type_hint=UITypeHint.Collection
+ )
+ index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
@@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation):
class CollectInvocationOutput(BaseInvocationOutput):
type: Literal["collect_output"] = "collect_output"
- collection: list[Any] = Field(description="The collection of input items")
-
- class Config:
- schema_extra = {
- "required": [
- "type",
- "collection",
- ]
- }
+ collection: list[Any] = OutputField(
+ description="The collection of input items", title="Collection", ui_type_hint=UITypeHint.Collection
+ )
class CollectInvocation(BaseInvocation):
@@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation):
type: Literal["collect"] = "collect"
- item: Any = Field(
+ item: Any = InputField(
description="The item to collect (all inputs must be of the same type)",
- default=None,
+ ui_type_hint=UITypeHint.CollectionItem,
+ title="Collection Item",
+ input=Input.Connection,
)
- collection: list[Any] = Field(
- description="The collection, will be provided on execution",
- default_factory=list,
+ collection: list[Any] = InputField(
+ description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
)
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py
index 41170a304b..b8c2f93e93 100644
--- a/invokeai/app/services/processor.py
+++ b/invokeai/app/services/processor.py
@@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state.id):
- outputs = invocation.invoke(
+ # use the internal invoke_internal(), which wraps the node's invoke() method in
+ # this accomodates nodes which require a value, but get it only from a
+ # connection
+ outputs = invocation.invoke_internal(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py
index 855f3f1939..3c46b1c2a0 100644
--- a/invokeai/app/services/sqlite.py
+++ b/invokeai/app/services/sqlite.py
@@ -49,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
- return parse_raw_as(item_type, item)
+ parsed = parse_raw_as(item_type, item)
+ return parsed
def set(self, item: T):
try:
diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json
index 8cc2c158be..6c9db74bbc 100644
--- a/invokeai/frontend/web/package.json
+++ b/invokeai/frontend/web/package.json
@@ -61,6 +61,7 @@
"@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1",
+ "@dnd-kit/utilities": "^3.2.1",
"@emotion/react": "^11.11.1",
"@emotion/styled": "^11.11.0",
"@floating-ui/react-dom": "^2.0.1",
diff --git a/invokeai/frontend/web/scripts/colors.js b/invokeai/frontend/web/scripts/colors.js
new file mode 100644
index 0000000000..3fc8f8d751
--- /dev/null
+++ b/invokeai/frontend/web/scripts/colors.js
@@ -0,0 +1,34 @@
+export const COLORS = {
+ reset: '\x1b[0m',
+ bright: '\x1b[1m',
+ dim: '\x1b[2m',
+ underscore: '\x1b[4m',
+ blink: '\x1b[5m',
+ reverse: '\x1b[7m',
+ hidden: '\x1b[8m',
+
+ fg: {
+ black: '\x1b[30m',
+ red: '\x1b[31m',
+ green: '\x1b[32m',
+ yellow: '\x1b[33m',
+ blue: '\x1b[34m',
+ magenta: '\x1b[35m',
+ cyan: '\x1b[36m',
+ white: '\x1b[37m',
+ gray: '\x1b[90m',
+ crimson: '\x1b[38m',
+ },
+ bg: {
+ black: '\x1b[40m',
+ red: '\x1b[41m',
+ green: '\x1b[42m',
+ yellow: '\x1b[43m',
+ blue: '\x1b[44m',
+ magenta: '\x1b[45m',
+ cyan: '\x1b[46m',
+ white: '\x1b[47m',
+ gray: '\x1b[100m',
+ crimson: '\x1b[48m',
+ },
+};
diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js
index ec67c48f2d..d105917e66 100644
--- a/invokeai/frontend/web/scripts/typegen.js
+++ b/invokeai/frontend/web/scripts/typegen.js
@@ -1,23 +1,83 @@
import fs from 'node:fs';
import openapiTS from 'openapi-typescript';
+import { COLORS } from './colors.js';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
async function main() {
process.stdout.write(
- `Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
+ `Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n`
);
const types = await openapiTS(OPENAPI_URL, {
exportType: true,
- transform: (schemaObject) => {
+ transform: (schemaObject, metadata) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
+
+ /**
+ * Because invocations may have required fields that accept connection input, the generated
+ * types may be incorrect.
+ *
+ * For example, the ImageResizeInvocation has a required `image` field, but because it accepts
+ * connection input, it should be optional on instantiation of the field.
+ *
+ * To handle this, the schema exposes an `input` property that can be used to determine if the
+ * field accepts connection input. If it does, we can make the field optional.
+ */
+
+ // Check if we are generating types for an invocation
+ const isInvocationPath = metadata.path.match(
+ /^#\/components\/schemas\/\w*Invocation$/
+ );
+
+ const hasInvocationProperties =
+ schemaObject.properties &&
+ ['id', 'is_intermediate', 'type'].every(
+ (prop) => prop in schemaObject.properties
+ );
+
+ if (isInvocationPath && hasInvocationProperties) {
+ // We only want to make fields optional if they are required
+ if (!Array.isArray(schemaObject?.required)) {
+ schemaObject.required = ['id', 'type'];
+ return;
+ }
+
+ schemaObject.required.forEach((prop) => {
+ const acceptsConnection = ['any', 'connection'].includes(
+ schemaObject.properties?.[prop]?.['input']
+ );
+
+ if (acceptsConnection) {
+ // remove this prop from the required array
+ const invocationName = metadata.path.split('/').pop();
+ console.log(
+ `Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}`
+ );
+ schemaObject.required = schemaObject.required.filter(
+ (r) => r !== prop
+ );
+ }
+ });
+
+ schemaObject.required = [
+ ...new Set(schemaObject.required.concat(['id', 'type'])),
+ ];
+
+ return;
+ }
+ // if (
+ // 'input' in schemaObject &&
+ // (schemaObject.input === 'any' || schemaObject.input === 'connection')
+ // ) {
+ // schemaObject.required = false;
+ // }
},
});
fs.writeFileSync(OUTPUT_FILE, types);
- process.stdout.write(` OK!\r\n`);
+ process.stdout.write(`\nOK!\r\n`);
}
main();
diff --git a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts
index 9827e7f2b3..bbe77dc698 100644
--- a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts
+++ b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts
@@ -1,8 +1,12 @@
import { createSelector } from '@reduxjs/toolkit';
-import { RootState } from 'app/store/store';
+import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
-import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
+import {
+ ctrlKeyPressed,
+ metaKeyPressed,
+ shiftKeyPressed,
+} from 'features/ui/store/hotkeysSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
@@ -16,11 +20,11 @@ import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector(
- [(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
- (hotkeys, ui) => {
- const { shift } = hotkeys;
+ [stateSelector],
+ ({ hotkeys, ui }) => {
+ const { shift, ctrl, meta } = hotkeys;
const { shouldPinParametersPanel, shouldPinGallery } = ui;
- return { shift, shouldPinGallery, shouldPinParametersPanel };
+ return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
},
{
memoizeOptions: {
@@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector(
*/
const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch();
- const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
- globalHotkeysSelector
- );
+ const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
+ useAppSelector(globalHotkeysSelector);
const activeTabName = useAppSelector(activeTabNameSelector);
useHotkeys(
@@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => {
} else {
shift && dispatch(shiftKeyPressed(false));
}
+ if (isHotkeyPressed('ctrl')) {
+ !ctrl && dispatch(ctrlKeyPressed(true));
+ } else {
+ ctrl && dispatch(ctrlKeyPressed(false));
+ }
+ if (isHotkeyPressed('meta')) {
+ !meta && dispatch(metaKeyPressed(true));
+ } else {
+ meta && dispatch(metaKeyPressed(false));
+ }
},
{ keyup: true, keydown: true },
- [shift]
+ [shift, ctrl, meta]
);
useHotkeys('o', () => {
diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx
index 93b7825db7..7e2ed7f571 100644
--- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx
+++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx
@@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import Loading from '../../common/components/Loading/Loading';
import '../../i18n';
-import ImageDndContext from './ImageDnd/ImageDndContext';
+import AppDndContext from '../../features/dnd/components/AppDndContext';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@@ -80,9 +80,9 @@ const InvokeAIUI = ({
}>
-
+
-
+
diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts
index ef27c98d1f..7797b8dc92 100644
--- a/invokeai/frontend/web/src/app/logging/logger.ts
+++ b/invokeai/frontend/web/src/app/logging/logger.ts
@@ -19,7 +19,8 @@ type LoggerNamespace =
| 'nodes'
| 'system'
| 'socketio'
- | 'session';
+ | 'session'
+ | 'dnd';
export const logger = (namespace: LoggerNamespace) =>
$logger.get().child({ namespace });
diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts
index 6d41d488c8..a596fce931 100644
--- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts
@@ -15,7 +15,7 @@ export const actionsDenylist = [
'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress',
// every time user presses shift
- 'hotkeys/shiftKeyPressed',
+ // 'hotkeys/shiftKeyPressed',
// this happens after every state change
'@@REMEMBER_PERSISTED',
];
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts
index 043105cb66..fc0b44653d 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts
@@ -1,16 +1,20 @@
import { createAction } from '@reduxjs/toolkit';
-import {
- TypesafeDraggableData,
- TypesafeDroppableData,
-} from 'app/components/ImageDnd/typesafeDnd';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
+import {
+ TypesafeDraggableData,
+ TypesafeDroppableData,
+} from 'features/dnd/types';
import { imageSelected } from 'features/gallery/store/gallerySlice';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+import {
+ fieldImageValueChanged,
+ workflowExposedFieldAdded,
+} from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '../';
+import { parseify } from 'common/util/serialize';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
@@ -21,7 +25,7 @@ export const addImageDroppedListener = () => {
startAppListening({
actionCreator: dndDropped,
effect: async (action, { dispatch }) => {
- const log = logger('images');
+ const log = logger('dnd');
const { activeData, overData } = action.payload;
if (activeData.payloadType === 'IMAGE_DTO') {
@@ -31,10 +35,28 @@ export const addImageDroppedListener = () => {
{ activeData, overData },
`Images (${activeData.payload.imageDTOs.length}) dropped`
);
+ } else if (activeData.payloadType === 'NODE_FIELD') {
+ log.debug(
+ { activeData: parseify(activeData), overData: parseify(overData) },
+ 'Node field dropped'
+ );
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
+ if (
+ overData.actionType === 'ADD_FIELD_TO_LINEAR' &&
+ activeData.payloadType === 'NODE_FIELD'
+ ) {
+ const { nodeId, field } = activeData.payload;
+ dispatch(
+ workflowExposedFieldAdded({
+ nodeId,
+ fieldName: field.name,
+ })
+ );
+ }
+
/**
* Image dropped on current image
*/
@@ -99,7 +121,7 @@ export const addImageDroppedListener = () => {
) {
const { fieldName, nodeId } = overData.context;
dispatch(
- fieldValueChanged({
+ fieldImageValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
index 6dc2d482a9..0c55908748 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
@@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { omit } from 'lodash-es';
@@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
- dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO }));
+ dispatch(
+ fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })
+ );
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
index 436a58aa8e..4d30ee3b8b 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts
@@ -15,12 +15,21 @@ import {
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
-import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
+import {
+ mainModelsAdapter,
+ modelsApi,
+ vaeModelsAdapter,
+} from 'services/api/endpoints/models';
+import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
- predicate: (state, action) =>
+ predicate: (
+ action
+ ): action is TypeGuardFor<
+ typeof modelsApi.endpoints.getMainModels.matchFulfilled
+ > =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
@@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().generation.model;
+ const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
- const isCurrentModelAvailable = some(
- action.payload.entities,
- (m) =>
- m?.model_name === currentModel?.model_name &&
- m?.base_model === currentModel?.base_model &&
- m?.model_type === currentModel?.model_type
- );
-
- if (isCurrentModelAvailable) {
- return;
- }
-
- const firstModelId = action.payload.ids[0];
- const firstModel = action.payload.entities[firstModelId];
-
- if (!firstModel) {
+ if (models.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
- const result = zMainOrOnnxModel.safeParse(firstModel);
+ const isCurrentModelAvailable = currentModel
+ ? models.some(
+ (m) =>
+ m.model_name === currentModel.model_name &&
+ m.base_model === currentModel.base_model &&
+ m.model_type === currentModel.model_type
+ )
+ : false;
+
+ if (isCurrentModelAvailable) {
+ return;
+ }
+
+ const result = zMainOrOnnxModel.safeParse(models[0]);
if (!result.success) {
log.error(
@@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => {
},
});
startAppListening({
- predicate: (state, action) =>
+ predicate: (
+ action
+ ): action is TypeGuardFor<
+ typeof modelsApi.endpoints.getMainModels.matchFulfilled
+ > =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
@@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().sdxl.refinerModel;
+ const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
- const isCurrentModelAvailable = some(
- action.payload.entities,
- (m) =>
- m?.model_name === currentModel?.model_name &&
- m?.base_model === currentModel?.base_model &&
- m?.model_type === currentModel?.model_type
- );
-
- if (isCurrentModelAvailable) {
- return;
- }
-
- const firstModelId = action.payload.ids[0];
- const firstModel = action.payload.entities[firstModelId];
-
- if (!firstModel) {
+ if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setShouldUseSDXLRefiner(false));
return;
}
- const result = zSDXLRefinerModel.safeParse(firstModel);
+ const isCurrentModelAvailable = currentModel
+ ? models.some(
+ (m) =>
+ m.model_name === currentModel.model_name &&
+ m.base_model === currentModel.base_model &&
+ m.model_type === currentModel.model_type
+ )
+ : false;
+
+ if (isCurrentModelAvailable) {
+ return;
+ }
+
+ const result = zSDXLRefinerModel.safeParse(models[0]);
if (!result.success) {
log.error(
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts
index 44729f215a..dd86c77735 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts
@@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => {
const log = logger('system');
const schemaJSON = action.payload;
- log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema');
+ log.debug({ schemaJSON }, 'Received OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
@@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.rejected,
- effect: () => {
+ effect: (action) => {
const log = logger('system');
- log.error('Problem dereferencing OpenAPI Schema');
+ log.error(
+ { error: parseify(action.error) },
+ 'Problem retrieving OpenAPI Schema'
+ );
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
index 5b3b9424b6..5501f208fd 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
@@ -19,7 +19,7 @@ import {
} from 'services/events/actions';
import { startAppListening } from '../..';
-const nodeDenylist = ['dataURL_image'];
+const nodeDenylist = ['load_image'];
export const addInvocationCompleteEventListener = () => {
startAppListening({
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts
index 0c298cbb24..5894bba5df 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts
@@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => {
const log = logger('session');
const state = getState();
- const graph = buildNodesGraph(state);
+ const graph = buildNodesGraph(state.nodes);
dispatch(nodesGraphBuilt(graph));
log.debug({ graph: parseify(graph) }, 'Nodes graph built');
diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts
index 827424fa7f..a39ed2ca7b 100644
--- a/invokeai/frontend/web/src/app/types/invokeai.ts
+++ b/invokeai/frontend/web/src/app/types/invokeai.ts
@@ -1,86 +1,7 @@
-import {
- // CONTROLNET_MODELS,
- CONTROLNET_PROCESSORS,
-} from 'features/controlNet/store/constants';
+import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt';
-// These are old types from the model management UI
-
-// export type ModelStatus = 'active' | 'cached' | 'not loaded';
-
-// export type Model = {
-// status: ModelStatus;
-// description: string;
-// weights: string;
-// config?: string;
-// vae?: string;
-// width?: number;
-// height?: number;
-// default?: boolean;
-// format?: string;
-// };
-
-// export type DiffusersModel = {
-// status: ModelStatus;
-// description: string;
-// repo_id?: string;
-// path?: string;
-// vae?: {
-// repo_id?: string;
-// path?: string;
-// };
-// format?: string;
-// default?: boolean;
-// };
-
-// export type ModelList = Record;
-
-// export type FoundModel = {
-// name: string;
-// location: string;
-// };
-
-// export type InvokeModelConfigProps = {
-// name: string | undefined;
-// description: string | undefined;
-// config: string | undefined;
-// weights: string | undefined;
-// vae: string | undefined;
-// width: number | undefined;
-// height: number | undefined;
-// default: boolean | undefined;
-// format: string | undefined;
-// };
-
-// export type InvokeDiffusersModelConfigProps = {
-// name: string | undefined;
-// description: string | undefined;
-// repo_id: string | undefined;
-// path: string | undefined;
-// default: boolean | undefined;
-// format: string | undefined;
-// vae: {
-// repo_id: string | undefined;
-// path: string | undefined;
-// };
-// };
-
-// export type InvokeModelConversionProps = {
-// model_name: string;
-// save_location: string;
-// custom_location: string | null;
-// };
-
-// export type InvokeModelMergingProps = {
-// models_to_merge: string[];
-// alpha: number;
-// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
-// force: boolean;
-// merged_model_name: string;
-// model_merge_save_path: string | null;
-// };
-
/**
* A disable-able application feature
*/
diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx
index 780447aba6..defe600b78 100644
--- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx
+++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx
@@ -6,10 +6,6 @@ import {
useColorMode,
useColorModeValue,
} from '@chakra-ui/react';
-import {
- TypesafeDraggableData,
- TypesafeDroppableData,
-} from 'app/components/ImageDnd/typesafeDnd';
import IAIIconButton from 'common/components/IAIIconButton';
import {
IAILoadingImageFallback,
@@ -17,6 +13,10 @@ import {
} from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
+import {
+ TypesafeDraggableData,
+ TypesafeDroppableData,
+} from 'features/dnd/types';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import {
MouseEvent,
@@ -157,11 +157,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
)
}
- width={imageDTO.width}
- height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
+ w: imageDTO.width,
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
@@ -213,13 +212,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick}
/>
)}
- {!isDropDisabled && (
-
- )}
{onClickReset && withResetIcon && imageDTO && (
{
}}
/>
)}
+ {!isDropDisabled && (
+
+ )}
)}
diff --git a/invokeai/frontend/web/src/common/components/IAIDraggable.tsx b/invokeai/frontend/web/src/common/components/IAIDraggable.tsx
index 482a8ac604..363799a573 100644
--- a/invokeai/frontend/web/src/common/components/IAIDraggable.tsx
+++ b/invokeai/frontend/web/src/common/components/IAIDraggable.tsx
@@ -1,22 +1,19 @@
-import { Box } from '@chakra-ui/react';
-import {
- TypesafeDraggableData,
- useDraggable,
-} from 'app/components/ImageDnd/typesafeDnd';
-import { MouseEvent, memo, useRef } from 'react';
+import { Box, BoxProps } from '@chakra-ui/react';
+import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
+import { TypesafeDraggableData } from 'features/dnd/types';
+import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
-type IAIDraggableProps = {
+type IAIDraggableProps = BoxProps & {
disabled?: boolean;
data?: TypesafeDraggableData;
- onClick?: (event: MouseEvent) => void;
};
const IAIDraggable = (props: IAIDraggableProps) => {
- const { data, disabled, onClick } = props;
+ const { data, disabled, ...rest } = props;
const dndId = useRef(uuidv4());
- const { attributes, listeners, setNodeRef } = useDraggable({
+ const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
id: dndId.current,
disabled,
data,
@@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => {
return (
{
insetInlineStart={0}
{...attributes}
{...listeners}
+ {...rest}
/>
);
};
diff --git a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx
index 1038f36840..e4fb121c78 100644
--- a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx
+++ b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx
@@ -1,9 +1,7 @@
import { Box } from '@chakra-ui/react';
-import {
- TypesafeDroppableData,
- isValidDrop,
- useDroppable,
-} from 'app/components/ImageDnd/typesafeDnd';
+import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
+import { TypesafeDroppableData } from 'features/dnd/types';
+import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { AnimatePresence } from 'framer-motion';
import { ReactNode, memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
@@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4());
- const { isOver, setNodeRef, active } = useDroppable({
+ const { isOver, setNodeRef, active } = useDroppableTypesafe({
id: dndId.current,
disabled,
data,
diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx
index 2057525b7a..a150e4ed0c 100644
--- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx
+++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx
@@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => {
type IAINoImageFallbackProps = {
label?: string;
- icon?: As;
+ icon?: As | null;
boxSize?: StyleProps['boxSize'];
sx?: ChakraProps['sx'];
};
@@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
...props.sx,
}}
>
-
+ {icon && }
{props.label && {props.label}}
);
diff --git a/invokeai/frontend/web/src/common/components/IAISwitch.tsx b/invokeai/frontend/web/src/common/components/IAISwitch.tsx
index 9803626397..da0883d77e 100644
--- a/invokeai/frontend/web/src/common/components/IAISwitch.tsx
+++ b/invokeai/frontend/web/src/common/components/IAISwitch.tsx
@@ -1,10 +1,13 @@
import {
+ Flex,
FormControl,
FormControlProps,
+ FormHelperText,
FormLabel,
FormLabelProps,
Switch,
SwitchProps,
+ Text,
Tooltip,
} from '@chakra-ui/react';
import { memo } from 'react';
@@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps {
formControlProps?: FormControlProps;
formLabelProps?: FormLabelProps;
tooltip?: string;
+ helperText?: string;
}
/**
@@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => {
formControlProps,
formLabelProps,
tooltip,
+ helperText,
...rest
} = props;
return (
@@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => {
- {label && (
-
- {label}
-
- )}
-
+
+
+ {label && (
+
+ {label}
+
+ )}
+
+
+ {helperText && (
+
+ {helperText}
+
+ )}
+
);
diff --git a/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts b/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts
index 770add7253..0afb7e7e5d 100644
--- a/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts
+++ b/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts
@@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => {
accent850,
accent900,
accent950,
+ baseAlpha50,
+ baseAlpha100,
+ baseAlpha150,
+ baseAlpha200,
+ baseAlpha250,
+ baseAlpha300,
+ baseAlpha350,
+ baseAlpha400,
+ baseAlpha450,
+ baseAlpha500,
+ baseAlpha550,
+ baseAlpha600,
+ baseAlpha650,
+ baseAlpha700,
+ baseAlpha750,
+ baseAlpha800,
+ baseAlpha850,
+ baseAlpha900,
+ baseAlpha950,
+ accentAlpha50,
+ accentAlpha100,
+ accentAlpha150,
+ accentAlpha200,
+ accentAlpha250,
+ accentAlpha300,
+ accentAlpha350,
+ accentAlpha400,
+ accentAlpha450,
+ accentAlpha500,
+ accentAlpha550,
+ accentAlpha600,
+ accentAlpha650,
+ accentAlpha700,
+ accentAlpha750,
+ accentAlpha800,
+ accentAlpha850,
+ accentAlpha900,
+ accentAlpha950,
] = useToken('colors', [
'base.50',
'base.100',
@@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => {
'accent.850',
'accent.900',
'accent.950',
+ 'baseAlpha.50',
+ 'baseAlpha.100',
+ 'baseAlpha.150',
+ 'baseAlpha.200',
+ 'baseAlpha.250',
+ 'baseAlpha.300',
+ 'baseAlpha.350',
+ 'baseAlpha.400',
+ 'baseAlpha.450',
+ 'baseAlpha.500',
+ 'baseAlpha.550',
+ 'baseAlpha.600',
+ 'baseAlpha.650',
+ 'baseAlpha.700',
+ 'baseAlpha.750',
+ 'baseAlpha.800',
+ 'baseAlpha.850',
+ 'baseAlpha.900',
+ 'baseAlpha.950',
+ 'accentAlpha.50',
+ 'accentAlpha.100',
+ 'accentAlpha.150',
+ 'accentAlpha.200',
+ 'accentAlpha.250',
+ 'accentAlpha.300',
+ 'accentAlpha.350',
+ 'accentAlpha.400',
+ 'accentAlpha.450',
+ 'accentAlpha.500',
+ 'accentAlpha.550',
+ 'accentAlpha.600',
+ 'accentAlpha.650',
+ 'accentAlpha.700',
+ 'accentAlpha.750',
+ 'accentAlpha.800',
+ 'accentAlpha.850',
+ 'accentAlpha.900',
+ 'accentAlpha.950',
]);
return {
@@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => {
accent850,
accent900,
accent950,
+ baseAlpha50,
+ baseAlpha100,
+ baseAlpha150,
+ baseAlpha200,
+ baseAlpha250,
+ baseAlpha300,
+ baseAlpha350,
+ baseAlpha400,
+ baseAlpha450,
+ baseAlpha500,
+ baseAlpha550,
+ baseAlpha600,
+ baseAlpha650,
+ baseAlpha700,
+ baseAlpha750,
+ baseAlpha800,
+ baseAlpha850,
+ baseAlpha900,
+ baseAlpha950,
+ accentAlpha50,
+ accentAlpha100,
+ accentAlpha150,
+ accentAlpha200,
+ accentAlpha250,
+ accentAlpha300,
+ accentAlpha350,
+ accentAlpha400,
+ accentAlpha450,
+ accentAlpha500,
+ accentAlpha550,
+ accentAlpha600,
+ accentAlpha650,
+ accentAlpha700,
+ accentAlpha750,
+ accentAlpha800,
+ accentAlpha850,
+ accentAlpha900,
+ accentAlpha950,
};
};
diff --git a/invokeai/frontend/web/src/common/util/serialize.ts b/invokeai/frontend/web/src/common/util/serialize.ts
index a9352a8228..a5db921f8d 100644
--- a/invokeai/frontend/web/src/common/util/serialize.ts
+++ b/invokeai/frontend/web/src/common/util/serialize.ts
@@ -1,4 +1,10 @@
/**
* Serialize an object to JSON and back to a new object
*/
-export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj));
+export const parseify = (obj: unknown) => {
+ try {
+ return JSON.parse(JSON.stringify(obj));
+ } catch {
+ return 'Error parsing object';
+ }
+};
diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
index cdab176cd2..4fffb82275 100644
--- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
@@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
-} from 'app/components/ImageDnd/typesafeDnd';
+} from 'features/dnd/types';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
diff --git a/invokeai/frontend/web/src/features/controlNet/store/types.ts b/invokeai/frontend/web/src/features/controlNet/store/types.ts
index 2d028fd0bb..80edb41699 100644
--- a/invokeai/frontend/web/src/features/controlNet/store/types.ts
+++ b/invokeai/frontend/web/src/features/controlNet/store/types.ts
@@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
/**
* Any ControlNet Processor node, with its parameters flagged as required
*/
-export type RequiredControlNetProcessorNode =
+export type RequiredControlNetProcessorNode = O.Required<
| RequiredCannyImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
| RequiredHedImageProcessorInvocation
@@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode =
| RequiredNormalbaeImageProcessorInvocation
| RequiredOpenposeImageProcessorInvocation
| RequiredPidiImageProcessorInvocation
- | RequiredZoeDepthImageProcessorInvocation;
+ | RequiredZoeDepthImageProcessorInvocation,
+ 'id'
+>;
/**
* Type guard for CannyImageProcessorInvocation
diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts
index 310521f32a..37be06bad6 100644
--- a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts
+++ b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts
@@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
+import { isInvocationNode } from 'features/nodes/types/types';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state;
@@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
- const isNodesImage = nodes.nodes.some((node) => {
+ const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => {
return some(
node.data.inputs,
(input) =>
- input.type === 'image' && input.value?.image_name === image_name
+ input.type === 'ImageField' && input.value?.image_name === image_name
);
});
diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx b/invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx
similarity index 70%
rename from invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx
rename to invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx
index 56eeb9b5db..bffe738aa9 100644
--- a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx
+++ b/invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx
@@ -6,23 +6,18 @@ import {
useSensor,
useSensors,
} from '@dnd-kit/core';
-import { snapCenterToCursor } from '@dnd-kit/modifiers';
+import { logger } from 'app/logging/logger';
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { useAppDispatch } from 'app/store/storeHooks';
+import { parseify } from 'common/util/serialize';
import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo, useCallback, useState } from 'react';
+import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
+import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
+import { DndContextTypesafe } from './DndContextTypesafe';
import DragPreview from './DragPreview';
-import {
- DndContext,
- DragEndEvent,
- DragStartEvent,
- TypesafeDraggableData,
-} from './typesafeDnd';
-import { logger } from 'app/logging/logger';
-type ImageDndContextProps = PropsWithChildren;
-
-const ImageDndContext = (props: ImageDndContextProps) => {
+const AppDndContext = (props: PropsWithChildren) => {
const [activeDragData, setActiveDragData] =
useState(null);
const log = logger('images');
@@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragStart = useCallback(
(event: DragStartEvent) => {
- log.trace({ dragData: event.active.data.current }, 'Drag started');
+ log.trace(
+ { dragData: parseify(event.active.data.current) },
+ 'Drag started'
+ );
const activeData = event.active.data.current;
if (!activeData) {
return;
@@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragEnd = useCallback(
(event: DragEndEvent) => {
- log.trace({ dragData: event.active.data.current }, 'Drag ended');
+ log.trace(
+ { dragData: parseify(event.active.data.current) },
+ 'Drag ended'
+ );
const overData = event.over?.data.current;
if (!activeDragData || !overData) {
return;
@@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const sensors = useSensors(mouseSensor, touchSensor);
+ const scaledModifier = useScaledModifer();
+
return (
-
{props.children}
-
+
{activeDragData && (
{
)}
-
+
);
};
-export default memo(ImageDndContext);
+export default memo(AppDndContext);
diff --git a/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx b/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx
new file mode 100644
index 0000000000..06fede4dc8
--- /dev/null
+++ b/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx
@@ -0,0 +1,6 @@
+import { DndContext } from '@dnd-kit/core';
+import { DndContextTypesafeProps } from '../types';
+
+export function DndContextTypesafe(props: DndContextTypesafeProps) {
+ return ;
+}
diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx
similarity index 69%
rename from invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx
rename to invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx
index c97778ffcd..0ee5d34b1a 100644
--- a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx
+++ b/invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx
@@ -1,6 +1,6 @@
-import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
+import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react';
import { memo } from 'react';
-import { TypesafeDraggableData } from './typesafeDnd';
+import { TypesafeDraggableData } from '../types';
type OverlayDragImageProps = {
dragData: TypesafeDraggableData | null;
@@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => {
return null;
}
+ if (props.dragData.payloadType === 'NODE_FIELD') {
+ const { field, fieldTemplate } = props.dragData.payload;
+ return (
+
+ {field.label || fieldTemplate.title}
+
+ );
+ }
+
if (props.dragData.payloadType === 'IMAGE_DTO') {
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
return (
{
return (
(activeTabName === 'nodes' ? nodes.zoom : 1)
+);
+
+/**
+ * Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
+ */
+export const useScaledModifer = () => {
+ const zoom = useAppSelector(selectZoom);
+ const modifier: Modifier = useCallback(
+ ({ activatorEvent, draggingNodeRect, transform }) => {
+ if (draggingNodeRect && activatorEvent) {
+ const activatorCoordinates = getEventCoordinates(activatorEvent);
+
+ if (!activatorCoordinates) {
+ return transform;
+ }
+
+ const offsetX = activatorCoordinates.x - draggingNodeRect.left;
+ const offsetY = activatorCoordinates.y - draggingNodeRect.top;
+
+ const x = transform.x + offsetX - draggingNodeRect.width / 2;
+ const y = transform.y + offsetY - draggingNodeRect.height / 2;
+ const scaleX = transform.scaleX * zoom;
+ const scaleY = transform.scaleY * zoom;
+
+ return {
+ x,
+ y,
+ scaleX,
+ scaleY,
+ };
+ }
+
+ return transform;
+ },
+ [zoom]
+ );
+
+ return modifier;
+};
diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/features/dnd/types/index.ts
similarity index 51%
rename from invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx
rename to invokeai/frontend/web/src/features/dnd/types/index.ts
index 6f24302070..294132d0a3 100644
--- a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx
+++ b/invokeai/frontend/web/src/features/dnd/types/index.ts
@@ -3,7 +3,6 @@ import {
Active,
Collision,
DndContextProps,
- DndContext as OriginalDndContext,
Over,
Translate,
UseDraggableArguments,
@@ -11,6 +10,10 @@ import {
useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+} from 'features/nodes/types/types';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
@@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
};
+export type AddFieldToLinearViewDropData = BaseDropData & {
+ actionType: 'ADD_FIELD_TO_LINEAR';
+};
+
export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
@@ -71,12 +78,22 @@ export type TypesafeDroppableData =
| AddToBatchDropData
| NodesMultiImageDropData
| AddToBoardDropData
- | RemoveFromBoardDropData;
+ | RemoveFromBoardDropData
+ | AddFieldToLinearViewDropData;
type BaseDragData = {
id: string;
};
+export type NodeFieldDraggableData = BaseDragData & {
+ payloadType: 'NODE_FIELD';
+ payload: {
+ nodeId: string;
+ field: InputFieldValue;
+ fieldTemplate: InputFieldTemplate;
+ };
+};
+
export type ImageDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTO';
payload: { imageDTO: ImageDTO };
@@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & {
payload: { imageDTOs: ImageDTO[] };
};
-export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
+export type TypesafeDraggableData =
+ | NodeFieldDraggableData
+ | ImageDraggableData
+ | ImageDTOsDraggableData;
-interface UseDroppableTypesafeArguments
+export interface UseDroppableTypesafeArguments
extends Omit {
data?: TypesafeDroppableData;
}
-type UseDroppableTypesafeReturnValue = Omit<
+export type UseDroppableTypesafeReturnValue = Omit<
ReturnType,
'active' | 'over'
> & {
@@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit<
over: TypesafeOver | null;
};
-export function useDroppable(props: UseDroppableTypesafeArguments) {
- return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
-}
-
-interface UseDraggableTypesafeArguments
+export interface UseDraggableTypesafeArguments
extends Omit {
data?: TypesafeDraggableData;
}
-type UseDraggableTypesafeReturnValue = Omit<
+export type UseDraggableTypesafeReturnValue = Omit<
ReturnType,
'active' | 'over'
> & {
@@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit<
over: TypesafeOver | null;
};
-export function useDraggable(props: UseDraggableTypesafeArguments) {
- return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
-}
-
-interface TypesafeActive extends Omit {
+export interface TypesafeActive extends Omit {
data: React.MutableRefObject;
}
-interface TypesafeOver extends Omit {
+export interface TypesafeOver extends Omit {
data: React.MutableRefObject;
}
-export const isValidDrop = (
- overData: TypesafeDroppableData | undefined,
- active: TypesafeActive | null
-) => {
- if (!overData || !active?.data.current) {
- return false;
- }
-
- const { actionType } = overData;
- const { payloadType } = active.data.current;
-
- if (overData.id === active.data.current.id) {
- return false;
- }
-
- switch (actionType) {
- case 'SET_CURRENT_IMAGE':
- return payloadType === 'IMAGE_DTO';
- case 'SET_INITIAL_IMAGE':
- return payloadType === 'IMAGE_DTO';
- case 'SET_CONTROLNET_IMAGE':
- return payloadType === 'IMAGE_DTO';
- case 'SET_CANVAS_INITIAL_IMAGE':
- return payloadType === 'IMAGE_DTO';
- case 'SET_NODES_IMAGE':
- return payloadType === 'IMAGE_DTO';
- case 'SET_MULTI_NODES_IMAGE':
- return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
- case 'ADD_TO_BATCH':
- return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
- case 'ADD_TO_BOARD': {
- // If the board is the same, don't allow the drop
-
- // Check the payload types
- const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
- if (!isPayloadValid) {
- return false;
- }
-
- // Check if the image's board is the board we are dragging onto
- if (payloadType === 'IMAGE_DTO') {
- const { imageDTO } = active.data.current.payload;
- const currentBoard = imageDTO.board_id ?? 'none';
- const destinationBoard = overData.context.boardId;
-
- return currentBoard !== destinationBoard;
- }
-
- if (payloadType === 'IMAGE_DTOS') {
- // TODO (multi-select)
- return true;
- }
-
- return false;
- }
- case 'REMOVE_FROM_BOARD': {
- // If the board is the same, don't allow the drop
-
- // Check the payload types
- const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
- if (!isPayloadValid) {
- return false;
- }
-
- // Check if the image's board is the board we are dragging onto
- if (payloadType === 'IMAGE_DTO') {
- const { imageDTO } = active.data.current.payload;
- const currentBoard = imageDTO.board_id;
-
- return currentBoard !== 'none';
- }
-
- if (payloadType === 'IMAGE_DTOS') {
- // TODO (multi-select)
- return true;
- }
-
- return false;
- }
- default:
- return false;
- }
-};
-
interface DragEvent {
activatorEvent: Event;
active: TypesafeActive;
@@ -240,6 +168,3 @@ export interface DndContextTypesafeProps
onDragEnd?(event: DragEndEvent): void;
onDragCancel?(event: DragCancelEvent): void;
}
-export function DndContext(props: DndContextTypesafeProps) {
- return ;
-}
diff --git a/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
new file mode 100644
index 0000000000..f704d22dff
--- /dev/null
+++ b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
@@ -0,0 +1,87 @@
+import { TypesafeActive, TypesafeDroppableData } from '../types';
+
+export const isValidDrop = (
+ overData: TypesafeDroppableData | undefined,
+ active: TypesafeActive | null
+) => {
+ if (!overData || !active?.data.current) {
+ return false;
+ }
+
+ const { actionType } = overData;
+ const { payloadType } = active.data.current;
+
+ if (overData.id === active.data.current.id) {
+ return false;
+ }
+
+ switch (actionType) {
+ case 'ADD_FIELD_TO_LINEAR':
+ return payloadType === 'NODE_FIELD';
+ case 'SET_CURRENT_IMAGE':
+ return payloadType === 'IMAGE_DTO';
+ case 'SET_INITIAL_IMAGE':
+ return payloadType === 'IMAGE_DTO';
+ case 'SET_CONTROLNET_IMAGE':
+ return payloadType === 'IMAGE_DTO';
+ case 'SET_CANVAS_INITIAL_IMAGE':
+ return payloadType === 'IMAGE_DTO';
+ case 'SET_NODES_IMAGE':
+ return payloadType === 'IMAGE_DTO';
+ case 'SET_MULTI_NODES_IMAGE':
+ return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
+ case 'ADD_TO_BATCH':
+ return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
+ case 'ADD_TO_BOARD': {
+ // If the board is the same, don't allow the drop
+
+ // Check the payload types
+ const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
+ if (!isPayloadValid) {
+ return false;
+ }
+
+ // Check if the image's board is the board we are dragging onto
+ if (payloadType === 'IMAGE_DTO') {
+ const { imageDTO } = active.data.current.payload;
+ const currentBoard = imageDTO.board_id ?? 'none';
+ const destinationBoard = overData.context.boardId;
+
+ return currentBoard !== destinationBoard;
+ }
+
+ if (payloadType === 'IMAGE_DTOS') {
+ // TODO (multi-select)
+ return true;
+ }
+
+ return false;
+ }
+ case 'REMOVE_FROM_BOARD': {
+ // If the board is the same, don't allow the drop
+
+ // Check the payload types
+ const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
+ if (!isPayloadValid) {
+ return false;
+ }
+
+ // Check if the image's board is the board we are dragging onto
+ if (payloadType === 'IMAGE_DTO') {
+ const { imageDTO } = active.data.current.payload;
+ const currentBoard = imageDTO.board_id;
+
+ return currentBoard !== 'none';
+ }
+
+ if (payloadType === 'IMAGE_DTOS') {
+ // TODO (multi-select)
+ return true;
+ }
+
+ return false;
+ }
+ default:
+ return false;
+ }
+};
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx
index 228ce7080c..696a8b748b 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx
@@ -11,7 +11,6 @@ import {
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
-import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
@@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types';
import AutoAddIcon from '../AutoAddIcon';
import BoardContextMenu from '../BoardContextMenu';
+import { AddToBoardDropData } from 'features/dnd/types';
interface GalleryBoardProps {
board: BoardDTO;
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx
index 0d630c524d..1698a81ac0 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx
@@ -1,7 +1,7 @@
import { As, Badge, Flex } from '@chakra-ui/react';
-import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import { TypesafeDroppableData } from 'features/dnd/types';
import { BoardId } from 'features/gallery/store/types';
import { ReactNode } from 'react';
import BoardContextMenu from '../BoardContextMenu';
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx
index f1341b1146..fec280db0f 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx
@@ -1,15 +1,15 @@
import { Box, Flex, Image, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
-import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import InvokeAILogoImage from 'assets/images/logo.png';
import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay';
+import { RemoveFromBoardDropData } from 'features/dnd/types';
import {
- boardIdSelected,
autoAddBoardIdChanged,
+ boardIdSelected,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx
index f78ee286ef..2576c8e9e3 100644
--- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx
@@ -1,14 +1,14 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
-import {
- TypesafeDraggableData,
- TypesafeDroppableData,
-} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import {
+ TypesafeDraggableData,
+ TypesafeDroppableData,
+} from 'features/dnd/types';
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { AnimatePresence, motion } from 'framer-motion';
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
index f2ff2ad30b..804df49b8e 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
@@ -52,11 +52,13 @@ const ImageGalleryContent = () => {
return (
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
index c9eee5f1f5..97f8199aed 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
@@ -1,9 +1,4 @@
import { Box, Flex } from '@chakra-ui/react';
-import {
- ImageDTOsDraggableData,
- ImageDraggableData,
- TypesafeDraggableData,
-} from 'app/components/ImageDnd/typesafeDnd';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
@@ -12,6 +7,11 @@ import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
+import {
+ ImageDTOsDraggableData,
+ ImageDraggableData,
+ TypesafeDraggableData,
+} from 'features/dnd/types';
interface HoverableImageProps {
imageName: string;
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx
index 4a56fe0e9a..bacd5c38ad 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx
@@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
options: {
scrollbars: {
visibility: 'auto',
- autoHide: 'leave',
+ autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx
index 590d40438b..69385607de 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx
@@ -1,26 +1,40 @@
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
-import { useMemo } from 'react';
-import { FaCopy } from 'react-icons/fa';
+import { useCallback, useMemo } from 'react';
+import { FaCopy, FaSave } from 'react-icons/fa';
type Props = {
- copyTooltip: string;
+ label: string;
jsonObject: object;
+ fileName?: string;
};
const ImageMetadataJSON = (props: Props) => {
- const { copyTooltip, jsonObject } = props;
+ const { label, jsonObject, fileName } = props;
const jsonString = useMemo(
() => JSON.stringify(jsonObject, null, 2),
[jsonObject]
);
+ const handleCopy = useCallback(() => {
+ navigator.clipboard.writeText(jsonString);
+ }, [jsonString]);
+
+ const handleSave = useCallback(() => {
+ const blob = new Blob([jsonString]);
+ const a = document.createElement('a');
+ a.href = URL.createObjectURL(blob);
+ a.download = `${fileName || label}.json`;
+ document.body.appendChild(a);
+ a.click();
+ a.remove();
+ }, [jsonString, label, fileName]);
+
return (
{
bottom: 0,
overflow: 'auto',
p: 4,
+ fontSize: 'sm',
}}
>
{
options={{
scrollbars: {
visibility: 'auto',
- autoHide: 'move',
+ autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
@@ -54,12 +69,22 @@ const ImageMetadataJSON = (props: Props) => {
-
+
}
+ variant="ghost"
+ opacity={0.7}
+ onClick={handleSave}
+ />
+
+
+ }
variant="ghost"
- onClick={() => navigator.clipboard.writeText(jsonString)}
+ opacity={0.7}
+ onClick={handleCopy}
/>
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx
index e1f2a9e46a..d70aea8a8d 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx
@@ -10,7 +10,8 @@ import {
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
-import { memo, useMemo } from 'react';
+import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import { memo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
@@ -41,48 +42,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const metadata = currentData?.metadata;
const graph = currentData?.graph;
- const tabData = useMemo(() => {
- const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
-
- if (metadata) {
- _tabData.push({
- label: 'Core Metadata',
- data: metadata,
- copyTooltip: 'Copy Core Metadata JSON',
- });
- }
-
- if (image) {
- _tabData.push({
- label: 'Image Details',
- data: image,
- copyTooltip: 'Copy Image Details JSON',
- });
- }
-
- if (graph) {
- _tabData.push({
- label: 'Graph',
- data: graph,
- copyTooltip: 'Copy Graph JSON',
- });
- }
- return _tabData;
- }, [metadata, graph, image]);
-
return (
{
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
- {tabData.map((tab) => (
-
-
- {tab.label}
-
-
- ))}
+ Core Metadata
+ Image Details
+ Graph
-
- {tabData.map((tab) => (
-
-
-
- ))}
+
+
+ {metadata ? (
+
+ ) : (
+
+ )}
+
+
+ {image ? (
+
+ ) : (
+
+ )}
+
+
+ {graph ? (
+
+ ) : (
+
+ )}
+
diff --git a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx
index a1a1acf1f8..a816762d0f 100644
--- a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx
@@ -9,30 +9,40 @@ import { map } from 'lodash-es';
import { forwardRef, useCallback } from 'react';
import 'reactflow/dist/style.css';
import { AnyInvocationType } from 'services/events/types';
-import { useBuildInvocation } from '../hooks/useBuildInvocation';
+import { useBuildNodeData } from '../hooks/useBuildNodeData';
import { nodeAdded } from '../store/nodesSlice';
type NodeTemplate = {
label: string;
value: string;
description: string;
+ tags: string[];
};
const selector = createSelector(
[stateSelector],
({ nodes }) => {
- const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
+ const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
return {
label: template.title,
value: template.type,
description: template.description,
+ tags: template.tags,
};
});
data.push({
label: 'Progress Image',
- value: 'progress_image',
- description: 'Displays the progress image in the Node Editor',
+ value: 'current_image',
+ description: 'Displays the current image in the Node Editor',
+ tags: ['progress'],
+ });
+
+ data.push({
+ label: 'Notes',
+ value: 'notes',
+ description: 'Add notes about your workflow',
+ tags: ['notes'],
});
return { data };
@@ -44,7 +54,7 @@ const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
- const buildInvocation = useBuildInvocation();
+ const buildInvocation = useBuildNodeData();
const toaster = useAppToaster();
@@ -89,11 +99,12 @@ const AddNodeMenu = () => {
filter={(value, item: NodeTemplate) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
- item.description.toLowerCase().includes(value.toLowerCase().trim())
+ item.description.toLowerCase().includes(value.toLowerCase().trim()) ||
+ item.tags.includes(value.toLowerCase().trim())
}
onChange={handleChange}
sx={{
- width: '18rem',
+ width: '24rem',
}}
/>
diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx
new file mode 100644
index 0000000000..678d8e3d1d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx
@@ -0,0 +1,61 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
+import { FIELDS, colorTokenToCssVar } from '../types/constants';
+
+const selector = createSelector(stateSelector, ({ nodes }) => {
+ const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
+ nodes;
+
+ const stroke =
+ currentConnectionFieldType && shouldColorEdges
+ ? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
+ : colorTokenToCssVar('base.500');
+
+ let className = 'react-flow__custom_connection-path';
+
+ if (shouldAnimateEdges) {
+ className = className.concat(' animated');
+ }
+
+ return {
+ stroke,
+ className,
+ };
+});
+
+export const CustomConnectionLine = ({
+ fromX,
+ fromY,
+ fromPosition,
+ toX,
+ toY,
+ toPosition,
+}: ConnectionLineComponentProps) => {
+ const { stroke, className } = useAppSelector(selector);
+
+ const pathParams = {
+ sourceX: fromX,
+ sourceY: fromY,
+ sourcePosition: fromPosition,
+ targetX: toX,
+ targetY: toY,
+ targetPosition: toPosition,
+ };
+
+ const [dAttr] = getBezierPath(pathParams);
+
+ return (
+
+
+
+ );
+};
diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx
new file mode 100644
index 0000000000..e0ccc6e323
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx
@@ -0,0 +1,183 @@
+import { Badge, Flex } from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
+import { useMemo } from 'react';
+import {
+ BaseEdge,
+ EdgeLabelRenderer,
+ EdgeProps,
+ getBezierPath,
+} from 'reactflow';
+import { FIELDS, colorTokenToCssVar } from '../types/constants';
+import { isInvocationNode } from '../types/types';
+
+const makeEdgeSelector = (
+ source: string,
+ sourceHandleId: string | null | undefined,
+ target: string,
+ targetHandleId: string | null | undefined,
+ selected?: boolean
+) =>
+ createSelector(stateSelector, ({ nodes }) => {
+ const sourceNode = nodes.nodes.find((node) => node.id === source);
+ const targetNode = nodes.nodes.find((node) => node.id === target);
+
+ const isInvocationToInvocationEdge =
+ isInvocationNode(sourceNode) && isInvocationNode(targetNode);
+
+ const isSelected = sourceNode?.selected || targetNode?.selected || selected;
+ const sourceType = isInvocationToInvocationEdge
+ ? sourceNode?.data?.outputs[sourceHandleId || '']?.type
+ : undefined;
+
+ const stroke =
+ sourceType && nodes.shouldColorEdges
+ ? colorTokenToCssVar(FIELDS[sourceType].color)
+ : colorTokenToCssVar('base.500');
+
+ return {
+ isSelected,
+ shouldAnimate: nodes.shouldAnimateEdges && isSelected,
+ stroke,
+ };
+ });
+
+const CollapsedEdge = ({
+ sourceX,
+ sourceY,
+ targetX,
+ targetY,
+ sourcePosition,
+ targetPosition,
+ markerEnd,
+ data,
+ selected,
+ source,
+ target,
+ sourceHandleId,
+ targetHandleId,
+}: EdgeProps<{ count: number }>) => {
+ const selector = useMemo(
+ () =>
+ makeEdgeSelector(
+ source,
+ sourceHandleId,
+ target,
+ targetHandleId,
+ selected
+ ),
+ [selected, source, sourceHandleId, target, targetHandleId]
+ );
+
+ const { isSelected, shouldAnimate } = useAppSelector(selector);
+
+ const [edgePath, labelX, labelY] = getBezierPath({
+ sourceX,
+ sourceY,
+ sourcePosition,
+ targetX,
+ targetY,
+ targetPosition,
+ });
+
+ const { base500 } = useChakraThemeTokens();
+
+ return (
+ <>
+
+ {data?.count && data.count > 1 && (
+
+
+
+ {data.count}
+
+
+
+ )}
+ >
+ );
+};
+
+const DefaultEdge = ({
+ sourceX,
+ sourceY,
+ targetX,
+ targetY,
+ sourcePosition,
+ targetPosition,
+ markerEnd,
+ selected,
+ source,
+ target,
+ sourceHandleId,
+ targetHandleId,
+}: EdgeProps) => {
+ const selector = useMemo(
+ () =>
+ makeEdgeSelector(
+ source,
+ sourceHandleId,
+ target,
+ targetHandleId,
+ selected
+ ),
+ [source, sourceHandleId, target, targetHandleId, selected]
+ );
+
+ const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
+
+ const [edgePath] = getBezierPath({
+ sourceX,
+ sourceY,
+ sourcePosition,
+ targetX,
+ targetY,
+ targetPosition,
+ });
+
+ return (
+
+ );
+};
+
+export const edgeTypes = {
+ collapsed: CollapsedEdge,
+ default: DefaultEdge,
+};
diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx
new file mode 100644
index 0000000000..3aacb3cd58
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx
@@ -0,0 +1,9 @@
+import InvocationNode from './nodes/InvocationNode';
+import CurrentImageNode from './nodes/CurrentImageNode';
+import NotesNode from './nodes/NotesNode';
+
+export const nodeTypes = {
+ invocation: InvocationNode,
+ current_image: CurrentImageNode,
+ notes: NotesNode,
+};
diff --git a/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx
deleted file mode 100644
index 86099a7315..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx
+++ /dev/null
@@ -1,64 +0,0 @@
-import { Tooltip } from '@chakra-ui/react';
-import { CSSProperties, memo } from 'react';
-import { Handle, Position, Connection, HandleType } from 'reactflow';
-import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
-// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
-import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
-
-const handleBaseStyles: CSSProperties = {
- position: 'absolute',
- width: '1rem',
- height: '1rem',
- borderWidth: 0,
-};
-
-const inputHandleStyles: CSSProperties = {
- left: '-1rem',
-};
-
-const outputHandleStyles: CSSProperties = {
- right: '-0.5rem',
-};
-
-// const requiredConnectionStyles: CSSProperties = {
-// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
-// };
-
-type FieldHandleProps = {
- nodeId: string;
- field: InputFieldTemplate | OutputFieldTemplate;
- isValidConnection: (connection: Connection) => boolean;
- handleType: HandleType;
- styles?: CSSProperties;
-};
-
-const FieldHandle = (props: FieldHandleProps) => {
- const { field, isValidConnection, handleType, styles } = props;
- const { name, type } = field;
-
- return (
-
-
-
- );
-};
-
-export default memo(FieldHandle);
diff --git a/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx b/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx
index 78316cc694..a523cc29fe 100644
--- a/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx
@@ -1,8 +1,8 @@
-import 'reactflow/dist/style.css';
-import { Tooltip, Badge, Flex } from '@chakra-ui/react';
+import { Badge, Flex, Tooltip } from '@chakra-ui/react';
import { map } from 'lodash-es';
-import { FIELDS } from '../types/constants';
import { memo } from 'react';
+import 'reactflow/dist/style.css';
+import { FIELDS } from '../types/constants';
const FieldTypeLegend = () => {
return (
@@ -10,8 +10,14 @@ const FieldTypeLegend = () => {
{map(FIELDS, ({ title, description, color }, key) => (
{title}
diff --git a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx
index 7b0718182b..71062e9774 100644
--- a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx
@@ -1,4 +1,3 @@
-import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import {
@@ -7,35 +6,49 @@ import {
OnConnectEnd,
OnConnectStart,
OnEdgesChange,
+ OnEdgesDelete,
OnInit,
+ OnMove,
OnNodesChange,
+ OnNodesDelete,
+ OnSelectionChangeFunc,
+ ProOptions,
ReactFlow,
} from 'reactflow';
+import { useIsValidConnection } from '../hooks/useIsValidConnection';
import {
connectionEnded,
connectionMade,
connectionStarted,
edgesChanged,
+ edgesDeleted,
nodesChanged,
- setEditorInstance,
+ nodesDeleted,
+ selectedEdgesChanged,
+ selectedNodesChanged,
+ zoomChanged,
} from '../store/nodesSlice';
-import { InvocationComponent } from './InvocationComponent';
-import ProgressImageNode from './ProgressImageNode';
-import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
-import MinimapPanel from './panels/MinimapPanel';
-import TopCenterPanel from './panels/TopCenterPanel';
-import TopLeftPanel from './panels/TopLeftPanel';
-import TopRightPanel from './panels/TopRightPanel';
+import { CustomConnectionLine } from './CustomConnectionLine';
+import { edgeTypes } from './CustomEdges';
+import { nodeTypes } from './CustomNodes';
+import BottomLeftPanel from './editorPanels/BottomLeftPanel';
+import MinimapPanel from './editorPanels/MinimapPanel';
+import TopCenterPanel from './editorPanels/TopCenterPanel';
+import TopLeftPanel from './editorPanels/TopLeftPanel';
+import TopRightPanel from './editorPanels/TopRightPanel';
-const nodeTypes = {
- invocation: InvocationComponent,
- progress_image: ProgressImageNode,
-};
+// TODO: can we support reactflow? if not, we could style the attribution so it matches the app
+const proOptions: ProOptions = { hideAttribution: true };
export const Flow = () => {
const dispatch = useAppDispatch();
- const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
- const edges = useAppSelector((state: RootState) => state.nodes.edges);
+ const nodes = useAppSelector((state) => state.nodes.nodes);
+ const edges = useAppSelector((state) => state.nodes.edges);
+ const shouldSnapToGrid = useAppSelector(
+ (state) => state.nodes.shouldSnapToGrid
+ );
+
+ const isValidConnection = useIsValidConnection();
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
@@ -69,10 +82,36 @@ export const Flow = () => {
dispatch(connectionEnded());
}, [dispatch]);
- const onInit: OnInit = useCallback(
- (v) => {
- dispatch(setEditorInstance(v));
- if (v) v.fitView();
+ const onInit: OnInit = useCallback((v) => {
+ v.fitView();
+ }, []);
+
+ const onEdgesDelete: OnEdgesDelete = useCallback(
+ (edges) => {
+ dispatch(edgesDeleted(edges));
+ },
+ [dispatch]
+ );
+
+ const onNodesDelete: OnNodesDelete = useCallback(
+ (nodes) => {
+ dispatch(nodesDeleted(nodes));
+ },
+ [dispatch]
+ );
+
+ const handleSelectionChange: OnSelectionChangeFunc = useCallback(
+ ({ nodes, edges }) => {
+ dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
+ dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
+ },
+ [dispatch]
+ );
+
+ const handleMove: OnMove = useCallback(
+ (e, viewport) => {
+ const { zoom } = viewport;
+ dispatch(zoomChanged(zoom));
},
[dispatch]
);
@@ -80,24 +119,33 @@ export const Flow = () => {
return (
-
+
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx
deleted file mode 100644
index 7b56bc95b4..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx
+++ /dev/null
@@ -1,55 +0,0 @@
-import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
-import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
-import { memo } from 'react';
-import { FaInfoCircle } from 'react-icons/fa';
-
-interface IAINodeHeaderProps {
- nodeId?: string;
- title?: string;
- description?: string;
-}
-
-const IAINodeHeader = (props: IAINodeHeaderProps) => {
- const { nodeId, title, description } = props;
- return (
-
-
-
- {title}
-
-
-
-
-
-
- );
-};
-
-export default memo(IAINodeHeader);
diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx
deleted file mode 100644
index 6f779e4295..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx
+++ /dev/null
@@ -1,149 +0,0 @@
-import {
- Box,
- Divider,
- Flex,
- FormControl,
- FormLabel,
- HStack,
- Tooltip,
-} from '@chakra-ui/react';
-import { RootState } from 'app/store/store';
-import { useAppSelector } from 'app/store/storeHooks';
-import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
-import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
-import {
- InputFieldTemplate,
- InputFieldValue,
- InvocationTemplate,
-} from 'features/nodes/types/types';
-import { map } from 'lodash-es';
-import { ReactNode, memo, useCallback } from 'react';
-import FieldHandle from '../FieldHandle';
-import InputFieldComponent from '../InputFieldComponent';
-
-interface IAINodeInputProps {
- nodeId: string;
-
- input: InputFieldValue;
- template?: InputFieldTemplate | undefined;
- connected: boolean;
-}
-
-function IAINodeInput(props: IAINodeInputProps) {
- const { nodeId, input, template, connected } = props;
- const isValidConnection = useIsValidConnection();
-
- return (
-
-
- {!template ? (
-
- Unknown input: {input.name}
-
- ) : (
- <>
-
-
-
- {template?.title}
-
-
-
-
-
- {!['never', 'directOnly'].includes(
- template?.inputRequirement ?? ''
- ) && (
-
- )}
- >
- )}
-
-
- );
-}
-
-interface IAINodeInputsProps {
- nodeId: string;
- template: InvocationTemplate;
- inputs: Record;
-}
-
-const IAINodeInputs = (props: IAINodeInputsProps) => {
- const { nodeId, template, inputs } = props;
-
- const edges = useAppSelector((state: RootState) => state.nodes.edges);
-
- const renderIAINodeInputs = useCallback(() => {
- const IAINodeInputsToRender: ReactNode[] = [];
- const inputSockets = map(inputs);
-
- inputSockets.forEach((inputSocket, index) => {
- const inputTemplate = template.inputs[inputSocket.name];
-
- const isConnected = Boolean(
- edges.filter((connectedInput) => {
- return (
- connectedInput.target === nodeId &&
- connectedInput.targetHandle === inputSocket.name
- );
- }).length
- );
-
- if (index < inputSockets.length) {
- IAINodeInputsToRender.push(
-
- );
- }
-
- IAINodeInputsToRender.push(
-
- );
- });
-
- return (
-
- {IAINodeInputsToRender}
-
- );
- }, [edges, inputs, nodeId, template.inputs]);
-
- return renderIAINodeInputs();
-};
-
-export default memo(IAINodeInputs);
diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx
deleted file mode 100644
index 2cb0bcde8d..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx
+++ /dev/null
@@ -1,97 +0,0 @@
-import {
- InvocationTemplate,
- OutputFieldTemplate,
- OutputFieldValue,
-} from 'features/nodes/types/types';
-import { memo, ReactNode, useCallback } from 'react';
-import { map } from 'lodash-es';
-import { useAppSelector } from 'app/store/storeHooks';
-import { RootState } from 'app/store/store';
-import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react';
-import FieldHandle from '../FieldHandle';
-import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
-
-interface IAINodeOutputProps {
- nodeId: string;
- output: OutputFieldValue;
- template?: OutputFieldTemplate | undefined;
- connected: boolean;
-}
-
-function IAINodeOutput(props: IAINodeOutputProps) {
- const { nodeId, output, template, connected } = props;
- const isValidConnection = useIsValidConnection();
-
- return (
-
-
- {!template ? (
-
-
- Unknown Output: {output.name}
-
-
- ) : (
- <>
-
- {template?.title}
-
-
- >
- )}
-
-
- );
-}
-
-interface IAINodeOutputsProps {
- nodeId: string;
- template: InvocationTemplate;
- outputs: Record;
-}
-
-const IAINodeOutputs = (props: IAINodeOutputsProps) => {
- const { nodeId, template, outputs } = props;
-
- const edges = useAppSelector((state: RootState) => state.nodes.edges);
-
- const renderIAINodeOutputs = useCallback(() => {
- const IAINodeOutputsToRender: ReactNode[] = [];
- const outputSockets = map(outputs);
-
- outputSockets.forEach((outputSocket) => {
- const outputTemplate = template.outputs[outputSocket.name];
-
- const isConnected = Boolean(
- edges.filter((connectedInput) => {
- return (
- connectedInput.source === nodeId &&
- connectedInput.sourceHandle === outputSocket.name
- );
- }).length
- );
-
- IAINodeOutputsToRender.push(
-
- );
- });
-
- return {IAINodeOutputsToRender};
- }, [edges, nodeId, outputs, template.outputs]);
-
- return renderIAINodeOutputs();
-};
-
-export default memo(IAINodeOutputs);
diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
deleted file mode 100644
index 0ecc43ef9c..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
+++ /dev/null
@@ -1,252 +0,0 @@
-import { Box } from '@chakra-ui/react';
-import { memo } from 'react';
-import { InputFieldTemplate, InputFieldValue } from '../types/types';
-import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
-import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
-import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
-import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
-import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
-import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
-import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
-import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
-import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
-import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
-import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
-import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
-import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
-import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
-import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
-import StringInputFieldComponent from './fields/StringInputFieldComponent';
-import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
-import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
-import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
-import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
-
-type InputFieldComponentProps = {
- nodeId: string;
- field: InputFieldValue;
- template: InputFieldTemplate;
-};
-
-// build an individual input element based on the schema
-const InputFieldComponent = (props: InputFieldComponentProps) => {
- const { nodeId, field, template } = props;
- const { type } = field;
-
- if (type === 'string' && template.type === 'string') {
- return (
-
- );
- }
-
- if (type === 'boolean' && template.type === 'boolean') {
- return (
-
- );
- }
-
- if (
- (type === 'integer' && template.type === 'integer') ||
- (type === 'float' && template.type === 'float')
- ) {
- return (
-
- );
- }
-
- if (type === 'enum' && template.type === 'enum') {
- return (
-
- );
- }
-
- if (type === 'image' && template.type === 'image') {
- return (
-
- );
- }
-
- if (type === 'latents' && template.type === 'latents') {
- return (
-
- );
- }
-
- if (type === 'conditioning' && template.type === 'conditioning') {
- return (
-
- );
- }
-
- if (type === 'unet' && template.type === 'unet') {
- return (
-
- );
- }
-
- if (type === 'clip' && template.type === 'clip') {
- return (
-
- );
- }
-
- if (type === 'vae' && template.type === 'vae') {
- return (
-
- );
- }
-
- if (type === 'control' && template.type === 'control') {
- return (
-
- );
- }
-
- if (type === 'model' && template.type === 'model') {
- return (
-
- );
- }
-
- if (type === 'refiner_model' && template.type === 'refiner_model') {
- return (
-
- );
- }
-
- if (type === 'vae_model' && template.type === 'vae_model') {
- return (
-
- );
- }
-
- if (type === 'lora_model' && template.type === 'lora_model') {
- return (
-
- );
- }
-
- if (type === 'controlnet_model' && template.type === 'controlnet_model') {
- return (
-
- );
- }
-
- if (type === 'array' && template.type === 'array') {
- return (
-
- );
- }
-
- if (type === 'item' && template.type === 'item') {
- return (
-
- );
- }
-
- if (type === 'color' && template.type === 'color') {
- return (
-
- );
- }
-
- if (type === 'item' && template.type === 'item') {
- return (
-
- );
- }
-
- if (type === 'image_collection' && template.type === 'image_collection') {
- return (
-
- );
- }
-
- return Unknown field type: {type};
-};
-
-export default memo(InputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx
new file mode 100644
index 0000000000..d67ca10dcc
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx
@@ -0,0 +1,57 @@
+import { ChevronUpIcon } from '@chakra-ui/icons';
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAIIconButton from 'common/components/IAIIconButton';
+import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
+import { NodeData } from 'features/nodes/types/types';
+import { memo, useCallback } from 'react';
+import { NodeProps, useUpdateNodeInternals } from 'reactflow';
+
+interface Props {
+ nodeProps: NodeProps;
+}
+
+const NodeCollapseButton = (props: Props) => {
+ const { id: nodeId, isOpen } = props.nodeProps.data;
+ const dispatch = useAppDispatch();
+ const updateNodeInternals = useUpdateNodeInternals();
+
+ const handleClick = useCallback(() => {
+ dispatch(nodeIsOpenChanged({ nodeId, isOpen: !isOpen }));
+ updateNodeInternals(nodeId);
+ }, [dispatch, isOpen, nodeId, updateNodeInternals]);
+
+ return (
+
+ }
+ />
+ );
+};
+
+export default memo(NodeCollapseButton);
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx
new file mode 100644
index 0000000000..ece24f6f8c
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx
@@ -0,0 +1,74 @@
+import { useColorModeValue } from '@chakra-ui/react';
+import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
+import {
+ InvocationNodeData,
+ InvocationTemplate,
+} from 'features/nodes/types/types';
+import { map } from 'lodash-es';
+import { CSSProperties, memo, useMemo } from 'react';
+import { Handle, NodeProps, Position } from 'reactflow';
+
+interface Props {
+ nodeProps: NodeProps;
+ nodeTemplate: InvocationTemplate;
+}
+
+const NodeCollapsedHandles = (props: Props) => {
+ const { data } = props.nodeProps;
+ const { base400, base600 } = useChakraThemeTokens();
+ const backgroundColor = useColorModeValue(base400, base600);
+
+ const dummyHandleStyles: CSSProperties = useMemo(
+ () => ({
+ borderWidth: 0,
+ borderRadius: '3px',
+ width: '1rem',
+ height: '1rem',
+ backgroundColor,
+ zIndex: -1,
+ }),
+ [backgroundColor]
+ );
+
+ return (
+ <>
+
+ {map(data.inputs, (input) => (
+ false}
+ position={Position.Left}
+ style={{ visibility: 'hidden' }}
+ />
+ ))}
+ false}
+ isConnectable={false}
+ position={Position.Right}
+ style={{ ...dummyHandleStyles, right: '-0.5rem' }}
+ />
+ {map(data.outputs, (output) => (
+ false}
+ position={Position.Right}
+ style={{ visibility: 'hidden' }}
+ />
+ ))}
+ >
+ );
+};
+
+export default memo(NodeCollapsedHandles);
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx
new file mode 100644
index 0000000000..3c513ed29a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx
@@ -0,0 +1,77 @@
+import {
+ Checkbox,
+ Flex,
+ FormControl,
+ FormLabel,
+ Spacer,
+} from '@chakra-ui/react';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
+import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
+import {
+ InvocationNodeData,
+ InvocationTemplate,
+} from 'features/nodes/types/types';
+import { some } from 'lodash-es';
+import { ChangeEvent, memo, useCallback, useMemo } from 'react';
+import { NodeProps } from 'reactflow';
+
+type Props = {
+ nodeProps: NodeProps;
+ nodeTemplate: InvocationTemplate;
+};
+
+const NodeFooter = (props: Props) => {
+ const { nodeProps, nodeTemplate } = props;
+ const dispatch = useAppDispatch();
+
+ const hasImageOutput = useMemo(
+ () =>
+ some(nodeTemplate?.outputs, (output) =>
+ ['ImageField', 'ImageCollection'].includes(output.type)
+ ),
+ [nodeTemplate?.outputs]
+ );
+
+ const handleChangeIsIntermediate = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(
+ fieldBooleanValueChanged({
+ nodeId: nodeProps.data.id,
+ fieldName: 'is_intermediate',
+ value: !e.target.checked,
+ })
+ );
+ },
+ [dispatch, nodeProps.data.id]
+ );
+
+ return (
+
+
+ {hasImageOutput && (
+
+ Save Output
+
+
+ )}
+
+ );
+};
+
+export default memo(NodeFooter);
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx
new file mode 100644
index 0000000000..ab54ca2c44
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx
@@ -0,0 +1,113 @@
+import {
+ Flex,
+ FormControl,
+ FormLabel,
+ Icon,
+ Modal,
+ ModalBody,
+ ModalCloseButton,
+ ModalContent,
+ ModalFooter,
+ ModalHeader,
+ ModalOverlay,
+ Text,
+ Tooltip,
+ useDisclosure,
+} from '@chakra-ui/react';
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAITextarea from 'common/components/IAITextarea';
+import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
+import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
+import {
+ InvocationNodeData,
+ InvocationTemplate,
+} from 'features/nodes/types/types';
+import { ChangeEvent, memo, useCallback } from 'react';
+import { FaInfoCircle } from 'react-icons/fa';
+import { NodeProps } from 'reactflow';
+
+interface Props {
+ nodeProps: NodeProps;
+ nodeTemplate: InvocationTemplate;
+}
+
+const NodeNotesEdit = (props: Props) => {
+ const { nodeProps, nodeTemplate } = props;
+ const { data } = nodeProps;
+ const { isOpen, onOpen, onClose } = useDisclosure();
+ const dispatch = useAppDispatch();
+ const handleNotesChanged = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
+ },
+ [data.id, dispatch]
+ );
+
+ return (
+ <>
+
+ ) : undefined
+ }
+ placement="top"
+ shouldWrapChildren
+ >
+
+
+
+
+
+
+
+
+
+ {data.label || nodeTemplate?.title || 'Unknown Node'}
+
+
+
+
+ Notes
+
+
+
+
+
+
+ >
+ );
+};
+
+export default memo(NodeNotesEdit);
+
+type TooltipContentProps = Props;
+
+const TooltipContent = (props: TooltipContentProps) => {
+ return (
+
+ {props.nodeTemplate?.title}
+
+ {props.nodeTemplate?.description}
+
+ {props.nodeProps.data.notes && {props.nodeProps.data.notes}}
+
+ );
+};
diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx
similarity index 73%
rename from invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx
rename to invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx
index 1aca32ec70..6391e86471 100644
--- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx
@@ -2,7 +2,10 @@ import { NODE_MIN_WIDTH } from 'features/nodes/types/constants';
import { memo } from 'react';
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
-const IAINodeResizer = (props: NodeResizerProps) => {
+// this causes https://github.com/invoke-ai/InvokeAI/issues/4140
+// not using it for now
+
+const NodeResizer = (props: NodeResizerProps) => {
const { ...rest } = props;
return (
{
);
};
-export default memo(IAINodeResizer);
+export default memo(NodeResizer);
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx
new file mode 100644
index 0000000000..bf12358871
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx
@@ -0,0 +1,69 @@
+import { Flex } from '@chakra-ui/react';
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAIIconButton from 'common/components/IAIIconButton';
+import IAIPopover from 'common/components/IAIPopover';
+import IAISwitch from 'common/components/IAISwitch';
+import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
+import { InvocationNodeData } from 'features/nodes/types/types';
+import { ChangeEvent, memo, useCallback } from 'react';
+import { FaBars } from 'react-icons/fa';
+
+interface Props {
+ data: InvocationNodeData;
+}
+
+const NodeSettings = (props: Props) => {
+ const { data } = props;
+ const dispatch = useAppDispatch();
+
+ const handleChangeIsIntermediate = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(
+ fieldBooleanValueChanged({
+ nodeId: data.id,
+ fieldName: 'is_intermediate',
+ value: e.target.checked,
+ })
+ );
+ },
+ [data.id, dispatch]
+ );
+
+ return (
+ }
+ />
+ }
+ >
+
+
+
+
+ );
+};
+
+export default memo(NodeSettings);
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx
new file mode 100644
index 0000000000..6695c4fd3b
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx
@@ -0,0 +1,185 @@
+import {
+ Badge,
+ CircularProgress,
+ Flex,
+ Icon,
+ Image,
+ Text,
+ Tooltip,
+} from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
+import {
+ InvocationNodeData,
+ NodeExecutionState,
+ NodeStatus,
+} from 'features/nodes/types/types';
+import { memo, useMemo } from 'react';
+import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
+import { NodeProps } from 'reactflow';
+
+type Props = {
+ nodeProps: NodeProps;
+};
+
+const iconBoxSize = 3;
+const circleStyles = {
+ circle: {
+ transitionProperty: 'none',
+ transitionDuration: '0s',
+ },
+ '.chakra-progress__track': { stroke: 'transparent' },
+};
+
+const NodeStatusIndicator = (props: Props) => {
+ const nodeId = props.nodeProps.data.id;
+ const selectNodeExecutionState = useMemo(
+ () =>
+ createSelector(
+ stateSelector,
+ ({ nodes }) => nodes.nodeExecutionStates[nodeId]
+ ),
+ [nodeId]
+ );
+
+ const nodeExecutionState = useAppSelector(selectNodeExecutionState);
+
+ if (!nodeExecutionState) {
+ return null;
+ }
+
+ return (
+ }
+ placement="top"
+ >
+
+
+
+
+ );
+};
+
+export default memo(NodeStatusIndicator);
+
+type TooltipLabelProps = {
+ nodeExecutionState: NodeExecutionState;
+};
+
+const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
+ const { status, progress, progressImage } = nodeExecutionState;
+ if (status === NodeStatus.PENDING) {
+ return Pending;
+ }
+
+ if (status === NodeStatus.IN_PROGRESS) {
+ if (progressImage) {
+ return (
+
+
+ {progress !== null && (
+
+ {Math.round(progress * 100)}%
+
+ )}
+
+ );
+ }
+
+ if (progress !== null) {
+ return In Progress ({Math.round(progress * 100)}%);
+ }
+
+ return In Progress;
+ }
+
+ if (status === NodeStatus.COMPLETED) {
+ return Completed;
+ }
+
+ if (status === NodeStatus.FAILED) {
+ return nodeExecutionState.error;
+ }
+
+ return null;
+};
+
+type StatusIconProps = {
+ nodeExecutionState: NodeExecutionState;
+};
+
+const StatusIcon = (props: StatusIconProps) => {
+ const { progress, status } = props.nodeExecutionState;
+ if (status === NodeStatus.PENDING) {
+ return (
+
+ );
+ }
+ if (status === NodeStatus.IN_PROGRESS) {
+ return progress === null ? (
+
+ ) : (
+
+ );
+ }
+ if (status === NodeStatus.COMPLETED) {
+ return (
+
+ );
+ }
+ if (status === NodeStatus.FAILED) {
+ return (
+
+ );
+ }
+ return null;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx
new file mode 100644
index 0000000000..fa6a8ea224
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx
@@ -0,0 +1,123 @@
+import {
+ Box,
+ Editable,
+ EditableInput,
+ EditablePreview,
+ Flex,
+ useEditableControls,
+} from '@chakra-ui/react';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
+import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
+import { NodeData } from 'features/nodes/types/types';
+import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
+
+type Props = {
+ nodeData: NodeData;
+ title: string;
+};
+
+const NodeTitle = (props: Props) => {
+ const { title } = props;
+ const { id: nodeId, label } = props.nodeData;
+ const dispatch = useAppDispatch();
+ const [localTitle, setLocalTitle] = useState(label || title);
+
+ const handleSubmit = useCallback(
+ async (newTitle: string) => {
+ dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
+ setLocalTitle(newTitle || title);
+ },
+ [nodeId, dispatch, title]
+ );
+
+ const handleChange = useCallback((newTitle: string) => {
+ setLocalTitle(newTitle);
+ }, []);
+
+ useEffect(() => {
+ // Another component may change the title; sync local title with global state
+ setLocalTitle(label || title);
+ }, [label, title]);
+
+ return (
+
+
+
+
+
+
+
+ );
+};
+
+export default memo(NodeTitle);
+
+function EditableControls() {
+ const { isEditing, getEditButtonProps } = useEditableControls();
+ const handleDoubleClick = useCallback(
+ (e: MouseEvent) => {
+ const { onClick } = getEditButtonProps();
+ if (!onClick) {
+ return;
+ }
+ onClick(e);
+ },
+ [getEditButtonProps]
+ );
+
+ if (isEditing) {
+ return null;
+ }
+
+ return (
+
+ );
+}
diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx
new file mode 100644
index 0000000000..2f555d700a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx
@@ -0,0 +1,96 @@
+import {
+ Box,
+ ChakraProps,
+ useColorModeValue,
+ useToken,
+} from '@chakra-ui/react';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { nodeClicked } from 'features/nodes/store/nodesSlice';
+import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react';
+import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
+import { NodeData } from 'features/nodes/types/types';
+import { NodeProps } from 'reactflow';
+
+const useNodeSelect = (nodeId: string) => {
+ const dispatch = useAppDispatch();
+
+ const selectNode = useCallback(
+ (e: MouseEvent) => {
+ dispatch(nodeClicked({ nodeId, ctrlOrMeta: e.ctrlKey || e.metaKey }));
+ },
+ [dispatch, nodeId]
+ );
+
+ return selectNode;
+};
+
+type NodeWrapperProps = PropsWithChildren & {
+ nodeProps: NodeProps;
+ width?: NonNullable['w'];
+};
+
+const NodeWrapper = (props: NodeWrapperProps) => {
+ const { width, children, nodeProps } = props;
+ const { data, selected } = nodeProps;
+ const nodeId = data.id;
+
+ const [
+ nodeSelectedOutlineLight,
+ nodeSelectedOutlineDark,
+ shadowsXl,
+ shadowsBase,
+ ] = useToken('shadows', [
+ 'nodeSelectedOutline.light',
+ 'nodeSelectedOutline.dark',
+ 'shadows.xl',
+ 'shadows.base',
+ ]);
+
+ const selectNode = useNodeSelect(nodeId);
+
+ const shadow = useColorModeValue(
+ nodeSelectedOutlineLight,
+ nodeSelectedOutlineDark
+ );
+
+ const shift = useAppSelector((state) => state.hotkeys.shift);
+ const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
+ const className = useMemo(
+ () => (shift ? DRAG_HANDLE_CLASSNAME : 'nopan'),
+ [shift]
+ );
+
+ return (
+
+
+ {children}
+
+ );
+};
+
+export default NodeWrapper;
diff --git a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx
deleted file mode 100644
index 4c031afaff..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx
+++ /dev/null
@@ -1,74 +0,0 @@
-import { Flex, Icon } from '@chakra-ui/react';
-import { FaExclamationCircle } from 'react-icons/fa';
-import { NodeProps } from 'reactflow';
-import { InvocationValue } from '../types/types';
-
-import { useAppSelector } from 'app/store/storeHooks';
-import { memo, useMemo } from 'react';
-import { makeTemplateSelector } from '../store/util/makeTemplateSelector';
-import IAINodeHeader from './IAINode/IAINodeHeader';
-import IAINodeInputs from './IAINode/IAINodeInputs';
-import IAINodeOutputs from './IAINode/IAINodeOutputs';
-import IAINodeResizer from './IAINode/IAINodeResizer';
-import NodeWrapper from './NodeWrapper';
-
-export const InvocationComponent = memo((props: NodeProps) => {
- const { id: nodeId, data, selected } = props;
- const { type, inputs, outputs } = data;
-
- const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
-
- const template = useAppSelector(templateSelector);
-
- if (!template) {
- return (
-
-
-
-
-
-
- );
- }
-
- return (
-
-
-
-
-
-
-
-
- );
-});
-
-InvocationComponent.displayName = 'InvocationComponent';
diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx
index 8c0480774c..8af9fefa90 100644
--- a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx
@@ -1,25 +1,45 @@
import { Box } from '@chakra-ui/react';
-import { ReactFlowProvider } from 'reactflow';
+import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
+import { memo, useState } from 'react';
+import { Panel, PanelGroup } from 'react-resizable-panels';
import 'reactflow/dist/style.css';
-
-import { memo } from 'react';
import { Flow } from './Flow';
+import NodeEditorPanelGroup from './panel/NodeEditorPanelGroup';
const NodeEditor = () => {
+ const [isPanelCollapsed, setIsPanelCollapsed] = useState(false);
return (
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx
new file mode 100644
index 0000000000..58e2e3564e
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx
@@ -0,0 +1,139 @@
+import {
+ Divider,
+ Flex,
+ Heading,
+ Modal,
+ ModalBody,
+ ModalCloseButton,
+ ModalContent,
+ ModalHeader,
+ ModalOverlay,
+ useDisclosure,
+} from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import IAIIconButton from 'common/components/IAIIconButton';
+import IAISwitch from 'common/components/IAISwitch';
+import { ChangeEvent, useCallback } from 'react';
+import { FaCog } from 'react-icons/fa';
+import {
+ shouldAnimateEdgesChanged,
+ shouldColorEdgesChanged,
+ shouldSnapToGridChanged,
+ shouldValidateGraphChanged,
+} from '../store/nodesSlice';
+
+const selector = createSelector(stateSelector, ({ nodes }) => {
+ const {
+ shouldAnimateEdges,
+ shouldValidateGraph,
+ shouldSnapToGrid,
+ shouldColorEdges,
+ } = nodes;
+ return {
+ shouldAnimateEdges,
+ shouldValidateGraph,
+ shouldSnapToGrid,
+ shouldColorEdges,
+ };
+});
+
+const NodeEditorSettings = () => {
+ const { isOpen, onOpen, onClose } = useDisclosure();
+ const dispatch = useAppDispatch();
+ const {
+ shouldAnimateEdges,
+ shouldValidateGraph,
+ shouldSnapToGrid,
+ shouldColorEdges,
+ } = useAppSelector(selector);
+
+ const handleChangeShouldValidate = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(shouldValidateGraphChanged(e.target.checked));
+ },
+ [dispatch]
+ );
+
+ const handleChangeShouldAnimate = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(shouldAnimateEdgesChanged(e.target.checked));
+ },
+ [dispatch]
+ );
+
+ const handleChangeShouldSnap = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(shouldSnapToGridChanged(e.target.checked));
+ },
+ [dispatch]
+ );
+
+ const handleChangeShouldColor = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(shouldColorEdgesChanged(e.target.checked));
+ },
+ [dispatch]
+ );
+
+ return (
+ <>
+ }
+ onClick={onOpen}
+ />
+
+
+
+
+ Node Editor Settings
+
+
+
+ General
+
+
+
+
+
+
+ Advanced
+
+
+
+
+
+
+ >
+ );
+};
+
+export default NodeEditorSettings;
diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx
index 1d498f19f5..4525dc5f6b 100644
--- a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx
@@ -1,34 +1,26 @@
-import { Box } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
-import { memo } from 'react';
+import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
+import { omit } from 'lodash-es';
+import { useMemo } from 'react';
+import { useDebounce } from 'use-debounce';
import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph';
-const NodeGraphOverlay = () => {
- const state = useAppSelector((state: RootState) => state);
- const graph = buildNodesGraph(state);
-
- return (
-
- {JSON.stringify(graph, null, 2)}
-
+const useNodesGraph = () => {
+ const nodes = useAppSelector((state: RootState) => state.nodes);
+ const [debouncedNodes] = useDebounce(nodes, 300);
+ const graph = useMemo(
+ () => omit(buildNodesGraph(debouncedNodes), 'id'),
+ [debouncedNodes]
);
+
+ return graph;
};
-export default memo(NodeGraphOverlay);
+const NodeGraph = () => {
+ const graph = useNodesGraph();
+
+ return ;
+};
+
+export default NodeGraph;
diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx
new file mode 100644
index 0000000000..693940859f
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx
@@ -0,0 +1,42 @@
+import {
+ Box,
+ Slider,
+ SliderFilledTrack,
+ SliderThumb,
+ SliderTrack,
+} from '@chakra-ui/react';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { useCallback } from 'react';
+import { nodeOpacityChanged } from '../store/nodesSlice';
+
+export default function NodeOpacitySlider() {
+ const dispatch = useAppDispatch();
+ const nodeOpacity = useAppSelector((state) => state.nodes.nodeOpacity);
+
+ const handleChange = useCallback(
+ (v: number) => {
+ dispatch(nodeOpacityChanged(v));
+ },
+ [dispatch]
+ );
+
+ return (
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx
deleted file mode 100644
index bc7944a28b..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx
+++ /dev/null
@@ -1,36 +0,0 @@
-import { Box, useToken } from '@chakra-ui/react';
-import { useAppSelector } from 'app/store/storeHooks';
-import { PropsWithChildren } from 'react';
-import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
-import { NODE_MIN_WIDTH } from '../types/constants';
-
-type NodeWrapperProps = PropsWithChildren & {
- selected: boolean;
-};
-
-const NodeWrapper = (props: NodeWrapperProps) => {
- const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [
- 'nodeSelectedOutline',
- 'dark-lg',
- ]);
-
- const shift = useAppSelector((state) => state.hotkeys.shift);
-
- return (
-
- {props.children}
-
- );
-};
-
-export default NodeWrapper;
diff --git a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx b/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx
deleted file mode 100644
index 142e2a2990..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx
+++ /dev/null
@@ -1,73 +0,0 @@
-import { Flex, Image } from '@chakra-ui/react';
-import { RootState } from 'app/store/store';
-import { IAINoContentFallback } from 'common/components/IAIImageFallback';
-import { memo } from 'react';
-import { useDispatch, useSelector } from 'react-redux';
-import { NodeProps, OnResize } from 'reactflow';
-import { setProgressNodeSize } from '../store/nodesSlice';
-import IAINodeHeader from './IAINode/IAINodeHeader';
-import IAINodeResizer from './IAINode/IAINodeResizer';
-import NodeWrapper from './NodeWrapper';
-
-const ProgressImageNode = (props: NodeProps) => {
- const progressImage = useSelector(
- (state: RootState) => state.system.progressImage
- );
- const progressNodeSize = useSelector(
- (state: RootState) => state.nodes.progressNodeSize
- );
- const dispatch = useDispatch();
- const { selected } = props;
-
- const handleResize: OnResize = (_, newSize) => {
- dispatch(setProgressNodeSize(newSize));
- };
-
- return (
-
-
-
- {progressImage ? (
-
- ) : (
-
-
-
- )}
-
-
-
- );
-};
-
-export default memo(ProgressImageNode);
diff --git a/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx b/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx
index 796cdb010e..7416c6c555 100644
--- a/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx
@@ -2,18 +2,16 @@ import { ButtonGroup, Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react';
-import {
- FaCode,
- FaExpand,
- FaMinus,
- FaPlus,
- FaInfo,
- FaMapMarkerAlt,
-} from 'react-icons/fa';
-import { useReactFlow } from 'reactflow';
import { useTranslation } from 'react-i18next';
import {
- shouldShowGraphOverlayChanged,
+ FaExpand,
+ FaInfo,
+ FaMapMarkerAlt,
+ FaMinus,
+ FaPlus,
+} from 'react-icons/fa';
+import { useReactFlow } from 'reactflow';
+import {
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
} from '../store/nodesSlice';
@@ -22,9 +20,6 @@ const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
- const shouldShowGraphOverlay = useAppSelector(
- (state) => state.nodes.shouldShowGraphOverlay
- );
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
@@ -44,10 +39,6 @@ const ViewportControls = () => {
fitView();
}, [fitView]);
- const handleClickedToggleGraphOverlay = useCallback(() => {
- dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
- }, [shouldShowGraphOverlay, dispatch]);
-
const handleClickedToggleFieldTypeLegend = useCallback(() => {
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
}, [shouldShowFieldTypeLegend, dispatch]);
@@ -79,20 +70,6 @@ const ViewportControls = () => {
icon={}
/>
-
- }
- />
-
(
-
+
+
+
+
);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx
similarity index 91%
rename from invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx
rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx
index 39142ed48e..8b7fb942a6 100644
--- a/invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx
@@ -20,7 +20,7 @@ const MinimapPanel = () => {
const nodeColor = useColorModeValue(
'var(--invokeai-colors-accent-300)',
- 'var(--invokeai-colors-accent-700)'
+ 'var(--invokeai-colors-accent-600)'
);
const maskColor = useColorModeValue(
@@ -32,10 +32,9 @@ const MinimapPanel = () => {
<>
{shouldShowMinimapPanel && (
{
return (
@@ -15,9 +14,8 @@ const TopCenterPanel = () => {
-
-
+
);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/TopLeftPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx
similarity index 100%
rename from invokeai/frontend/web/src/features/nodes/components/panels/TopLeftPanel.tsx
rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx
diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx
similarity index 55%
rename from invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx
rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx
index e3e3a871c8..7facf3973f 100644
--- a/invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx
@@ -1,22 +1,16 @@
-import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { Panel } from 'reactflow';
import FieldTypeLegend from '../FieldTypeLegend';
-import NodeGraphOverlay from '../NodeGraphOverlay';
const TopRightPanel = () => {
- const shouldShowGraphOverlay = useAppSelector(
- (state: RootState) => state.nodes.shouldShowGraphOverlay
- );
const shouldShowFieldTypeLegend = useAppSelector(
- (state: RootState) => state.nodes.shouldShowFieldTypeLegend
+ (state) => state.nodes.shouldShowFieldTypeLegend
);
return (
{shouldShowFieldTypeLegend && }
- {shouldShowGraphOverlay && }
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx
deleted file mode 100644
index 8e478c907c..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx
+++ /dev/null
@@ -1,15 +0,0 @@
-import {
- ArrayInputFieldTemplate,
- ArrayInputFieldValue,
-} from 'features/nodes/types/types';
-import { memo } from 'react';
-import { FaList } from 'react-icons/fa';
-import { FieldComponentProps } from './types';
-
-const ArrayInputFieldComponent = (
- _props: FieldComponentProps
-) => {
- return ;
-};
-
-export default memo(ArrayInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx
deleted file mode 100644
index 5f26bc4f2a..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx
+++ /dev/null
@@ -1,37 +0,0 @@
-import { Select } from '@chakra-ui/react';
-import { useAppDispatch } from 'app/store/storeHooks';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
-import {
- EnumInputFieldTemplate,
- EnumInputFieldValue,
-} from 'features/nodes/types/types';
-import { ChangeEvent, memo } from 'react';
-import { FieldComponentProps } from './types';
-
-const EnumInputFieldComponent = (
- props: FieldComponentProps
-) => {
- const { nodeId, field, template } = props;
-
- const dispatch = useAppDispatch();
-
- const handleValueChanged = (e: ChangeEvent) => {
- dispatch(
- fieldValueChanged({
- nodeId,
- fieldName: field.name,
- value: e.target.value,
- })
- );
- };
-
- return (
-
- );
-};
-
-export default memo(EnumInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx
new file mode 100644
index 0000000000..d9f8f951bc
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx
@@ -0,0 +1,47 @@
+import { MenuItem, MenuList } from '@chakra-ui/react';
+import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+} from 'features/nodes/types/types';
+import { MouseEvent, useCallback } from 'react';
+import { menuListMotionProps } from 'theme/components/menu';
+
+type Props = {
+ nodeId: string;
+ field: InputFieldValue;
+ fieldTemplate: InputFieldTemplate;
+ children: ContextMenuProps['children'];
+};
+
+const FieldContextMenu = (props: Props) => {
+ const skipEvent = useCallback((e: MouseEvent) => {
+ e.preventDefault();
+ }, []);
+
+ return (
+
+ menuProps={{
+ size: 'sm',
+ isLazy: true,
+ }}
+ menuButtonProps={{
+ bg: 'transparent',
+ _hover: { bg: 'transparent' },
+ }}
+ renderMenu={() => (
+
+
+
+ )}
+ >
+ {props.children}
+
+ );
+};
+
+export default FieldContextMenu;
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx
new file mode 100644
index 0000000000..f47e68976d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx
@@ -0,0 +1,122 @@
+import { Tooltip } from '@chakra-ui/react';
+import { CSSProperties, memo, useMemo } from 'react';
+import { Handle, HandleType, NodeProps, Position } from 'reactflow';
+import {
+ FIELDS,
+ HANDLE_TOOLTIP_OPEN_DELAY,
+ colorTokenToCssVar,
+} from '../../types/constants';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+ OutputFieldTemplate,
+ OutputFieldValue,
+} from '../../types/types';
+
+export const handleBaseStyles: CSSProperties = {
+ position: 'absolute',
+ width: '1rem',
+ height: '1rem',
+ borderWidth: 0,
+ zIndex: 1,
+};
+
+export const inputHandleStyles: CSSProperties = {
+ left: '-1rem',
+};
+
+export const outputHandleStyles: CSSProperties = {
+ right: '-0.5rem',
+};
+
+type FieldHandleProps = {
+ nodeProps: NodeProps;
+ nodeTemplate: InvocationTemplate;
+ field: InputFieldValue | OutputFieldValue;
+ fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
+ handleType: HandleType;
+ isConnectionInProgress: boolean;
+ isConnectionStartField: boolean;
+ connectionError: string | null;
+};
+
+const FieldHandle = (props: FieldHandleProps) => {
+ const {
+ fieldTemplate,
+ handleType,
+ isConnectionInProgress,
+ isConnectionStartField,
+ connectionError,
+ } = props;
+ const { name, type } = fieldTemplate;
+ const { color, title } = FIELDS[type];
+
+ const styles: CSSProperties = useMemo(() => {
+ const s: CSSProperties = {
+ backgroundColor: colorTokenToCssVar(color),
+ position: 'absolute',
+ width: '1rem',
+ height: '1rem',
+ borderWidth: 0,
+ zIndex: 1,
+ };
+
+ if (handleType === 'target') {
+ s.insetInlineStart = '-1rem';
+ } else {
+ s.insetInlineEnd = '-1rem';
+ }
+
+ if (isConnectionInProgress && !isConnectionStartField && connectionError) {
+ s.filter = 'opacity(0.4) grayscale(0.7)';
+ }
+
+ if (isConnectionInProgress && connectionError) {
+ if (isConnectionStartField) {
+ s.cursor = 'grab';
+ } else {
+ s.cursor = 'not-allowed';
+ }
+ } else {
+ s.cursor = 'crosshair';
+ }
+
+ return s;
+ }, [
+ color,
+ connectionError,
+ handleType,
+ isConnectionInProgress,
+ isConnectionStartField,
+ ]);
+
+ const tooltip = useMemo(() => {
+ if (isConnectionInProgress && isConnectionStartField) {
+ return title;
+ }
+ if (isConnectionInProgress && connectionError) {
+ return connectionError ?? title;
+ }
+ return title;
+ }, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
+
+ return (
+
+
+
+ );
+};
+
+export default memo(FieldHandle);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx
new file mode 100644
index 0000000000..fc239addf3
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx
@@ -0,0 +1,161 @@
+import {
+ Editable,
+ EditableInput,
+ EditablePreview,
+ Flex,
+ useEditableControls,
+} from '@chakra-ui/react';
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAIDraggable from 'common/components/IAIDraggable';
+import { NodeFieldDraggableData } from 'features/dnd/types';
+import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+} from 'features/nodes/types/types';
+import {
+ MouseEvent,
+ memo,
+ useCallback,
+ useEffect,
+ useMemo,
+ useState,
+} from 'react';
+
+interface Props {
+ nodeData: InvocationNodeData;
+ nodeTemplate: InvocationTemplate;
+ field: InputFieldValue;
+ fieldTemplate: InputFieldTemplate;
+ isDraggable?: boolean;
+}
+
+const FieldTitle = (props: Props) => {
+ const { nodeData, field, fieldTemplate, isDraggable = false } = props;
+ const { label } = field;
+ const { title, input } = fieldTemplate;
+ const { id: nodeId } = nodeData;
+ const dispatch = useAppDispatch();
+ const [localTitle, setLocalTitle] = useState(label || title);
+
+ const draggableData: NodeFieldDraggableData | undefined = useMemo(
+ () =>
+ input !== 'connection' && isDraggable
+ ? {
+ id: `${nodeId}-${field.name}`,
+ payloadType: 'NODE_FIELD',
+ payload: { nodeId, field, fieldTemplate },
+ }
+ : undefined,
+ [field, fieldTemplate, input, isDraggable, nodeId]
+ );
+
+ const handleSubmit = useCallback(
+ async (newTitle: string) => {
+ dispatch(
+ fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle })
+ );
+ setLocalTitle(newTitle || title);
+ },
+ [dispatch, nodeId, field.name, title]
+ );
+
+ const handleChange = useCallback((newTitle: string) => {
+ setLocalTitle(newTitle);
+ }, []);
+
+ useEffect(() => {
+ // Another component may change the title; sync local title with global state
+ setLocalTitle(label || title);
+ }, [label, title]);
+
+ return (
+
+
+
+
+
+
+
+ );
+};
+
+export default memo(FieldTitle);
+
+type EditableControlsProps = {
+ draggableData?: NodeFieldDraggableData;
+};
+
+function EditableControls(props: EditableControlsProps) {
+ const { isEditing, getEditButtonProps } = useEditableControls();
+ const handleDoubleClick = useCallback(
+ (e: MouseEvent) => {
+ const { onClick } = getEditButtonProps();
+ if (!onClick) {
+ return;
+ }
+ onClick(e);
+ },
+ [getEditButtonProps]
+ );
+
+ if (isEditing) {
+ return null;
+ }
+
+ if (props.draggableData) {
+ return (
+
+ );
+ }
+
+ return (
+
+ );
+}
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx
new file mode 100644
index 0000000000..bf5cd3cd9b
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx
@@ -0,0 +1,41 @@
+import { Flex, Text } from '@chakra-ui/react';
+import { FIELDS } from 'features/nodes/types/constants';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+ OutputFieldTemplate,
+ OutputFieldValue,
+ isInputFieldTemplate,
+ isInputFieldValue,
+} from 'features/nodes/types/types';
+import { startCase } from 'lodash-es';
+
+interface Props {
+ nodeData: InvocationNodeData;
+ nodeTemplate: InvocationTemplate;
+ field: InputFieldValue | OutputFieldValue;
+ fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
+}
+
+const FieldTooltipContent = ({ field, fieldTemplate }: Props) => {
+ const isInputTemplate = isInputFieldTemplate(fieldTemplate);
+
+ return (
+
+
+ {isInputFieldValue(field) && field.label
+ ? `${field.label} (${fieldTemplate.title})`
+ : fieldTemplate.title}
+
+
+ {fieldTemplate.description}
+
+ Type: {FIELDS[fieldTemplate.type].title}
+ {isInputTemplate && Input: {startCase(fieldTemplate.input)}}
+
+ );
+};
+
+export default FieldTooltipContent;
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx
new file mode 100644
index 0000000000..67f4369384
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx
@@ -0,0 +1,153 @@
+import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
+import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
+import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
+import {
+ InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+} from 'features/nodes/types/types';
+import { PropsWithChildren, useMemo } from 'react';
+import { NodeProps } from 'reactflow';
+import FieldHandle from './FieldHandle';
+import FieldTitle from './FieldTitle';
+import FieldTooltipContent from './FieldTooltipContent';
+import InputFieldRenderer from './InputFieldRenderer';
+
+interface Props {
+ nodeProps: NodeProps;
+ nodeTemplate: InvocationTemplate;
+ field: InputFieldValue;
+}
+
+const InputField = (props: Props) => {
+ const { nodeProps, nodeTemplate, field } = props;
+ const { id: nodeId } = nodeProps.data;
+
+ const {
+ isConnected,
+ isConnectionInProgress,
+ isConnectionStartField,
+ connectionError,
+ shouldDim,
+ } = useConnectionState({ nodeId, field, kind: 'input' });
+
+ const fieldTemplate = useMemo(
+ () => nodeTemplate.inputs[field.name],
+ [field.name, nodeTemplate.inputs]
+ );
+
+ const isMissingInput = useMemo(() => {
+ if (!fieldTemplate) {
+ return false;
+ }
+
+ if (!fieldTemplate.required) {
+ return false;
+ }
+
+ if (!isConnected && fieldTemplate.input === 'connection') {
+ return true;
+ }
+
+ if (!field.value && !isConnected && fieldTemplate.input === 'any') {
+ return true;
+ }
+ }, [fieldTemplate, isConnected, field.value]);
+
+ if (!fieldTemplate) {
+ return (
+
+
+ Unknown input: {field.name}
+
+
+ );
+ }
+
+ return (
+
+
+
+ }
+ openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
+ placement="top"
+ shouldWrapChildren
+ hasArrow
+ >
+
+
+
+
+
+
+
+ {fieldTemplate.input !== 'direct' && (
+
+ )}
+
+ );
+};
+
+export default InputField;
+
+type InputFieldWrapperProps = PropsWithChildren<{
+ shouldDim: boolean;
+}>;
+
+const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => (
+
+ {children}
+
+);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx
new file mode 100644
index 0000000000..ce9d88af0a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx
@@ -0,0 +1,293 @@
+import { Box } from '@chakra-ui/react';
+import { memo } from 'react';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+} from '../../types/types';
+import BooleanInputField from './fieldTypes/BooleanInputField';
+import ClipInputField from './fieldTypes/ClipInputField';
+import CollectionInputField from './fieldTypes/CollectionInputField';
+import CollectionItemInputField from './fieldTypes/CollectionItemInputField';
+import ColorInputField from './fieldTypes/ColorInputField';
+import ConditioningInputField from './fieldTypes/ConditioningInputField';
+import ControlInputField from './fieldTypes/ControlInputField';
+import ControlNetModelInputField from './fieldTypes/ControlNetModelInputField';
+import EnumInputField from './fieldTypes/EnumInputField';
+import ImageCollectionInputField from './fieldTypes/ImageCollectionInputField';
+import ImageInputField from './fieldTypes/ImageInputField';
+import LatentsInputField from './fieldTypes/LatentsInputField';
+import LoRAModelInputField from './fieldTypes/LoRAModelInputField';
+import MainModelInputField from './fieldTypes/MainModelInputField';
+import NumberInputField from './fieldTypes/NumberInputField';
+import RefinerModelInputField from './fieldTypes/RefinerModelInputField';
+import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField';
+import StringInputField from './fieldTypes/StringInputField';
+import UnetInputField from './fieldTypes/UnetInputField';
+import VaeInputField from './fieldTypes/VaeInputField';
+import VaeModelInputField from './fieldTypes/VaeModelInputField';
+
+type InputFieldProps = {
+ nodeData: InvocationNodeData;
+ nodeTemplate: InvocationTemplate;
+ field: InputFieldValue;
+ fieldTemplate: InputFieldTemplate;
+};
+
+// build an individual input element based on the schema
+const InputFieldRenderer = (props: InputFieldProps) => {
+ const { nodeData, nodeTemplate, field, fieldTemplate } = props;
+ const { type } = field;
+
+ if (type === 'string' && fieldTemplate.type === 'string') {
+ return (
+
+ );
+ }
+
+ if (type === 'boolean' && fieldTemplate.type === 'boolean') {
+ return (
+
+ );
+ }
+
+ if (
+ (type === 'integer' && fieldTemplate.type === 'integer') ||
+ (type === 'float' && fieldTemplate.type === 'float') ||
+ (type === 'Seed' && fieldTemplate.type === 'Seed')
+ ) {
+ return (
+
+ );
+ }
+
+ if (type === 'enum' && fieldTemplate.type === 'enum') {
+ return (
+
+ );
+ }
+
+ if (type === 'ImageField' && fieldTemplate.type === 'ImageField') {
+ return (
+
+ );
+ }
+
+ if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') {
+ return (
+
+ );
+ }
+
+ if (
+ type === 'ConditioningField' &&
+ fieldTemplate.type === 'ConditioningField'
+ ) {
+ return (
+
+ );
+ }
+
+ if (type === 'UNetField' && fieldTemplate.type === 'UNetField') {
+ return (
+
+ );
+ }
+
+ if (type === 'ClipField' && fieldTemplate.type === 'ClipField') {
+ return (
+
+ );
+ }
+
+ if (type === 'VaeField' && fieldTemplate.type === 'VaeField') {
+ return (
+
+ );
+ }
+
+ if (type === 'ControlField' && fieldTemplate.type === 'ControlField') {
+ return (
+
+ );
+ }
+
+ if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') {
+ return (
+
+ );
+ }
+
+ if (
+ type === 'SDXLRefinerModelField' &&
+ fieldTemplate.type === 'SDXLRefinerModelField'
+ ) {
+ return (
+
+ );
+ }
+
+ if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') {
+ return (
+
+ );
+ }
+
+ if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') {
+ return (
+
+ );
+ }
+
+ if (
+ type === 'ControlNetModelField' &&
+ fieldTemplate.type === 'ControlNetModelField'
+ ) {
+ return (
+
+ );
+ }
+
+ if (type === 'Collection' && fieldTemplate.type === 'Collection') {
+ return (
+
+ );
+ }
+
+ if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') {
+ return (
+
+ );
+ }
+
+ if (type === 'ColorField' && fieldTemplate.type === 'ColorField') {
+ return (
+
+ );
+ }
+
+ if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') {
+ return (
+
+ );
+ }
+
+ if (
+ type === 'SDXLMainModelField' &&
+ fieldTemplate.type === 'SDXLMainModelField'
+ ) {
+ return (
+
+ );
+ }
+
+ return Unknown field type: {type};
+};
+
+export default memo(InputFieldRenderer);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx
deleted file mode 100644
index 6fa89345bf..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx
+++ /dev/null
@@ -1,15 +0,0 @@
-import {
- ItemInputFieldTemplate,
- ItemInputFieldValue,
-} from 'features/nodes/types/types';
-import { memo } from 'react';
-import { FaAddressCard } from 'react-icons/fa';
-import { FieldComponentProps } from './types';
-
-const ItemInputFieldComponent = (
- _props: FieldComponentProps
-) => {
- return ;
-};
-
-export default memo(ItemInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx
new file mode 100644
index 0000000000..98a8000b1a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx
@@ -0,0 +1,88 @@
+import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
+import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+} from 'features/nodes/types/types';
+import { memo } from 'react';
+import FieldTitle from './FieldTitle';
+import FieldTooltipContent from './FieldTooltipContent';
+import InputFieldRenderer from './InputFieldRenderer';
+
+type Props = {
+ nodeData: InvocationNodeData;
+ nodeTemplate: InvocationTemplate;
+ field: InputFieldValue;
+ fieldTemplate: InputFieldTemplate;
+};
+
+const LinearViewField = ({
+ nodeData,
+ nodeTemplate,
+ field,
+ fieldTemplate,
+}: Props) => {
+ // const dispatch = useAppDispatch();
+ // const handleRemoveField = useCallback(() => {
+ // dispatch(
+ // workflowExposedFieldRemoved({
+ // nodeId: nodeData.id,
+ // fieldName: field.name,
+ // })
+ // );
+ // }, [dispatch, field.name, nodeData.id]);
+
+ return (
+
+
+
+ }
+ openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
+ placement="top"
+ shouldWrapChildren
+ hasArrow
+ >
+
+
+
+
+
+
+
+ );
+};
+
+export default memo(LinearViewField);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx
new file mode 100644
index 0000000000..5a29d1ab7e
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx
@@ -0,0 +1,114 @@
+import {
+ Flex,
+ FormControl,
+ FormLabel,
+ Spacer,
+ Tooltip,
+} from '@chakra-ui/react';
+import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
+import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
+import {
+ InvocationNodeData,
+ InvocationTemplate,
+ OutputFieldValue,
+} from 'features/nodes/types/types';
+import { PropsWithChildren, useMemo } from 'react';
+import { NodeProps } from 'reactflow';
+import FieldHandle from './FieldHandle';
+import FieldTooltipContent from './FieldTooltipContent';
+
+interface Props {
+ nodeProps: NodeProps;
+ nodeTemplate: InvocationTemplate;
+ field: OutputFieldValue;
+}
+
+const OutputField = (props: Props) => {
+ const { nodeTemplate, nodeProps, field } = props;
+
+ const {
+ isConnected,
+ isConnectionInProgress,
+ isConnectionStartField,
+ connectionError,
+ shouldDim,
+ } = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' });
+
+ const fieldTemplate = useMemo(
+ () => nodeTemplate.outputs[field.name],
+ [field.name, nodeTemplate]
+ );
+
+ if (!fieldTemplate) {
+ return (
+
+
+ Unknown output: {field.name}
+
+
+ );
+ }
+
+ return (
+
+
+
+ }
+ openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
+ placement="top"
+ shouldWrapChildren
+ hasArrow
+ >
+
+
+ {fieldTemplate?.title}
+
+
+
+
+
+ );
+};
+
+export default OutputField;
+
+type OutputFieldWrapperProps = PropsWithChildren<{
+ shouldDim: boolean;
+}>;
+
+const OutputFieldWrapper = ({
+ shouldDim,
+ children,
+}: OutputFieldWrapperProps) => (
+
+ {children}
+
+);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx
deleted file mode 100644
index 18cf7e997f..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx
+++ /dev/null
@@ -1,36 +0,0 @@
-import { Input, Textarea } from '@chakra-ui/react';
-import { useAppDispatch } from 'app/store/storeHooks';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
-import {
- StringInputFieldTemplate,
- StringInputFieldValue,
-} from 'features/nodes/types/types';
-import { ChangeEvent, memo } from 'react';
-import { FieldComponentProps } from './types';
-
-const StringInputFieldComponent = (
- props: FieldComponentProps
-) => {
- const { nodeId, field } = props;
- const dispatch = useAppDispatch();
-
- const handleValueChanged = (
- e: ChangeEvent
- ) => {
- dispatch(
- fieldValueChanged({
- nodeId,
- fieldName: field.name,
- value: e.target.value,
- })
- );
- };
-
- return ['prompt', 'style'].includes(field.name.toLowerCase()) ? (
-
- ) : (
-
- );
-};
-
-export default memo(StringInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/BooleanInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/BooleanInputField.tsx
similarity index 53%
rename from invokeai/frontend/web/src/features/nodes/components/fields/BooleanInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/BooleanInputField.tsx
index 52a60253ba..00a2d2bd10 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/BooleanInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/BooleanInputField.tsx
@@ -1,29 +1,33 @@
import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
} from 'features/nodes/types/types';
-import { ChangeEvent, memo } from 'react';
+import { ChangeEvent, memo, useCallback } from 'react';
import { FieldComponentProps } from './types';
const BooleanInputFieldComponent = (
props: FieldComponentProps
) => {
- const { nodeId, field } = props;
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const dispatch = useAppDispatch();
- const handleValueChanged = (e: ChangeEvent) => {
- dispatch(
- fieldValueChanged({
- nodeId,
- fieldName: field.name,
- value: e.target.checked,
- })
- );
- };
+ const handleValueChanged = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(
+ fieldBooleanValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value: e.target.checked,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
return (
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ClipInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ClipInputField.tsx
similarity index 100%
rename from invokeai/frontend/web/src/features/nodes/components/fields/ClipInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ClipInputField.tsx
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionInputField.tsx
new file mode 100644
index 0000000000..99c88af2cb
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionInputField.tsx
@@ -0,0 +1,17 @@
+import {
+ CollectionInputFieldTemplate,
+ CollectionInputFieldValue,
+} from 'features/nodes/types/types';
+import { memo } from 'react';
+import { FieldComponentProps } from './types';
+
+const CollectionInputFieldComponent = (
+ _props: FieldComponentProps<
+ CollectionInputFieldValue,
+ CollectionInputFieldTemplate
+ >
+) => {
+ return null;
+};
+
+export default memo(CollectionInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionItemInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionItemInputField.tsx
new file mode 100644
index 0000000000..00f753d8d3
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionItemInputField.tsx
@@ -0,0 +1,17 @@
+import {
+ CollectionItemInputFieldTemplate,
+ CollectionItemInputFieldValue,
+} from 'features/nodes/types/types';
+import { memo } from 'react';
+import { FieldComponentProps } from './types';
+
+const CollectionItemInputFieldComponent = (
+ _props: FieldComponentProps<
+ CollectionItemInputFieldValue,
+ CollectionItemInputFieldTemplate
+ >
+) => {
+ return null;
+};
+
+export default memo(CollectionItemInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ColorInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ColorInputField.tsx
similarity index 57%
rename from invokeai/frontend/web/src/features/nodes/components/fields/ColorInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ColorInputField.tsx
index c4884dcffc..c4a4d19a1e 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ColorInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ColorInputField.tsx
@@ -1,23 +1,33 @@
+import { useAppDispatch } from 'app/store/storeHooks';
+import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice';
import {
ColorInputFieldTemplate,
ColorInputFieldValue,
} from 'features/nodes/types/types';
-import { memo } from 'react';
-import { FieldComponentProps } from './types';
+import { memo, useCallback } from 'react';
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
-import { useAppDispatch } from 'app/store/storeHooks';
+import { FieldComponentProps } from './types';
const ColorInputFieldComponent = (
props: FieldComponentProps
) => {
- const { nodeId, field } = props;
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const dispatch = useAppDispatch();
- const handleValueChanged = (value: RgbaColor) => {
- dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
- };
+ const handleValueChanged = useCallback(
+ (value: RgbaColor) => {
+ dispatch(
+ fieldColorValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
return (
) => {
- const { nodeId, field } = props;
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const controlNetModel = field.value;
const dispatch = useAppDispatch();
@@ -73,7 +74,7 @@ const ControlNetModelInputFieldComponent = (
}
dispatch(
- fieldValueChanged({
+ fieldControlNetModelValueChanged({
nodeId,
fieldName: field.name,
value: newControlNetModel,
@@ -85,10 +86,8 @@ const ControlNetModelInputFieldComponent = (
return (
+) => {
+ const { nodeData, field, fieldTemplate } = props;
+ const nodeId = nodeData.id;
+
+ const dispatch = useAppDispatch();
+
+ const handleValueChanged = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(
+ fieldEnumModelValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value: e.target.value,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ return (
+
+ );
+};
+
+export default memo(EnumInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageCollectionInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ImageCollectionInputField.tsx
similarity index 86%
rename from invokeai/frontend/web/src/features/nodes/components/fields/ImageCollectionInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ImageCollectionInputField.tsx
index 8ecd6c8cd9..1ca820939b 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageCollectionInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ImageCollectionInputField.tsx
@@ -5,13 +5,11 @@ import {
import { memo } from 'react';
import { Flex } from '@chakra-ui/react';
-import {
- NodesMultiImageDropData,
- isValidDrop,
- useDroppable,
-} from 'app/components/ImageDnd/typesafeDnd';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
+import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
+import { NodesMultiImageDropData } from 'features/dnd/types';
+import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { FieldComponentProps } from './types';
@@ -21,7 +19,8 @@ const ImageCollectionInputFieldComponent = (
ImageCollectionInputFieldTemplate
>
) => {
- const { nodeId, field } = props;
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
// const dispatch = useAppDispatch();
@@ -41,14 +40,14 @@ const ImageCollectionInputFieldComponent = (
const droppableData: NodesMultiImageDropData = {
id: `node-${nodeId}-${field.name}`,
actionType: 'SET_MULTI_NODES_IMAGE',
- context: { nodeId, fieldName: field.name },
+ context: { nodeId: nodeId, fieldName: field.name },
};
const {
isOver,
setNodeRef: setDroppableRef,
active,
- } = useDroppable({
+ } = useDroppableTypesafe({
id: `node_${nodeId}`,
data: droppableData,
});
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ImageInputField.tsx
similarity index 88%
rename from invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ImageInputField.tsx
index b16897d889..a7305d889d 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ImageInputField.tsx
@@ -1,12 +1,12 @@
import { Flex } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAIDndImage from 'common/components/IAIDndImage';
import {
TypesafeDraggableData,
TypesafeDroppableData,
-} from 'app/components/ImageDnd/typesafeDnd';
-import { useAppDispatch } from 'app/store/storeHooks';
-import IAIDndImage from 'common/components/IAIDndImage';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+} from 'features/dnd/types';
+import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import {
ImageInputFieldTemplate,
ImageInputFieldValue,
@@ -19,8 +19,8 @@ import { FieldComponentProps } from './types';
const ImageInputFieldComponent = (
props: FieldComponentProps
) => {
- const { nodeId, field } = props;
-
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const { currentData: imageDTO } = useGetImageDTOQuery(
@@ -29,7 +29,7 @@ const ImageInputFieldComponent = (
const handleReset = useCallback(() => {
dispatch(
- fieldValueChanged({
+ fieldImageValueChanged({
nodeId,
fieldName: field.name,
value: undefined,
@@ -79,6 +79,9 @@ const ImageInputFieldComponent = (
droppableData={droppableData}
draggableData={draggableData}
onClickReset={handleReset}
+ withResetIcon
+ thumbnail
+ useThumbailFallback
postUploadAction={postUploadAction}
/>
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LatentsInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/LatentsInputField.tsx
similarity index 100%
rename from invokeai/frontend/web/src/features/nodes/components/fields/LatentsInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/LatentsInputField.tsx
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/LoRAModelInputField.tsx
similarity index 92%
rename from invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/LoRAModelInputField.tsx
index 27d15fb93e..8aae6ee9a4 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/LoRAModelInputField.tsx
@@ -3,7 +3,7 @@ import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
LoRAModelInputFieldTemplate,
LoRAModelInputFieldValue,
@@ -21,7 +21,8 @@ const LoRAModelInputFieldComponent = (
LoRAModelInputFieldTemplate
>
) => {
- const { nodeId, field } = props;
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const lora = field.value;
const dispatch = useAppDispatch();
const { data: loraModels } = useGetLoRAModelsQuery();
@@ -68,7 +69,7 @@ const LoRAModelInputFieldComponent = (
}
dispatch(
- fieldValueChanged({
+ fieldLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value: newLoRAModel,
@@ -90,11 +91,8 @@ const LoRAModelInputFieldComponent = (
return (
0 ? 'Select a LoRA' : 'No LoRAs available'}
data={data}
nothingFound="No matching LoRAs"
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/MainModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/MainModelInputField.tsx
new file mode 100644
index 0000000000..f1047f52cb
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/MainModelInputField.tsx
@@ -0,0 +1,144 @@
+import { Flex, Text } from '@chakra-ui/react';
+import { SelectItem } from '@mantine/core';
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
+import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
+import {
+ MainModelInputFieldTemplate,
+ MainModelInputFieldValue,
+} from 'features/nodes/types/types';
+import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
+import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
+import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
+import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
+import { forEach } from 'lodash-es';
+import { memo, useCallback, useMemo } from 'react';
+import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
+import {
+ useGetMainModelsQuery,
+ useGetOnnxModelsQuery,
+} from 'services/api/endpoints/models';
+import { FieldComponentProps } from './types';
+
+const MainModelInputFieldComponent = (
+ props: FieldComponentProps<
+ MainModelInputFieldValue,
+ MainModelInputFieldTemplate
+ >
+) => {
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
+ const dispatch = useAppDispatch();
+ const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
+
+ const { data: onnxModels, isLoading: isLoadingOnnxModels } =
+ useGetOnnxModelsQuery(NON_SDXL_MAIN_MODELS);
+ const { data: mainModels, isLoading: isLoadingMainModels } =
+ useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
+
+ const isLoadingModels = useMemo(
+ () => isLoadingOnnxModels || isLoadingMainModels,
+ [isLoadingOnnxModels, isLoadingMainModels]
+ );
+
+ const data = useMemo(() => {
+ if (!mainModels) {
+ return [];
+ }
+
+ const data: SelectItem[] = [];
+
+ forEach(mainModels.entities, (model, id) => {
+ if (!model) {
+ return;
+ }
+
+ data.push({
+ value: id,
+ label: model.model_name,
+ group: MODEL_TYPE_MAP[model.base_model],
+ });
+ });
+
+ if (onnxModels) {
+ forEach(onnxModels.entities, (model, id) => {
+ if (!model) {
+ return;
+ }
+
+ data.push({
+ value: id,
+ label: model.model_name,
+ group: MODEL_TYPE_MAP[model.base_model],
+ });
+ });
+ }
+ return data;
+ }, [mainModels, onnxModels]);
+
+ // grab the full model entity from the RTK Query cache
+ // TODO: maybe we should just store the full model entity in state?
+ const selectedModel = useMemo(
+ () =>
+ (mainModels?.entities[
+ `${field.value?.base_model}/main/${field.value?.model_name}`
+ ] ||
+ onnxModels?.entities[
+ `${field.value?.base_model}/onnx/${field.value?.model_name}`
+ ]) ??
+ null,
+ [
+ field.value?.base_model,
+ field.value?.model_name,
+ mainModels?.entities,
+ onnxModels?.entities,
+ ]
+ );
+
+ const handleChangeModel = useCallback(
+ (v: string | null) => {
+ if (!v) {
+ return;
+ }
+
+ const newModel = modelIdToMainModelParam(v);
+
+ if (!newModel) {
+ return;
+ }
+
+ dispatch(
+ fieldMainModelValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value: newModel,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ return (
+
+ {isLoadingModels ? (
+ Loading...
+ ) : (
+ 0 ? 'Select a model' : 'No models available'
+ }
+ data={data}
+ error={!selectedModel}
+ disabled={data.length === 0}
+ onChange={handleChangeModel}
+ />
+ )}
+ {isSyncModelEnabled && }
+
+ );
+};
+
+export default memo(MainModelInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/NumberInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/NumberInputField.tsx
similarity index 67%
rename from invokeai/frontend/web/src/features/nodes/components/fields/NumberInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/NumberInputField.tsx
index 50d69a6496..3d82003882 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/NumberInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/NumberInputField.tsx
@@ -7,27 +7,34 @@ import {
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
import {
FloatInputFieldTemplate,
FloatInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
+ SeedInputFieldTemplate,
+ SeedInputFieldValue,
} from 'features/nodes/types/types';
-import { memo, useEffect, useState } from 'react';
+import { memo, useEffect, useMemo, useState } from 'react';
import { FieldComponentProps } from './types';
const NumberInputFieldComponent = (
props: FieldComponentProps<
- IntegerInputFieldValue | FloatInputFieldValue,
- IntegerInputFieldTemplate | FloatInputFieldTemplate
+ IntegerInputFieldValue | FloatInputFieldValue | SeedInputFieldValue,
+ IntegerInputFieldTemplate | FloatInputFieldTemplate | SeedInputFieldTemplate
>
) => {
- const { nodeId, field } = props;
+ const { nodeData, field, fieldTemplate } = props;
+ const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const [valueAsString, setValueAsString] = useState(
String(field.value)
);
+ const isIntegerField = useMemo(
+ () => fieldTemplate.type === 'integer' || fieldTemplate.type === 'Seed',
+ [fieldTemplate.type]
+ );
const handleValueChanged = (v: string) => {
setValueAsString(v);
@@ -35,13 +42,10 @@ const NumberInputFieldComponent = (
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
dispatch(
- fieldValueChanged({
+ fieldNumberValueChanged({
nodeId,
fieldName: field.name,
- value:
- props.template.type === 'integer'
- ? Math.floor(Number(v))
- : Number(v),
+ value: isIntegerField ? Math.floor(Number(v)) : Number(v),
})
);
}
@@ -60,8 +64,8 @@ const NumberInputFieldComponent = (
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/RefinerModelInputField.tsx
similarity index 89%
rename from invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/RefinerModelInputField.tsx
index 28c6567e8d..4a419b51d6 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/RefinerModelInputField.tsx
@@ -2,13 +2,14 @@ import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
- RefinerModelInputFieldTemplate,
- RefinerModelInputFieldValue,
+ SDXLRefinerModelInputFieldTemplate,
+ SDXLRefinerModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
+import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
@@ -16,16 +17,15 @@ import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
-import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const RefinerModelInputFieldComponent = (
props: FieldComponentProps<
- RefinerModelInputFieldValue,
- RefinerModelInputFieldTemplate
+ SDXLRefinerModelInputFieldValue,
+ SDXLRefinerModelInputFieldTemplate
>
) => {
- const { nodeId, field } = props;
-
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
@@ -77,7 +77,7 @@ const RefinerModelInputFieldComponent = (
}
dispatch(
- fieldValueChanged({
+ fieldRefinerModelValueChanged({
nodeId,
fieldName: field.name,
value: newModel,
@@ -97,10 +97,8 @@ const RefinerModelInputFieldComponent = (
) : (
0 ? 'Select a model' : 'No models available'}
data={data}
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/SDXLMainModelInputField.tsx
similarity index 78%
rename from invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/SDXLMainModelInputField.tsx
index 154c3c1cb0..89bd6b2b65 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/SDXLMainModelInputField.tsx
@@ -1,40 +1,41 @@
-import { useAppDispatch } from 'app/store/storeHooks';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
-import {
- MainModelInputFieldValue,
- ModelInputFieldTemplate,
-} from 'features/nodes/types/types';
-
-import { Box, Flex } from '@chakra-ui/react';
+import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
+import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
+import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
+import {
+ SDXLMainModelInputFieldTemplate,
+ SDXLMainModelInputFieldValue,
+} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
+import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
+import { SDXL_MAIN_MODELS } from 'services/api/constants';
import {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
-import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { FieldComponentProps } from './types';
-import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
const ModelInputFieldComponent = (
- props: FieldComponentProps
+ props: FieldComponentProps<
+ SDXLMainModelInputFieldValue,
+ SDXLMainModelInputFieldTemplate
+ >
) => {
- const { nodeId, field } = props;
-
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
- const { data: onnxModels } = useGetOnnxModelsQuery(NON_REFINER_BASE_MODELS);
- const { data: mainModels, isLoading } = useGetMainModelsQuery(
- NON_REFINER_BASE_MODELS
- );
+ const { data: onnxModels } = useGetOnnxModelsQuery(SDXL_MAIN_MODELS);
+ const { data: mainModels, isLoading } =
+ useGetMainModelsQuery(SDXL_MAIN_MODELS);
const data = useMemo(() => {
if (!mainModels) {
@@ -44,7 +45,7 @@ const ModelInputFieldComponent = (
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
- if (!model) {
+ if (!model || model.base_model !== 'sdxl') {
return;
}
@@ -57,7 +58,7 @@ const ModelInputFieldComponent = (
if (onnxModels) {
forEach(onnxModels.entities, (model, id) => {
- if (!model) {
+ if (!model || model.base_model !== 'sdxl') {
return;
}
@@ -103,7 +104,7 @@ const ModelInputFieldComponent = (
}
dispatch(
- fieldValueChanged({
+ fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value: newModel,
@@ -123,10 +124,8 @@ const ModelInputFieldComponent = (
) : (
0 ? 'Select a model' : 'No models available'}
data={data}
@@ -134,11 +133,7 @@ const ModelInputFieldComponent = (
disabled={data.length === 0}
onChange={handleChangeModel}
/>
- {isSyncModelEnabled && (
-
-
-
- )}
+ {isSyncModelEnabled && }
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/StringInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/StringInputField.tsx
new file mode 100644
index 0000000000..8cc0cf774f
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/StringInputField.tsx
@@ -0,0 +1,46 @@
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAIInput from 'common/components/IAIInput';
+import IAITextarea from 'common/components/IAITextarea';
+import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice';
+import {
+ StringInputFieldTemplate,
+ StringInputFieldValue,
+} from 'features/nodes/types/types';
+import { ChangeEvent, memo, useCallback } from 'react';
+import { FieldComponentProps } from './types';
+
+const StringInputFieldComponent = (
+ props: FieldComponentProps
+) => {
+ const { nodeData, field, fieldTemplate } = props;
+ const nodeId = nodeData.id;
+ const dispatch = useAppDispatch();
+
+ const handleValueChanged = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(
+ fieldStringValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value: e.target.value,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ if (fieldTemplate.ui_component === 'textarea') {
+ return (
+
+ );
+ }
+
+ return ;
+};
+
+export default memo(StringInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/UnetInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/UnetInputField.tsx
similarity index 100%
rename from invokeai/frontend/web/src/features/nodes/components/fields/UnetInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/UnetInputField.tsx
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/VaeInputField.tsx
similarity index 100%
rename from invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/VaeInputField.tsx
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/VaeModelInputField.tsx
similarity index 93%
rename from invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/VaeModelInputField.tsx
index 3fe04b6e29..a8f6a24de4 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/VaeModelInputField.tsx
@@ -2,7 +2,7 @@ import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
-import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
+import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
@@ -20,7 +20,8 @@ const VaeModelInputFieldComponent = (
VaeModelInputFieldTemplate
>
) => {
- const { nodeId, field } = props;
+ const { nodeData, field } = props;
+ const nodeId = nodeData.id;
const vae = field.value;
const dispatch = useAppDispatch();
const { data: vaeModels } = useGetVaeModelsQuery();
@@ -73,7 +74,7 @@ const VaeModelInputFieldComponent = (
}
dispatch(
- fieldValueChanged({
+ fieldVaeModelValueChanged({
nodeId,
fieldName: field.name,
value: newVaeModel,
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/types.ts b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/types.ts
similarity index 60%
rename from invokeai/frontend/web/src/features/nodes/components/fields/types.ts
rename to invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/types.ts
index cd8c076c89..b1d14c9018 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/types.ts
@@ -1,13 +1,16 @@
import {
InputFieldTemplate,
InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
} from 'features/nodes/types/types';
export type FieldComponentProps<
V extends InputFieldValue,
T extends InputFieldTemplate
> = {
- nodeId: string;
+ nodeData: InvocationNodeData;
+ nodeTemplate: InvocationTemplate;
field: V;
- template: T;
+ fieldTemplate: T;
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/nodes/CurrentImageNode.tsx b/invokeai/frontend/web/src/features/nodes/components/nodes/CurrentImageNode.tsx
new file mode 100644
index 0000000000..04e51159c6
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/nodes/CurrentImageNode.tsx
@@ -0,0 +1,93 @@
+import { Flex, Image, Text } from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import IAIDndImage from 'common/components/IAIDndImage';
+import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
+import { PropsWithChildren, memo } from 'react';
+import { useSelector } from 'react-redux';
+import { NodeProps } from 'reactflow';
+import NodeWrapper from '../Invocation/NodeWrapper';
+
+const selector = createSelector(stateSelector, ({ system, gallery }) => {
+ const imageDTO = gallery.selection[gallery.selection.length - 1];
+
+ return {
+ imageDTO,
+ progressImage: system.progressImage,
+ };
+});
+
+const CurrentImageNode = (props: NodeProps) => {
+ const { progressImage, imageDTO } = useSelector(selector);
+
+ if (progressImage) {
+ return (
+
+
+
+ );
+ }
+
+ if (imageDTO) {
+ return (
+
+
+
+ );
+ }
+
+ return (
+
+
+
+ );
+};
+
+export default memo(CurrentImageNode);
+
+const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
+
+
+
+
+ Current Image
+
+
+
+ {props.children}
+
+
+
+);
diff --git a/invokeai/frontend/web/src/features/nodes/components/nodes/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/nodes/InvocationNode.tsx
new file mode 100644
index 0000000000..f5274bf966
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/nodes/InvocationNode.tsx
@@ -0,0 +1,127 @@
+import { Flex, Icon } from '@chakra-ui/react';
+import { useAppSelector } from 'app/store/storeHooks';
+import { makeTemplateSelector } from 'features/nodes/store/util/makeTemplateSelector';
+import { InvocationNodeData } from 'features/nodes/types/types';
+import { map } from 'lodash-es';
+import { memo, useMemo } from 'react';
+import { FaExclamationCircle } from 'react-icons/fa';
+import { NodeProps } from 'reactflow';
+import NodeCollapseButton from '../Invocation/NodeCollapseButton';
+import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles';
+import NodeFooter from '../Invocation/NodeFooter';
+import NodeNotesEdit from '../Invocation/NodeNotesEdit';
+import NodeStatusIndicator from '../Invocation/NodeStatusIndicator';
+import NodeTitle from '../Invocation/NodeTitle';
+import NodeWrapper from '../Invocation/NodeWrapper';
+import InputField from '../fields/InputField';
+import OutputField from '../fields/OutputField';
+
+const InvocationNode = (props: NodeProps) => {
+ const { id: nodeId, data } = props;
+ const { type, inputs, outputs, isOpen } = data;
+
+ const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
+ const nodeTemplate = useAppSelector(templateSelector);
+ const inputFields = useMemo(
+ () => map(inputs).filter((i) => i.name !== 'is_intermediate'),
+ [inputs]
+ );
+ const outputFields = useMemo(() => map(outputs), [outputs]);
+
+ if (!nodeTemplate) {
+ return (
+
+
+
+
+
+ );
+ }
+
+ return (
+
+
+
+
+
+
+
+
+ {!isOpen && (
+
+ )}
+
+ {isOpen && (
+ <>
+
+
+ {outputFields.map((field) => (
+
+ ))}
+ {inputFields.map((field) => (
+
+ ))}
+
+
+
+ >
+ )}
+
+ );
+};
+
+export default memo(InvocationNode);
diff --git a/invokeai/frontend/web/src/features/nodes/components/nodes/NotesNode.tsx b/invokeai/frontend/web/src/features/nodes/components/nodes/NotesNode.tsx
new file mode 100644
index 0000000000..c3b035c6f3
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/nodes/NotesNode.tsx
@@ -0,0 +1,73 @@
+import { Box, Flex } from '@chakra-ui/react';
+import { useAppDispatch } from 'app/store/storeHooks';
+import IAITextarea from 'common/components/IAITextarea';
+import { notesNodeValueChanged } from 'features/nodes/store/nodesSlice';
+import { NotesNodeData } from 'features/nodes/types/types';
+import { ChangeEvent, memo, useCallback } from 'react';
+import { NodeProps } from 'reactflow';
+import NodeCollapseButton from '../Invocation/NodeCollapseButton';
+import NodeTitle from '../Invocation/NodeTitle';
+import NodeWrapper from '../Invocation/NodeWrapper';
+
+const NotesNode = (props: NodeProps) => {
+ const { id: nodeId, data } = props;
+ const { notes, isOpen } = data;
+ const dispatch = useAppDispatch();
+ const handleChange = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(notesNodeValueChanged({ nodeId, value: e.target.value }));
+ },
+ [dispatch, nodeId]
+ );
+
+ return (
+
+
+
+
+
+
+ {isOpen && (
+ <>
+
+
+
+
+
+ >
+ )}
+
+ );
+};
+
+export default memo(NotesNode);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/InspectorPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/InspectorPanel.tsx
new file mode 100644
index 0000000000..654b076eb8
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/InspectorPanel.tsx
@@ -0,0 +1,101 @@
+import {
+ Flex,
+ Tab,
+ TabList,
+ TabPanel,
+ TabPanels,
+ Tabs,
+} from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
+import { memo } from 'react';
+
+const selector = createSelector(
+ stateSelector,
+ ({ nodes }) => {
+ const lastSelectedNodeId =
+ nodes.selectedNodes[nodes.selectedNodes.length - 1];
+
+ const lastSelectedNode = nodes.nodes.find(
+ (node) => node.id === lastSelectedNodeId
+ );
+
+ const lastSelectedNodeTemplate = lastSelectedNode
+ ? nodes.nodeTemplates[lastSelectedNode.data.type]
+ : undefined;
+
+ return {
+ node: lastSelectedNode,
+ template: lastSelectedNodeTemplate,
+ };
+ },
+ defaultSelectorOptions
+);
+
+const InspectorPanel = () => {
+ const { node, template } = useAppSelector(selector);
+
+ return (
+
+
+
+ Node Template
+ Node Data
+
+
+
+
+ {template ? (
+
+
+
+ ) : (
+
+ )}
+
+
+ {node ? (
+
+ ) : (
+
+ )}
+
+
+
+
+ );
+};
+
+export default memo(InspectorPanel);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/NodeDataInspector.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/NodeDataInspector.tsx
new file mode 100644
index 0000000000..74b1620839
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/NodeDataInspector.tsx
@@ -0,0 +1,36 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
+import { memo } from 'react';
+
+const selector = createSelector(
+ stateSelector,
+ ({ nodes }) => {
+ const lastSelectedNodeId =
+ nodes.selectedNodes[nodes.selectedNodes.length - 1];
+
+ const lastSelectedNode = nodes.nodes.find(
+ (node) => node.id === lastSelectedNodeId
+ );
+
+ return {
+ node: lastSelectedNode,
+ };
+ },
+ defaultSelectorOptions
+);
+
+const NodeDataInspector = () => {
+ const { node } = useAppSelector(selector);
+
+ return node ? (
+
+ ) : (
+
+ );
+};
+
+export default memo(NodeDataInspector);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/NodeEditorPanelGroup.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/NodeEditorPanelGroup.tsx
new file mode 100644
index 0000000000..269108e87a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/NodeEditorPanelGroup.tsx
@@ -0,0 +1,49 @@
+import InspectorPanel from 'features/nodes/components/panel/InspectorPanel';
+import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
+import { memo, useState } from 'react';
+import { Panel, PanelGroup } from 'react-resizable-panels';
+import 'reactflow/dist/style.css';
+import WorkflowPanel from './WorkflowPanel';
+
+const NodeEditorPanelGroup = () => {
+ const [isTopPanelCollapsed, setIsTopPanelCollapsed] = useState(false);
+ const [isBottomPanelCollapsed, setIsBottomPanelCollapsed] = useState(false);
+
+ return (
+
+
+
+
+
+
+
+
+
+ );
+};
+
+export default memo(NodeEditorPanelGroup);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/ScrollableContent.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/ScrollableContent.tsx
new file mode 100644
index 0000000000..3b8cf5d520
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/ScrollableContent.tsx
@@ -0,0 +1,45 @@
+import { Box, Flex } from '@chakra-ui/react';
+import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
+import { PropsWithChildren } from 'react';
+
+const ScrollableContent = (props: PropsWithChildren) => {
+ return (
+
+
+
+ {props.children}
+
+
+
+ );
+};
+
+export default ScrollableContent;
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/WorkflowPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/WorkflowPanel.tsx
new file mode 100644
index 0000000000..052cf15ad7
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/WorkflowPanel.tsx
@@ -0,0 +1,52 @@
+import {
+ Flex,
+ Tab,
+ TabList,
+ TabPanel,
+ TabPanels,
+ Tabs,
+} from '@chakra-ui/react';
+import { memo } from 'react';
+import GeneralTab from './workflow/GeneralTab';
+import LinearTab from './workflow/LinearTab';
+import WorkflowTab from './workflow/WorkflowTab';
+
+const WorkflowPanel = () => {
+ return (
+
+
+
+ Linear
+ Details
+ JSON
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ );
+};
+
+export default memo(WorkflowPanel);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/workflow/GeneralTab.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/GeneralTab.tsx
new file mode 100644
index 0000000000..8dab91c0d5
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/GeneralTab.tsx
@@ -0,0 +1,142 @@
+import { Flex, FormControl, FormLabel } from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import IAIInput from 'common/components/IAIInput';
+import IAITextarea from 'common/components/IAITextarea';
+import {
+ workflowAuthorChanged,
+ workflowContactChanged,
+ workflowDescriptionChanged,
+ workflowNameChanged,
+ workflowNotesChanged,
+ workflowTagsChanged,
+ workflowVersionChanged,
+} from 'features/nodes/store/nodesSlice';
+import { ChangeEvent, memo, useCallback } from 'react';
+import ScrollableContent from '../ScrollableContent';
+
+const selector = createSelector(
+ stateSelector,
+ ({ nodes }) => {
+ const { author, name, description, tags, version, contact, notes } =
+ nodes.workflow;
+
+ return {
+ name,
+ author,
+ description,
+ tags,
+ version,
+ contact,
+ notes,
+ };
+ },
+ defaultSelectorOptions
+);
+
+const WorkflowPanel = () => {
+ const { author, name, description, tags, version, contact, notes } =
+ useAppSelector(selector);
+ const dispatch = useAppDispatch();
+
+ const handleChangeName = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowNameChanged(e.target.value));
+ },
+ [dispatch]
+ );
+ const handleChangeAuthor = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowAuthorChanged(e.target.value));
+ },
+ [dispatch]
+ );
+ const handleChangeContact = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowContactChanged(e.target.value));
+ },
+ [dispatch]
+ );
+ const handleChangeVersion = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowVersionChanged(e.target.value));
+ },
+ [dispatch]
+ );
+ const handleChangeDescription = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowDescriptionChanged(e.target.value));
+ },
+ [dispatch]
+ );
+ const handleChangeTags = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowTagsChanged(e.target.value));
+ },
+ [dispatch]
+ );
+
+ const handleChangeNotes = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowNotesChanged(e.target.value));
+ },
+ [dispatch]
+ );
+
+ return (
+
+
+
+
+
+
+
+
+
+
+
+
+ Short Description
+
+
+
+ Notes
+
+
+
+
+ );
+};
+
+export default memo(WorkflowPanel);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/workflow/LinearTab.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/LinearTab.tsx
new file mode 100644
index 0000000000..833fcc6839
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/LinearTab.tsx
@@ -0,0 +1,114 @@
+import { Box, Flex } from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import IAIDroppable from 'common/components/IAIDroppable';
+import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import { AddFieldToLinearViewDropData } from 'features/dnd/types';
+import {
+ InputFieldTemplate,
+ InputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+ isInvocationNode,
+} from 'features/nodes/types/types';
+import { forEach } from 'lodash-es';
+import { memo } from 'react';
+import LinearViewField from '../../fields/LinearViewField';
+import ScrollableContent from '../ScrollableContent';
+
+const selector = createSelector(
+ stateSelector,
+ ({ nodes }) => {
+ const fields: {
+ nodeData: InvocationNodeData;
+ nodeTemplate: InvocationTemplate;
+ field: InputFieldValue;
+ fieldTemplate: InputFieldTemplate;
+ }[] = [];
+ const { exposedFields } = nodes.workflow;
+ nodes.nodes.filter(isInvocationNode).forEach((node) => {
+ const nodeTemplate = nodes.nodeTemplates[node.data.type];
+ if (!nodeTemplate) {
+ return;
+ }
+ forEach(node.data.inputs, (field) => {
+ if (
+ !exposedFields.some(
+ (f) => f.nodeId === node.id && f.fieldName === field.name
+ )
+ ) {
+ return;
+ }
+ const fieldTemplate = nodeTemplate.inputs[field.name];
+ if (!fieldTemplate) {
+ return;
+ }
+ fields.push({
+ nodeData: node.data,
+ nodeTemplate,
+ field,
+ fieldTemplate,
+ });
+ });
+ });
+
+ return {
+ fields,
+ };
+ },
+ defaultSelectorOptions
+);
+
+const droppableData: AddFieldToLinearViewDropData = {
+ id: 'add-field-to-linear-view',
+ actionType: 'ADD_FIELD_TO_LINEAR',
+};
+
+const LinearTabContent = () => {
+ const { fields } = useAppSelector(selector);
+
+ return (
+
+
+
+ {fields.length ? (
+ fields.map(({ nodeData, nodeTemplate, field, fieldTemplate }) => (
+
+ ))
+ ) : (
+
+ )}
+
+
+
+
+ );
+};
+
+export default memo(LinearTabContent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/workflow/NotesTab.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/NotesTab.tsx
new file mode 100644
index 0000000000..d8b19c1645
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/NotesTab.tsx
@@ -0,0 +1,51 @@
+import { Box, Text } from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import IAITextarea from 'common/components/IAITextarea';
+import { workflowNotesChanged } from 'features/nodes/store/nodesSlice';
+import { ChangeEvent, memo, useCallback } from 'react';
+
+const selector = createSelector(stateSelector, ({ nodes }) => {
+ const { notes } = nodes.workflow;
+
+ return {
+ notes,
+ };
+});
+
+const WorkflowPanel = () => {
+ const { notes } = useAppSelector(selector);
+ const dispatch = useAppDispatch();
+
+ const handleChangeNotes = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(workflowNotesChanged(e.target.value));
+ },
+ [dispatch]
+ );
+
+ return (
+
+
+
+
+ {notes.length}
+
+
+
+ );
+};
+
+export default memo(WorkflowPanel);
diff --git a/invokeai/frontend/web/src/features/nodes/components/panel/workflow/WorkflowTab.tsx b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/WorkflowTab.tsx
new file mode 100644
index 0000000000..c9400ab5f6
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/panel/workflow/WorkflowTab.tsx
@@ -0,0 +1,43 @@
+import { Flex } from '@chakra-ui/react';
+import { RootState } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
+import { buildWorkflow } from 'features/nodes/util/buildWorkflow';
+import { memo, useMemo } from 'react';
+import { useDebounce } from 'use-debounce';
+
+const useWatchWorkflow = () => {
+ const nodes = useAppSelector((state: RootState) => state.nodes);
+ const [debouncedNodes] = useDebounce(nodes, 300);
+ const workflow = useMemo(
+ () => buildWorkflow(debouncedNodes),
+ [debouncedNodes]
+ );
+
+ return {
+ workflow,
+ };
+};
+
+const WorkflowWorkflowTab = () => {
+ const { workflow } = useWatchWorkflow();
+
+ return (
+
+
+
+ );
+};
+
+export default memo(WorkflowWorkflowTab);
diff --git a/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx b/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx
index d4a4f8d31f..1b9dc38cb6 100644
--- a/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx
@@ -1,10 +1,9 @@
import { Box, Flex } from '@chakra-ui/layout';
import { Tooltip } from '@chakra-ui/tooltip';
import { useAppToaster } from 'app/components/Toaster';
-import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIInput from 'common/components/IAIInput';
-import { useBuildInvocation } from 'features/nodes/hooks/useBuildInvocation';
+import { useBuildNodeData } from 'features/nodes/hooks/useBuildNodeData';
import { InvocationTemplate } from 'features/nodes/types/types';
import Fuse from 'fuse.js';
import { map } from 'lodash-es';
@@ -51,16 +50,15 @@ const NodeListItem = (props: NodeListItemProps) => {
NodeListItem.displayName = 'NodeListItem';
const NodeSearch = () => {
- const invocationTemplates = useAppSelector(
- (state: RootState) => state.nodes.invocationTemplates
+ const nodeTemplates = useAppSelector((state) =>
+ map(state.nodes.nodeTemplates)
);
- const nodes = map(invocationTemplates);
const [filteredNodes, setFilteredNodes] = useState<
Fuse.FuseResult[]
>([]);
- const buildInvocation = useBuildInvocation();
+ const buildInvocation = useBuildNodeData();
const dispatch = useAppDispatch();
const toaster = useAppToaster();
@@ -76,7 +74,7 @@ const NodeSearch = () => {
keys: ['title', 'type', 'tags'],
};
- const fuse = new Fuse(nodes, fuseOptions);
+ const fuse = new Fuse(nodeTemplates, fuseOptions);
const findNode = (e: ChangeEvent) => {
setSearchText(e.target.value);
@@ -121,7 +119,7 @@ const NodeSearch = () => {
}
});
} else {
- nodes.forEach(({ title, description, type }, index) => {
+ nodeTemplates.forEach(({ title, description, type }, index) => {
nodeListToRender.push(
{
if (searchText.length > 0) {
nextIndex = (focusedIndex + 1) % filteredNodes.length;
} else {
- nextIndex = (focusedIndex + 1) % nodes.length;
+ nextIndex = (focusedIndex + 1) % nodeTemplates.length;
}
}
@@ -161,7 +159,8 @@ const NodeSearch = () => {
nextIndex =
(focusedIndex + filteredNodes.length - 1) % filteredNodes.length;
} else {
- nextIndex = (focusedIndex + nodes.length - 1) % nodes.length;
+ nextIndex =
+ (focusedIndex + nodeTemplates.length - 1) % nodeTemplates.length;
}
}
@@ -175,7 +174,7 @@ const NodeSearch = () => {
if (searchText.length > 0) {
selectedNodeType = filteredNodes[focusedIndex]?.item.type;
} else {
- selectedNodeType = nodes[focusedIndex]?.type;
+ selectedNodeType = nodeTemplates[focusedIndex]?.type;
}
if (selectedNodeType) {
diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx
deleted file mode 100644
index 44d93bb8fe..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx
+++ /dev/null
@@ -1,161 +0,0 @@
-import { FileButton } from '@mantine/core';
-import { useAppDispatch } from 'app/store/storeHooks';
-import IAIIconButton from 'common/components/IAIIconButton';
-import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
-import { addToast } from 'features/system/store/systemSlice';
-import { makeToast } from 'features/system/util/makeToast';
-import i18n from 'i18n';
-import { memo, useCallback, useRef } from 'react';
-import { useTranslation } from 'react-i18next';
-import { FaUpload } from 'react-icons/fa';
-import { useReactFlow } from 'reactflow';
-
-interface JsonFile {
- [key: string]: unknown;
-}
-
-function sanityCheckInvokeAIGraph(jsonFile: JsonFile): {
- isValid: boolean;
- message: string;
-} {
- // Check if primary keys exist
- const keys = ['nodes', 'edges', 'viewport'];
- for (const key of keys) {
- if (!(key in jsonFile)) {
- return {
- isValid: false,
- message: i18n.t('toast.nodesNotValidGraph'),
- };
- }
- }
-
- // Check if nodes and edges are arrays
- if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) {
- return {
- isValid: false,
- message: i18n.t('toast.nodesNotValidGraph'),
- };
- }
-
- // Check if data is present in nodes
- const nodeKeys = ['data', 'type'];
- const nodeTypes = ['invocation', 'progress_image'];
- if (jsonFile.nodes.length > 0) {
- for (const node of jsonFile.nodes) {
- for (const nodeKey of nodeKeys) {
- if (!(nodeKey in node)) {
- return {
- isValid: false,
- message: i18n.t('toast.nodesNotValidGraph'),
- };
- }
- if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) {
- return {
- isValid: false,
- message: i18n.t('toast.nodesUnrecognizedTypes'),
- };
- }
- }
- }
- }
-
- // Check Edge Object
- const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle'];
- if (jsonFile.edges.length > 0) {
- for (const edge of jsonFile.edges) {
- for (const edgeKey of edgeKeys) {
- if (!(edgeKey in edge)) {
- return {
- isValid: false,
- message: i18n.t('toast.nodesBrokenConnections'),
- };
- }
- }
- }
- }
-
- return {
- isValid: true,
- message: i18n.t('toast.nodesLoaded'),
- };
-}
-
-const LoadGraphButton = () => {
- const { t } = useTranslation();
- const dispatch = useAppDispatch();
- const { fitView } = useReactFlow();
-
- const uploadedFileRef = useRef<() => void>(null);
-
- const restoreJSONToEditor = useCallback(
- (v: File | null) => {
- if (!v) return;
- const reader = new FileReader();
- reader.onload = async () => {
- const json = reader.result;
-
- try {
- const retrievedNodeTree = await JSON.parse(String(json));
- const { isValid, message } =
- sanityCheckInvokeAIGraph(retrievedNodeTree);
-
- if (isValid) {
- dispatch(loadFileNodes(retrievedNodeTree.nodes));
- dispatch(loadFileEdges(retrievedNodeTree.edges));
- fitView();
-
- dispatch(
- addToast(makeToast({ title: message, status: 'success' }))
- );
- } else {
- dispatch(
- addToast(
- makeToast({
- title: message,
- status: 'error',
- })
- )
- );
- }
- // Cleanup
- reader.abort();
- } catch (error) {
- if (error) {
- dispatch(
- addToast(
- makeToast({
- title: t('toast.nodesNotValidJSON'),
- status: 'error',
- })
- )
- );
- }
- }
- };
-
- reader.readAsText(v);
-
- // Cleanup
- uploadedFileRef.current?.();
- },
- [fitView, dispatch, t]
- );
- return (
-
- {(props) => (
- }
- tooltip={t('nodes.loadGraph')}
- aria-label={t('nodes.loadGraph')}
- {...props}
- />
- )}
-
- );
-};
-
-export default memo(LoadGraphButton);
diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx
index 740fecc2a4..b735bce0f7 100644
--- a/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx
@@ -5,7 +5,7 @@ import IAIButton, { IAIButtonProps } from 'common/components/IAIButton';
import IAIIconButton, {
IAIIconButtonProps,
} from 'common/components/IAIIconButton';
-import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
+import { selectIsReadyNodes } from 'features/nodes/store/selectors';
import ProgressBar from 'features/system/components/ProgressBar';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
@@ -22,7 +22,7 @@ export default function NodeInvokeButton(props: InvokeButton) {
const { iconButton = false, ...rest } = props;
const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector);
- const isReady = useIsReadyToInvoke();
+ const isReady = useAppSelector(selectIsReadyNodes);
const handleInvoke = useCallback(() => {
dispatch(userInvoked('nodes'));
}, [dispatch]);
@@ -71,12 +71,6 @@ export default function NodeInvokeButton(props: InvokeButton) {
tooltipProps={{ placement: 'bottom' }}
colorScheme="accent"
id="invoke-button"
- _disabled={{
- background: 'none',
- _hover: {
- background: 'none',
- },
- }}
{...rest}
/>
) : (
@@ -90,12 +84,6 @@ export default function NodeInvokeButton(props: InvokeButton) {
colorScheme="accent"
id="invoke-button"
fontWeight={700}
- _disabled={{
- background: 'none',
- _hover: {
- background: 'none',
- },
- }}
{...rest}
>
Invoke
diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/SaveGraphButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/SaveGraphButton.tsx
deleted file mode 100644
index 42e545258e..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/ui/SaveGraphButton.tsx
+++ /dev/null
@@ -1,48 +0,0 @@
-import { RootState } from 'app/store/store';
-import { useAppSelector } from 'app/store/storeHooks';
-import IAIIconButton from 'common/components/IAIIconButton';
-import { map, omit } from 'lodash-es';
-import { memo, useCallback } from 'react';
-import { useTranslation } from 'react-i18next';
-import { FaSave } from 'react-icons/fa';
-
-const SaveGraphButton = () => {
- const { t } = useTranslation();
- const editorInstance = useAppSelector(
- (state: RootState) => state.nodes.editorInstance
- );
-
- const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
-
- const saveEditorToJSON = useCallback(() => {
- if (editorInstance) {
- const editorState = editorInstance.toObject();
-
- editorState.edges = map(editorState.edges, (edge) => {
- return omit(edge, ['style']);
- });
-
- const nodeSetupJSON = new Blob([JSON.stringify(editorState)]);
- const nodeDownloadElement = document.createElement('a');
- nodeDownloadElement.href = URL.createObjectURL(nodeSetupJSON);
- nodeDownloadElement.download = 'MyNodes.json';
- document.body.appendChild(nodeDownloadElement);
- nodeDownloadElement.click();
- // Cleanup
- nodeDownloadElement.remove();
- }
- }, [editorInstance]);
-
- return (
- }
- fontSize={18}
- tooltip={t('nodes.saveGraph')}
- aria-label={t('nodes.saveGraph')}
- onClick={saveEditorToJSON}
- isDisabled={nodes.length === 0}
- />
- );
-};
-
-export default memo(SaveGraphButton);
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildInvocation.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts
similarity index 69%
rename from invokeai/frontend/web/src/features/nodes/hooks/useBuildInvocation.ts
rename to invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts
index 50a19d6b45..e38c20c05a 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildInvocation.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts
@@ -7,42 +7,67 @@ import { Node, useReactFlow } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
import {
+ CurrentImageNodeData,
InputFieldValue,
- InvocationValue,
+ InvocationNodeData,
+ NotesNodeData,
OutputFieldValue,
} from '../types/types';
import { buildInputFieldValue } from '../util/fieldValueBuilders';
+import { DRAG_HANDLE_CLASSNAME } from '../types/constants';
const templatesSelector = createSelector(
[(state: RootState) => state.nodes],
- (nodes) => nodes.invocationTemplates
+ (nodes) => nodes.nodeTemplates
);
-export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
-
export const SHARED_NODE_PROPERTIES: Partial = {
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
};
-export const useBuildInvocation = () => {
+export const useBuildNodeData = () => {
const invocationTemplates = useAppSelector(templatesSelector);
const flow = useReactFlow();
return useCallback(
- (type: AnyInvocationType | 'progress_image') => {
- if (type === 'progress_image') {
- const { x, y } = flow.project({
- x: window.innerWidth / 2.5,
- y: window.innerHeight / 8,
- });
+ (type: AnyInvocationType | 'current_image' | 'notes') => {
+ const nodeId = uuidv4();
- const node: Node = {
+ const { x, y } = flow.project({
+ x: window.innerWidth / 2.5,
+ y: window.innerHeight / 8,
+ });
+ if (type === 'current_image') {
+ const node: Node = {
...SHARED_NODE_PROPERTIES,
- id: 'progress_image',
- type: 'progress_image',
+ id: nodeId,
+ type: 'current_image',
position: { x: x, y: y },
- data: {},
+ data: {
+ id: nodeId,
+ type: 'current_image',
+ isOpen: true,
+ label: 'Current Image',
+ },
+ };
+
+ return node;
+ }
+
+ if (type === 'notes') {
+ const node: Node = {
+ ...SHARED_NODE_PROPERTIES,
+ id: nodeId,
+ type: 'notes',
+ position: { x: x, y: y },
+ data: {
+ id: nodeId,
+ isOpen: true,
+ label: 'Notes',
+ notes: '',
+ type: 'notes',
+ },
};
return node;
@@ -55,8 +80,6 @@ export const useBuildInvocation = () => {
return;
}
- const nodeId = uuidv4();
-
const inputs = reduce(
template.inputs,
(inputsAccumulator, inputTemplate, inputName) => {
@@ -83,6 +106,7 @@ export const useBuildInvocation = () => {
id: fieldId,
name: outputName,
type: outputTemplate.type,
+ fieldKind: 'output',
};
outputsAccumulator[outputName] = outputFieldValue;
@@ -92,12 +116,7 @@ export const useBuildInvocation = () => {
{} as Record
);
- const { x, y } = flow.project({
- x: window.innerWidth / 2.5,
- y: window.innerHeight / 8,
- });
-
- const invocation: Node = {
+ const invocation: Node = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'invocation',
@@ -107,6 +126,9 @@ export const useBuildInvocation = () => {
type,
inputs,
outputs,
+ isOpen: true,
+ label: '',
+ notes: '',
},
};
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts
new file mode 100644
index 0000000000..625736a933
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts
@@ -0,0 +1,92 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
+import { InputFieldValue, OutputFieldValue } from 'features/nodes/types/types';
+import { useMemo } from 'react';
+
+const selectIsConnectionInProgress = createSelector(
+ stateSelector,
+ ({ nodes }) =>
+ nodes.currentConnectionFieldType !== null &&
+ nodes.connectionStartParams !== null
+);
+
+export type UseConnectionStateProps =
+ | {
+ nodeId: string;
+ field: InputFieldValue;
+ kind: 'input';
+ }
+ | {
+ nodeId: string;
+ field: OutputFieldValue;
+ kind: 'output';
+ };
+
+export const useConnectionState = ({
+ nodeId,
+ field,
+ kind,
+}: UseConnectionStateProps) => {
+ const selectIsConnected = useMemo(
+ () =>
+ createSelector(stateSelector, ({ nodes }) =>
+ Boolean(
+ nodes.edges.filter((edge) => {
+ return (
+ (kind === 'input' ? edge.target : edge.source) === nodeId &&
+ (kind === 'input' ? edge.targetHandle : edge.sourceHandle) ===
+ field.name
+ );
+ }).length
+ )
+ ),
+ [field.name, kind, nodeId]
+ );
+
+ const selectConnectionError = useMemo(
+ () =>
+ makeConnectionErrorSelector(
+ nodeId,
+ field.name,
+ kind === 'input' ? 'target' : 'source',
+ field.type
+ ),
+ [nodeId, field.name, field.type, kind]
+ );
+
+ const selectIsConnectionStartField = useMemo(
+ () =>
+ createSelector(stateSelector, ({ nodes }) =>
+ Boolean(
+ nodes.connectionStartParams?.nodeId === nodeId &&
+ nodes.connectionStartParams?.handleId === field.name &&
+ nodes.connectionStartParams?.handleType ===
+ { input: 'target', output: 'source' }[kind]
+ )
+ ),
+ [field.name, kind, nodeId]
+ );
+
+ const isConnected = useAppSelector(selectIsConnected);
+ const isConnectionInProgress = useAppSelector(selectIsConnectionInProgress);
+ const isConnectionStartField = useAppSelector(selectIsConnectionStartField);
+ const connectionError = useAppSelector(selectConnectionError);
+
+ const shouldDim = useMemo(
+ () =>
+ Boolean(
+ isConnectionInProgress && connectionError && !isConnectionStartField
+ ),
+ [connectionError, isConnectionInProgress, isConnectionStartField]
+ );
+
+ return {
+ isConnected,
+ isConnectionInProgress,
+ isConnectionStartField,
+ connectionError,
+ shouldDim,
+ };
+};
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts
index e5bfa0a627..3a63d75bb0 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts
@@ -1,94 +1,102 @@
// TODO: enable this at some point
-// import graphlib from '@dagrejs/graphlib';
-// import { useCallback } from 'react';
-// import { Connection, Node, useReactFlow } from 'reactflow';
-// import { InvocationValue } from '../types/types';
+import graphlib from '@dagrejs/graphlib';
+import { useAppSelector } from 'app/store/storeHooks';
+import { useCallback } from 'react';
+import { Connection, Edge, Node, useReactFlow } from 'reactflow';
+import { COLLECTION_TYPES } from '../types/constants';
+import { InvocationNodeData } from '../types/types';
-// export const useIsValidConnection = () => {
-// const flow = useReactFlow();
+export const useIsValidConnection = () => {
+ const flow = useReactFlow();
+ const shouldValidateGraph = useAppSelector(
+ (state) => state.nodes.shouldValidateGraph
+ );
+ const isValidConnection = useCallback(
+ ({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
+ if (!shouldValidateGraph) {
+ // manual override!
+ return true;
+ }
-// // Check if an in-progress connection is valid
-// const isValidConnection = useCallback(
-// ({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
-// const edges = flow.getEdges();
-// const nodes = flow.getNodes();
+ const edges = flow.getEdges();
+ const nodes = flow.getNodes();
+ // Connection must have valid targets
+ if (!(source && sourceHandle && target && targetHandle)) {
+ return false;
+ }
-// // Connection must have valid targets
-// if (!(source && sourceHandle && target && targetHandle)) {
-// return false;
-// }
+ // Find the source and target nodes
+ const sourceNode = flow.getNode(source) as Node;
+ const targetNode = flow.getNode(target) as Node;
-// // Connection is invalid if target already has a connection
-// if (
-// edges.find((edge) => {
-// return edge.target === target && edge.targetHandle === targetHandle;
-// })
-// ) {
-// return false;
-// }
+ // Conditional guards against undefined nodes/handles
+ if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) {
+ return false;
+ }
-// // Find the source and target nodes
-// const sourceNode = flow.getNode(source) as Node;
+ const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
+ const targetType = targetNode.data.inputs[targetHandle]?.type;
-// const targetNode = flow.getNode(target) as Node;
+ if (!sourceType || !targetType) {
+ // something has gone terribly awry
+ return false;
+ }
-// // Conditional guards against undefined nodes/handles
-// if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) {
-// return false;
-// }
+ // Connection is invalid if target already has a connection
+ if (
+ edges.find((edge) => {
+ return edge.target === target && edge.targetHandle === targetHandle;
+ }) &&
+ // except CollectionItem inputs can have multiples
+ targetType !== 'CollectionItem'
+ ) {
+ return false;
+ }
-// // Connection types must be the same for a connection
-// if (
-// sourceNode.data.outputs[sourceHandle].type !==
-// targetNode.data.inputs[targetHandle].type
-// ) {
-// return false;
-// }
+ // Connection types must be the same for a connection
+ if (
+ sourceType !== targetType &&
+ sourceType !== 'CollectionItem' &&
+ targetType !== 'CollectionItem'
+ ) {
+ if (
+ !(
+ COLLECTION_TYPES.includes(targetType) &&
+ COLLECTION_TYPES.includes(sourceType)
+ )
+ ) {
+ return false;
+ }
+ }
+ // Graphs much be acyclic (no loops!)
+ return getIsGraphAcyclic(source, target, nodes, edges);
+ },
+ [flow, shouldValidateGraph]
+ );
-// // Graphs much be acyclic (no loops!)
+ return isValidConnection;
+};
-// /**
-// * TODO: use `graphlib.alg.findCycles()` to identify strong connections
-// *
-// * this validation func only runs when the cursor hits the second handle of the connection,
-// * and only on that second handle - so it cannot tell us exhaustively which connections
-// * are valid.
-// *
-// * ideally, we check when the connection starts to calculate all invalid handles at once.
-// *
-// * requires making a new graphlib graph - and calling `findCycles()` - for each potential
-// * handle. instead of using the `isValidConnection` prop, it would use the `onConnectStart`
-// * prop.
-// *
-// * the strong connections should be stored in global state.
-// *
-// * then, `isValidConnection` would simple loop through the strong connections and if the
-// * source and target are in a single strong connection, return false.
-// *
-// * and also, we can use this knowledge to style every handle when a connection starts,
-// * which is otherwise not possible.
-// */
+export const getIsGraphAcyclic = (
+ source: string,
+ target: string,
+ nodes: Node[],
+ edges: Edge[]
+) => {
+ // construct graphlib graph from editor state
+ const g = new graphlib.Graph();
-// // build a graphlib graph
-// const g = new graphlib.Graph();
+ nodes.forEach((n) => {
+ g.setNode(n.id);
+ });
-// nodes.forEach((n) => {
-// g.setNode(n.id);
-// });
+ edges.forEach((e) => {
+ g.setEdge(e.source, e.target);
+ });
-// edges.forEach((e) => {
-// g.setEdge(e.source, e.target);
-// });
+ // add the candidate edge
+ g.setEdge(source, target);
-// // Add the candidate edge to the graph
-// g.setEdge(source, target);
-
-// return graphlib.alg.isAcyclic(g);
-// },
-// [flow]
-// );
-
-// return isValidConnection;
-// };
-
-export const useIsValidConnection = () => () => true;
+ // check if the graph is acyclic
+ return graphlib.alg.isAcyclic(g);
+};
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts
index 5cbc3c381d..60344abf37 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts
@@ -5,5 +5,9 @@ import { NodesState } from './types';
*/
export const nodesPersistDenylist: (keyof NodesState)[] = [
'schema',
- 'invocationTemplates',
+ 'nodeTemplates',
+ 'connectionStartParams',
+ 'currentConnectionFieldType',
+ 'selectedNodes',
+ 'selectedEdges',
];
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
index 2e41081e95..8878d24370 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
@@ -1,12 +1,5 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
-import {
- ControlNetModelParam,
- LoRAModelParam,
- MainModelParam,
- VaeModelParam,
-} from 'features/parameters/types/parameterSchemas';
-import { cloneDeep, uniqBy } from 'lodash-es';
-import { RgbaColor } from 'react-colorful';
+import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es';
import {
addEdge,
applyEdgeChanges,
@@ -14,26 +7,103 @@ import {
Connection,
Edge,
EdgeChange,
+ EdgeRemoveChange,
+ getConnectedEdges,
+ getIncomers,
+ getOutgoers,
Node,
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
+import { sessionInvoked } from 'services/api/thunks/session';
import { ImageField } from 'services/api/types';
-import { InvocationTemplate, InvocationValue } from '../types/types';
+import {
+ appSocketGeneratorProgress,
+ appSocketInvocationComplete,
+ appSocketInvocationError,
+ appSocketInvocationStarted,
+} from 'services/events/actions';
+import { DRAG_HANDLE_CLASSNAME } from '../types/constants';
+import {
+ BooleanInputFieldValue,
+ ColorInputFieldValue,
+ ControlNetModelInputFieldValue,
+ CurrentImageNodeData,
+ EnumInputFieldValue,
+ FieldIdentifier,
+ FloatInputFieldValue,
+ ImageInputFieldValue,
+ InputFieldValue,
+ IntegerInputFieldValue,
+ InvocationNodeData,
+ InvocationTemplate,
+ isInvocationNode,
+ isNotesNode,
+ LoRAModelInputFieldValue,
+ MainModelInputFieldValue,
+ NodeStatus,
+ NotesNodeData,
+ SDXLRefinerModelInputFieldValue,
+ StringInputFieldValue,
+ VaeModelInputFieldValue,
+ Workflow,
+} from '../types/types';
import { NodesState } from './types';
export const initialNodesState: NodesState = {
nodes: [],
edges: [],
schema: null,
- invocationTemplates: {},
+ nodeTemplates: {},
connectionStartParams: null,
- shouldShowGraphOverlay: false,
+ currentConnectionFieldType: null,
shouldShowFieldTypeLegend: false,
shouldShowMinimapPanel: true,
- editorInstance: undefined,
- progressNodeSize: { width: 512, height: 512 },
+ shouldValidateGraph: true,
+ shouldAnimateEdges: true,
+ shouldSnapToGrid: true,
+ shouldColorEdges: true,
+ nodeOpacity: 1,
+ selectedNodes: [],
+ selectedEdges: [],
+ workflow: {
+ name: '',
+ author: '',
+ description: '',
+ notes: '',
+ tags: '',
+ contact: '',
+ version: '',
+ exposedFields: [],
+ },
+ nodeExecutionStates: {},
+ zoom: 1,
+};
+
+type FieldValueAction = PayloadAction<{
+ nodeId: string;
+ fieldName: string;
+ value: T['value'];
+}>;
+
+const fieldValueReducer = (
+ state: NodesState,
+ action: FieldValueAction
+) => {
+ const { nodeId, fieldName, value } = action.payload;
+ const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
+ const node = state.nodes?.[nodeIndex];
+ if (!isInvocationNode(node)) {
+ return;
+ }
+ const input = node.data?.inputs[fieldName];
+ if (!input) {
+ return;
+ }
+ if (nodeIndex > -1) {
+ input.value = value;
+ }
};
const nodesSlice = createSlice({
@@ -43,49 +113,336 @@ const nodesSlice = createSlice({
nodesChanged: (state, action: PayloadAction) => {
state.nodes = applyNodeChanges(action.payload, state.nodes);
},
- nodeAdded: (state, action: PayloadAction>) => {
- state.nodes.push(action.payload);
+ nodeAdded: (
+ state,
+ action: PayloadAction<
+ Node
+ >
+ ) => {
+ const node = action.payload;
+ state.nodes.push(node);
+
+ if (!isInvocationNode(node)) {
+ return;
+ }
+
+ state.nodeExecutionStates[node.id] = {
+ status: NodeStatus.PENDING,
+ error: null,
+ progress: null,
+ progressImage: null,
+ };
},
edgesChanged: (state, action: PayloadAction) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
},
connectionStarted: (state, action: PayloadAction) => {
state.connectionStartParams = action.payload;
+ const { nodeId, handleId, handleType } = action.payload;
+ if (!nodeId || !handleId) {
+ return;
+ }
+ const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
+ const node = state.nodes?.[nodeIndex];
+ if (!isInvocationNode(node)) {
+ return;
+ }
+ const field =
+ handleType === 'source'
+ ? node.data.outputs[handleId]
+ : node.data.inputs[handleId];
+ state.currentConnectionFieldType = field?.type ?? null;
},
connectionMade: (state, action: PayloadAction) => {
- state.edges = addEdge(action.payload, state.edges);
+ const fieldType = state.currentConnectionFieldType;
+ if (!fieldType) {
+ return;
+ }
+ state.edges = addEdge(
+ { ...action.payload, type: 'default' },
+ state.edges
+ );
},
connectionEnded: (state) => {
state.connectionStartParams = null;
+ state.currentConnectionFieldType = null;
},
- fieldValueChanged: (
+ workflowExposedFieldAdded: (
+ state,
+ action: PayloadAction
+ ) => {
+ state.workflow.exposedFields = uniqBy(
+ state.workflow.exposedFields.concat(action.payload),
+ (field) => `${field.nodeId}-${field.fieldName}`
+ );
+ },
+ workflowExposedFieldRemoved: (
+ state,
+ action: PayloadAction
+ ) => {
+ state.workflow.exposedFields = state.workflow.exposedFields.filter(
+ (field) => !isEqual(field, action.payload)
+ );
+ },
+ fieldLabelChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldName: string;
- value:
- | string
- | number
- | boolean
- | ImageField
- | RgbaColor
- | undefined
- | ImageField[]
- | MainModelParam
- | VaeModelParam
- | LoRAModelParam
- | ControlNetModelParam;
+ label: string;
}>
) => {
- const { nodeId, fieldName, value } = action.payload;
- const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
- const input = state.nodes?.[nodeIndex]?.data?.inputs[fieldName];
- if (!input) {
+ const { nodeId, fieldName, label } = action.payload;
+ const node = state.nodes.find((n) => n.id === nodeId);
+ if (!isInvocationNode(node)) {
return;
}
- if (nodeIndex > -1) {
- input.value = value;
+ const field = node.data.inputs[fieldName];
+ if (!field) {
+ return;
}
+ field.label = label;
+ },
+ nodeIsOpenChanged: (
+ state,
+ action: PayloadAction<{ nodeId: string; isOpen: boolean }>
+ ) => {
+ const { nodeId, isOpen } = action.payload;
+ const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
+
+ const node = state.nodes?.[nodeIndex];
+ if (!isInvocationNode(node) && !isNotesNode(node)) {
+ return;
+ }
+
+ node.data.isOpen = isOpen;
+
+ if (!isInvocationNode(node)) {
+ return;
+ }
+
+ // edges between two closed nodes should not be visible:
+ // - if the node was just opened, we need to make all its edges visible
+ // - if the edge was just closed, we need to check all its edges and hide them if both nodes are closed
+
+ const connectedEdges = getConnectedEdges([node], state.edges);
+
+ if (isOpen) {
+ // reset hidden status of all edges
+ connectedEdges.forEach((edge) => {
+ delete edge.hidden;
+ });
+ // delete dummy edges
+ connectedEdges.forEach((edge) => {
+ if (edge.type === 'collapsed') {
+ state.edges = state.edges.filter((e) => e.id !== edge.id);
+ }
+ });
+ } else {
+ const closedIncomers = getIncomers(
+ node,
+ state.nodes,
+ state.edges
+ ).filter(
+ (node) => isInvocationNode(node) && node.data.isOpen === false
+ );
+
+ const closedOutgoers = getOutgoers(
+ node,
+ state.nodes,
+ state.edges
+ ).filter(
+ (node) => isInvocationNode(node) && node.data.isOpen === false
+ );
+
+ const collapsedEdgesToCreate: Edge<{ count: number }>[] = [];
+
+ // hide all edges
+ connectedEdges.forEach((edge) => {
+ if (
+ edge.target === nodeId &&
+ closedIncomers.find((node) => node.id === edge.source)
+ ) {
+ edge.hidden = true;
+ const collapsedEdge = collapsedEdgesToCreate.find(
+ (e) => e.source === edge.source && e.target === edge.target
+ );
+ if (collapsedEdge) {
+ collapsedEdge.data = {
+ count: (collapsedEdge.data?.count ?? 0) + 1,
+ };
+ } else {
+ collapsedEdgesToCreate.push({
+ id: `${edge.source}-${edge.target}-collapsed`,
+ source: edge.source,
+ target: edge.target,
+ type: 'collapsed',
+ data: { count: 1 },
+ });
+ }
+ }
+ if (
+ edge.source === nodeId &&
+ closedOutgoers.find((node) => node.id === edge.target)
+ ) {
+ const collapsedEdge = collapsedEdgesToCreate.find(
+ (e) => e.source === edge.source && e.target === edge.target
+ );
+ edge.hidden = true;
+ if (collapsedEdge) {
+ collapsedEdge.data = {
+ count: (collapsedEdge.data?.count ?? 0) + 1,
+ };
+ } else {
+ collapsedEdgesToCreate.push({
+ id: `${edge.source}-${edge.target}-collapsed`,
+ source: edge.source,
+ target: edge.target,
+ type: 'collapsed',
+ data: { count: 1 },
+ });
+ }
+ }
+ });
+ if (collapsedEdgesToCreate.length) {
+ state.edges = applyEdgeChanges(
+ collapsedEdgesToCreate.map((edge) => ({ type: 'add', item: edge })),
+ state.edges
+ );
+ }
+ }
+ },
+ edgesDeleted: (state, action: PayloadAction) => {
+ const edges = action.payload;
+ const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
+
+ // if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes
+ if (collapsedEdges.length) {
+ const edgeChanges: EdgeRemoveChange[] = [];
+ collapsedEdges.forEach((collapsedEdge) => {
+ state.edges.forEach((edge) => {
+ if (
+ edge.source === collapsedEdge.source &&
+ edge.target === collapsedEdge.target
+ ) {
+ edgeChanges.push({ id: edge.id, type: 'remove' });
+ }
+ });
+ });
+ state.edges = applyEdgeChanges(edgeChanges, state.edges);
+ }
+ },
+ nodesDeleted: (
+ state,
+ action: PayloadAction<
+ Node[]
+ >
+ ) => {
+ action.payload.forEach((node) => {
+ state.workflow.exposedFields = state.workflow.exposedFields.filter(
+ (f) => f.nodeId !== node.id
+ );
+ if (!isInvocationNode(node)) {
+ return;
+ }
+ delete state.nodeExecutionStates[node.id];
+ });
+ },
+ nodeLabelChanged: (
+ state,
+ action: PayloadAction<{ nodeId: string; label: string }>
+ ) => {
+ const { nodeId, label } = action.payload;
+ const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
+ const node = state.nodes?.[nodeIndex];
+ if (!isInvocationNode(node)) {
+ return;
+ }
+ node.data.label = label;
+ },
+ nodeNotesChanged: (
+ state,
+ action: PayloadAction<{ nodeId: string; notes: string }>
+ ) => {
+ const { nodeId, notes } = action.payload;
+ const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
+ const node = state.nodes?.[nodeIndex];
+ if (!isInvocationNode(node)) {
+ return;
+ }
+ node.data.notes = notes;
+ },
+ selectedNodesChanged: (state, action: PayloadAction) => {
+ state.selectedNodes = action.payload;
+ },
+ selectedEdgesChanged: (state, action: PayloadAction) => {
+ state.selectedEdges = action.payload;
+ },
+ fieldStringValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldNumberValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldBooleanValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldImageValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldColorValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldMainModelValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldRefinerModelValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldVaeModelValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldLoRAModelValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldControlNetModelValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
+ },
+ fieldEnumModelValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action);
},
imageCollectionFieldValueChanged: (
state,
@@ -102,7 +459,13 @@ const nodesSlice = createSlice({
return;
}
- const input = state.nodes?.[nodeIndex]?.data?.inputs[fieldName];
+ const node = state.nodes?.[nodeIndex];
+
+ if (!isInvocationNode(node)) {
+ return;
+ }
+
+ const input = node.data?.inputs[fieldName];
if (!input) {
return;
}
@@ -119,8 +482,30 @@ const nodesSlice = createSlice({
'image_name'
);
},
- shouldShowGraphOverlayChanged: (state, action: PayloadAction) => {
- state.shouldShowGraphOverlay = action.payload;
+ nodeClicked: (
+ state,
+ action: PayloadAction<{ nodeId: string; ctrlOrMeta?: boolean }>
+ ) => {
+ const { nodeId, ctrlOrMeta } = action.payload;
+ state.nodes.forEach((node) => {
+ if (node.id === nodeId) {
+ node.selected = true;
+ } else if (!ctrlOrMeta) {
+ node.selected = false;
+ }
+ });
+ },
+ notesNodeValueChanged: (
+ state,
+ action: PayloadAction<{ nodeId: string; value: string }>
+ ) => {
+ const { nodeId, value } = action.payload;
+ const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
+ const node = state.nodes?.[nodeIndex];
+ if (!isNotesNode(node)) {
+ return;
+ }
+ node.data.notes = value;
},
shouldShowFieldTypeLegendChanged: (
state,
@@ -135,32 +520,126 @@ const nodesSlice = createSlice({
state,
action: PayloadAction>
) => {
- state.invocationTemplates = action.payload;
+ state.nodeTemplates = action.payload;
},
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
},
- setEditorInstance: (state, action) => {
- state.editorInstance = action.payload;
+ shouldValidateGraphChanged: (state, action: PayloadAction) => {
+ state.shouldValidateGraph = action.payload;
},
- loadFileNodes: (state, action: PayloadAction[]>) => {
+ shouldAnimateEdgesChanged: (state, action: PayloadAction) => {
+ state.shouldAnimateEdges = action.payload;
+ },
+ shouldSnapToGridChanged: (state, action: PayloadAction) => {
+ state.shouldSnapToGrid = action.payload;
+ },
+ shouldColorEdgesChanged: (state, action: PayloadAction) => {
+ state.shouldColorEdges = action.payload;
+ },
+ nodeOpacityChanged: (state, action: PayloadAction) => {
+ state.nodeOpacity = action.payload;
+ },
+ loadFileNodes: (
+ state,
+ action: PayloadAction[]>
+ ) => {
state.nodes = action.payload;
},
loadFileEdges: (state, action: PayloadAction) => {
state.edges = action.payload;
},
- setProgressNodeSize: (
- state,
- action: PayloadAction<{ width: number; height: number }>
- ) => {
- state.progressNodeSize = action.payload;
+ workflowNameChanged: (state, action: PayloadAction) => {
+ state.workflow.name = action.payload;
+ },
+ workflowDescriptionChanged: (state, action: PayloadAction) => {
+ state.workflow.description = action.payload;
+ },
+ workflowTagsChanged: (state, action: PayloadAction) => {
+ state.workflow.tags = action.payload;
+ },
+ workflowAuthorChanged: (state, action: PayloadAction) => {
+ state.workflow.author = action.payload;
+ },
+ workflowNotesChanged: (state, action: PayloadAction) => {
+ state.workflow.notes = action.payload;
+ },
+ workflowVersionChanged: (state, action: PayloadAction) => {
+ state.workflow.version = action.payload;
+ },
+ workflowContactChanged: (state, action: PayloadAction) => {
+ state.workflow.contact = action.payload;
+ },
+ workflowLoaded: (state, action: PayloadAction) => {
+ // TODO: validation
+ const { nodes, edges, ...workflow } = action.payload;
+ state.workflow = workflow;
+ state.nodes = applyNodeChanges(
+ nodes.map((node) => ({
+ item: { ...node, dragHandle: `.${DRAG_HANDLE_CLASSNAME}` },
+ type: 'add',
+ })),
+ []
+ );
+ state.edges = applyEdgeChanges(
+ edges.map((edge) => ({ item: edge, type: 'add' })),
+ []
+ );
+ },
+ zoomChanged: (state, action: PayloadAction) => {
+ state.zoom = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload;
});
+ builder.addCase(appSocketInvocationStarted, (state, action) => {
+ const { source_node_id } = action.payload.data;
+ const node = state.nodeExecutionStates[source_node_id];
+ if (node) {
+ node.status = NodeStatus.IN_PROGRESS;
+ }
+ });
+ builder.addCase(appSocketInvocationComplete, (state, action) => {
+ const { source_node_id } = action.payload.data;
+ const node = state.nodeExecutionStates[source_node_id];
+ if (node) {
+ node.status = NodeStatus.COMPLETED;
+ if (node.progress !== null) {
+ node.progress = 1;
+ }
+ }
+ });
+ builder.addCase(appSocketInvocationError, (state, action) => {
+ const { source_node_id } = action.payload.data;
+ const node = state.nodeExecutionStates[source_node_id];
+ if (node) {
+ node.status = NodeStatus.FAILED;
+ node.error = action.payload.data.error;
+ node.progress = null;
+ node.progressImage = null;
+ }
+ });
+ builder.addCase(appSocketGeneratorProgress, (state, action) => {
+ const { source_node_id, step, total_steps, progress_image } =
+ action.payload.data;
+ const node = state.nodeExecutionStates[source_node_id];
+ if (node) {
+ node.status = NodeStatus.IN_PROGRESS;
+ node.progress = (step + 1) / total_steps;
+ node.progressImage = progress_image ?? null;
+ }
+ });
+ builder.addCase(sessionInvoked.fulfilled, (state) => {
+ forEach(state.nodeExecutionStates, (nes) => {
+ nes.status = NodeStatus.PENDING;
+ nes.error = null;
+ nes.progress = null;
+ nes.progressImage = null;
+ });
+ });
},
});
@@ -168,20 +647,53 @@ export const {
nodesChanged,
edgesChanged,
nodeAdded,
- fieldValueChanged,
+ nodesDeleted,
connectionMade,
connectionStarted,
connectionEnded,
- shouldShowGraphOverlayChanged,
+ nodeClicked,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
nodeTemplatesBuilt,
nodeEditorReset,
imageCollectionFieldValueChanged,
- setEditorInstance,
loadFileNodes,
loadFileEdges,
- setProgressNodeSize,
+ fieldStringValueChanged,
+ fieldNumberValueChanged,
+ fieldBooleanValueChanged,
+ fieldImageValueChanged,
+ fieldColorValueChanged,
+ fieldMainModelValueChanged,
+ fieldVaeModelValueChanged,
+ fieldLoRAModelValueChanged,
+ fieldEnumModelValueChanged,
+ fieldControlNetModelValueChanged,
+ fieldRefinerModelValueChanged,
+ nodeIsOpenChanged,
+ nodeLabelChanged,
+ nodeNotesChanged,
+ edgesDeleted,
+ shouldValidateGraphChanged,
+ shouldAnimateEdgesChanged,
+ nodeOpacityChanged,
+ shouldSnapToGridChanged,
+ shouldColorEdgesChanged,
+ selectedNodesChanged,
+ selectedEdgesChanged,
+ workflowNameChanged,
+ workflowDescriptionChanged,
+ workflowTagsChanged,
+ workflowAuthorChanged,
+ workflowNotesChanged,
+ workflowVersionChanged,
+ workflowContactChanged,
+ workflowLoaded,
+ notesNodeValueChanged,
+ workflowExposedFieldAdded,
+ workflowExposedFieldRemoved,
+ fieldLabelChanged,
+ zoomChanged,
} = nodesSlice.actions;
export default nodesSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts
new file mode 100644
index 0000000000..41a608baa3
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts
@@ -0,0 +1,92 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+// import { validateSeedWeights } from 'common/util/seedWeightPairs';
+import { every } from 'lodash-es';
+import { getConnectedEdges } from 'reactflow';
+import { isInvocationNode } from '../types/types';
+import { NodesState } from './types';
+
+export const selectIsReadyNodes = createSelector(
+ [stateSelector],
+ (state) => {
+ const { nodes, system } = state;
+ const { isProcessing, isConnected } = system;
+
+ if (isProcessing || !isConnected) {
+ // Cannot generate if already processing an image
+ return false;
+ }
+
+ if (!nodes.shouldValidateGraph) {
+ return true;
+ }
+
+ const isGraphReady = every(nodes.nodes, (node) => {
+ if (!isInvocationNode(node)) {
+ return true;
+ }
+
+ const nodeTemplate = nodes.nodeTemplates[node.data.type];
+
+ if (!nodeTemplate) {
+ // Node type not found
+ return false;
+ }
+
+ const connectedEdges = getConnectedEdges([node], nodes.edges);
+
+ const isNodeValid = every(node.data.inputs, (field) => {
+ const fieldTemplate = nodeTemplate.inputs[field.name];
+ const hasConnection = connectedEdges.some(
+ (edge) => edge.target === node.id && edge.targetHandle === field.name
+ );
+
+ if (!fieldTemplate) {
+ // Field type not found
+ return false;
+ }
+
+ if (fieldTemplate.required && !field.value && !hasConnection) {
+ // Required field is empty or does not have a connection
+ return false;
+ }
+
+ // ok
+ return true;
+ });
+
+ return isNodeValid;
+ });
+
+ return isGraphReady;
+ },
+ defaultSelectorOptions
+);
+
+export const getNodeAndTemplate = (nodeId: string, nodes: NodesState) => {
+ const node = nodes.nodes.find((node) => node.id === nodeId);
+ const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
+
+ return { node, nodeTemplate };
+};
+
+export const getInputFieldAndTemplate = (
+ nodeId: string,
+ fieldName: string,
+ nodes: NodesState
+) => {
+ const node = nodes.nodes
+ .filter(isInvocationNode)
+ .find((node) => node.id === nodeId);
+ const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
+
+ if (!node || !nodeTemplate) {
+ return;
+ }
+
+ const field = node.data.inputs[fieldName];
+ const fieldTemplate = nodeTemplate.inputs[fieldName];
+
+ return { field, fieldTemplate };
+};
diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts
index 14cb92006b..578adddcee 100644
--- a/invokeai/frontend/web/src/features/nodes/store/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/types.ts
@@ -1,16 +1,31 @@
import { OpenAPIV3 } from 'openapi-types';
-import { Edge, Node, OnConnectStartParams, ReactFlowInstance } from 'reactflow';
-import { InvocationTemplate, InvocationValue } from '../types/types';
+import { Edge, Node, OnConnectStartParams } from 'reactflow';
+import {
+ FieldType,
+ InvocationEdgeExtra,
+ InvocationTemplate,
+ NodeData,
+ NodeExecutionState,
+ Workflow,
+} from '../types/types';
export type NodesState = {
- nodes: Node[];
- edges: Edge[];
+ nodes: Node[];
+ edges: Edge[];
schema: OpenAPIV3.Document | null;
- invocationTemplates: Record;
+ nodeTemplates: Record;
connectionStartParams: OnConnectStartParams | null;
- shouldShowGraphOverlay: boolean;
+ currentConnectionFieldType: FieldType | null;
shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean;
- editorInstance: ReactFlowInstance | undefined;
- progressNodeSize: { width: number; height: number };
+ shouldValidateGraph: boolean;
+ shouldAnimateEdges: boolean;
+ nodeOpacity: number;
+ shouldSnapToGrid: boolean;
+ shouldColorEdges: boolean;
+ selectedNodes: string[];
+ selectedEdges: string[];
+ workflow: Omit;
+ nodeExecutionStates: Record;
+ zoom: number;
};
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts
new file mode 100644
index 0000000000..3cc3859ce0
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts
@@ -0,0 +1,92 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
+import { COLLECTION_TYPES } from 'features/nodes/types/constants';
+import { FieldType } from 'features/nodes/types/types';
+import { HandleType } from 'reactflow';
+
+export const makeConnectionErrorSelector = (
+ nodeId: string,
+ fieldName: string,
+ handleType: HandleType,
+ fieldType: FieldType
+) =>
+ createSelector(stateSelector, (state) => {
+ const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
+ state.nodes;
+
+ if (!state.nodes.shouldValidateGraph) {
+ // manual override!
+ return null;
+ }
+
+ if (!connectionStartParams || !currentConnectionFieldType) {
+ return 'No connection in progress';
+ }
+
+ const {
+ handleType: connectionHandleType,
+ nodeId: connectionNodeId,
+ handleId: connectionFieldName,
+ } = connectionStartParams;
+
+ if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
+ return 'No connection data';
+ }
+
+ const targetFieldType =
+ handleType === 'target' ? fieldType : currentConnectionFieldType;
+ const sourceFieldType =
+ handleType === 'source' ? fieldType : currentConnectionFieldType;
+
+ if (nodeId === connectionNodeId) {
+ return 'Cannot connect to self';
+ }
+
+ if (handleType === connectionHandleType) {
+ if (handleType === 'source') {
+ return 'Cannot connect output to output';
+ }
+ return 'Cannot connect input to input';
+ }
+
+ if (
+ fieldType !== currentConnectionFieldType &&
+ fieldType !== 'CollectionItem' &&
+ currentConnectionFieldType !== 'CollectionItem'
+ ) {
+ if (
+ !(
+ COLLECTION_TYPES.includes(targetFieldType) &&
+ COLLECTION_TYPES.includes(sourceFieldType)
+ )
+ ) {
+ // except for collection items, field types must match
+ return 'Field types must match';
+ }
+ }
+
+ if (
+ handleType === 'target' &&
+ edges.find((edge) => {
+ return edge.target === nodeId && edge.targetHandle === fieldName;
+ }) &&
+ // except CollectionItem inputs can have multiples
+ targetFieldType !== 'CollectionItem'
+ ) {
+ return 'Inputs may only have one connection';
+ }
+
+ const isGraphAcyclic = getIsGraphAcyclic(
+ connectionHandleType === 'source' ? connectionNodeId : nodeId,
+ connectionHandleType === 'source' ? nodeId : connectionNodeId,
+ nodes,
+ edges
+ );
+
+ if (!isGraphAcyclic) {
+ return 'Connection would create a cycle';
+ }
+
+ return null;
+ });
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeTemplateSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeTemplateSelector.ts
index 0a314a75de..2c4ec37f0b 100644
--- a/invokeai/frontend/web/src/features/nodes/store/util/makeTemplateSelector.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/util/makeTemplateSelector.ts
@@ -1,24 +1,11 @@
import { createSelector } from '@reduxjs/toolkit';
-import { RootState } from 'app/store/store';
-import { InvocationTemplate } from 'features/nodes/types/types';
+import { stateSelector } from 'app/store/store';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { AnyInvocationType } from 'services/events/types';
export const makeTemplateSelector = (type: AnyInvocationType) =>
createSelector(
- [(state: RootState) => state.nodes],
- (nodes) => {
- const template = nodes.invocationTemplates[type];
- if (!template) {
- return;
- }
- return template;
- },
- {
- memoizeOptions: {
- resultEqualityCheck: (
- a: InvocationTemplate | undefined,
- b: InvocationTemplate | undefined
- ) => a !== undefined && b !== undefined && a.type === b.type,
- },
- }
+ stateSelector,
+ ({ nodes }) => nodes.nodeTemplates[type],
+ defaultSelectorOptions
);
diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts
index d83d240847..f6389d47dd 100644
--- a/invokeai/frontend/web/src/features/nodes/types/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts
@@ -1,168 +1,172 @@
import { FieldType, FieldUIConfig } from './types';
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
+export const COLOR_TOKEN_VALUE = 500;
+export const NODE_WIDTH = 320;
+export const NODE_MIN_WIDTH = 320;
+export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
-export const FIELD_TYPE_MAP: Record = {
- integer: 'integer',
- float: 'float',
- number: 'float',
- string: 'string',
- boolean: 'boolean',
- enum: 'enum',
- ImageField: 'image',
- image_collection: 'image_collection',
- LatentsField: 'latents',
- ConditioningField: 'conditioning',
- UNetField: 'unet',
- ClipField: 'clip',
- VaeField: 'vae',
- model: 'model',
- refiner_model: 'refiner_model',
- vae_model: 'vae_model',
- lora_model: 'lora_model',
- controlnet_model: 'controlnet_model',
- ControlNetModelField: 'controlnet_model',
- array: 'array',
- item: 'item',
- ColorField: 'color',
- ControlField: 'control',
- control: 'control',
- cfg_scale: 'float',
- control_weight: 'float',
-};
+export const COLLECTION_TYPES: FieldType[] = [
+ 'Collection',
+ 'IntegerCollection',
+ 'FloatCollection',
+ 'StringCollection',
+ 'BooleanCollection',
+ 'ImageCollection',
+];
-const COLOR_TOKEN_VALUE = 500;
-
-const getColorTokenCssVariable = (color: string) =>
- `var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
+export const colorTokenToCssVar = (colorToken: string) =>
+ `var(--invokeai-colors-${colorToken.split('.').join('-')}`;
export const FIELDS: Record = {
integer: {
- color: 'red',
- colorCssVar: getColorTokenCssVariable('red'),
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
+ color: 'red.500',
},
float: {
- color: 'orange',
- colorCssVar: getColorTokenCssVariable('orange'),
title: 'Float',
description: 'Floats are numbers with a decimal point.',
+ color: 'orange.500',
},
string: {
- color: 'yellow',
- colorCssVar: getColorTokenCssVariable('yellow'),
title: 'String',
description: 'Strings are text.',
+ color: 'yellow.500',
},
boolean: {
- color: 'green',
- colorCssVar: getColorTokenCssVariable('green'),
title: 'Boolean',
+ color: 'green.500',
description: 'Booleans are true or false.',
},
enum: {
- color: 'blue',
- colorCssVar: getColorTokenCssVariable('blue'),
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
+ color: 'blue.500',
},
- image: {
- color: 'purple',
- colorCssVar: getColorTokenCssVariable('purple'),
+ ImageField: {
title: 'Image',
description: 'Images may be passed between nodes.',
+ color: 'purple.500',
},
- image_collection: {
- color: 'purple',
- colorCssVar: getColorTokenCssVariable('purple'),
- title: 'Image Collection',
- description: 'A collection of images.',
- },
- latents: {
- color: 'pink',
- colorCssVar: getColorTokenCssVariable('pink'),
+ LatentsField: {
title: 'Latents',
description: 'Latents may be passed between nodes.',
+ color: 'pink.500',
},
- conditioning: {
- color: 'cyan',
- colorCssVar: getColorTokenCssVariable('cyan'),
+ ConditioningField: {
+ color: 'cyan.500',
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
- unet: {
- color: 'red',
- colorCssVar: getColorTokenCssVariable('red'),
+ ImageCollection: {
+ title: 'Image Collection',
+ description: 'A collection of images.',
+ color: 'base.300',
+ },
+ UNetField: {
+ color: 'red.500',
title: 'UNet',
description: 'UNet submodel.',
},
- clip: {
- color: 'green',
- colorCssVar: getColorTokenCssVariable('green'),
+ ClipField: {
+ color: 'green.500',
title: 'Clip',
description: 'Tokenizer and text_encoder submodels.',
},
- vae: {
- color: 'blue',
- colorCssVar: getColorTokenCssVariable('blue'),
+ VaeField: {
+ color: 'blue.500',
title: 'Vae',
description: 'Vae submodel.',
},
- control: {
- color: 'cyan',
- colorCssVar: getColorTokenCssVariable('cyan'), // TODO: no free color left
+ ControlField: {
+ color: 'cyan.500',
title: 'Control',
description: 'Control info passed between nodes.',
},
- model: {
- color: 'teal',
- colorCssVar: getColorTokenCssVariable('teal'),
+ MainModelField: {
+ color: 'teal.500',
title: 'Model',
- description: 'Models are models.',
+ description: 'TODO',
},
- refiner_model: {
- color: 'teal',
- colorCssVar: getColorTokenCssVariable('teal'),
+ SDXLRefinerModelField: {
+ color: 'teal.500',
title: 'Refiner Model',
- description: 'Models are models.',
+ description: 'TODO',
},
- vae_model: {
- color: 'teal',
- colorCssVar: getColorTokenCssVariable('teal'),
+ VaeModelField: {
+ color: 'teal.500',
title: 'VAE',
- description: 'Models are models.',
+ description: 'TODO',
},
- lora_model: {
- color: 'teal',
- colorCssVar: getColorTokenCssVariable('teal'),
+ LoRAModelField: {
+ color: 'teal.500',
title: 'LoRA',
- description: 'Models are models.',
+ description: 'TODO',
},
- controlnet_model: {
- color: 'teal',
- colorCssVar: getColorTokenCssVariable('teal'),
+ ControlNetModelField: {
+ color: 'teal.500',
title: 'ControlNet',
- description: 'Models are models.',
+ description: 'TODO',
},
- array: {
- color: 'gray',
- colorCssVar: getColorTokenCssVariable('gray'),
- title: 'Array',
- description: 'TODO: Array type description.',
+ Collection: {
+ color: 'base.500',
+ title: 'Collection',
+ description: 'TODO',
},
- item: {
- color: 'gray',
- colorCssVar: getColorTokenCssVariable('gray'),
+ CollectionItem: {
+ color: 'base.500',
title: 'Collection Item',
- description: 'TODO: Collection Item type description.',
+ description: 'TODO',
},
- color: {
- color: 'gray',
- colorCssVar: getColorTokenCssVariable('gray'),
+ ColorField: {
title: 'Color',
description: 'A RGBA color.',
+ color: 'base.500',
+ },
+ BooleanCollection: {
+ title: 'Boolean Collection',
+ description: 'A collection of booleans.',
+ color: 'green.500',
+ },
+ IntegerCollection: {
+ title: 'Integer Collection',
+ description: 'A collection of integers.',
+ color: 'red.500',
+ },
+ FloatCollection: {
+ color: 'orange.500',
+ title: 'Float Collection',
+ description: 'A collection of floats.',
+ },
+ FilePath: {
+ color: 'base.500',
+ title: 'File Path',
+ description: 'A path to a file.',
+ },
+ LoRAField: {
+ color: 'base.500',
+ title: 'LoRA',
+ description: 'LoRA field.',
+ },
+ ONNXModelField: {
+ color: 'base.500',
+ title: 'ONNX Model',
+ description: 'ONNX model field.',
+ },
+ SDXLMainModelField: {
+ color: 'base.500',
+ title: 'SDXL Model',
+ description: 'SDXL model field.',
+ },
+ Seed: {
+ color: 'green.500',
+ title: 'Seed',
+ description: 'A seed for random number generation.',
+ },
+ StringCollection: {
+ color: 'yellow.500',
+ title: 'String Collection',
+ description: 'A collection of strings.',
},
};
-
-export const NODE_MIN_WIDTH = 250;
diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts
index 157b990b96..927153e744 100644
--- a/invokeai/frontend/web/src/features/nodes/types/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/types.ts
@@ -2,23 +2,25 @@ import {
ControlNetModelParam,
LoRAModelParam,
MainModelParam,
+ OnnxModelParam,
VaeModelParam,
} from 'features/parameters/types/parameterSchemas';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
-import { Graph, ImageDTO, ImageField } from 'services/api/types';
-import { AnyInvocationType } from 'services/events/types';
+import { Edge, Node } from 'reactflow';
+import {
+ Graph,
+ ImageDTO,
+ ImageField,
+ _InputField,
+ _OutputField,
+} from 'services/api/types';
+import { AnyInvocationType, ProgressImage } from 'services/events/types';
import { O } from 'ts-toolbelt';
+import { z } from 'zod';
export type NonNullableGraph = O.Required;
-export type InvocationValue = {
- id: string;
- type: AnyInvocationType;
- inputs: Record;
- outputs: Record;
-};
-
export type InvocationTemplate = {
/**
* Unique type of the invocation
@@ -47,37 +49,49 @@ export type InvocationTemplate = {
};
export type FieldUIConfig = {
- color: string;
- colorCssVar: string;
title: string;
description: string;
+ color: string;
};
-/**
- * The valid invocation field types
- */
-export type FieldType =
- | 'integer'
- | 'float'
- | 'string'
- | 'boolean'
- | 'enum'
- | 'image'
- | 'latents'
- | 'conditioning'
- | 'unet'
- | 'clip'
- | 'vae'
- | 'control'
- | 'model'
- | 'refiner_model'
- | 'vae_model'
- | 'lora_model'
- | 'controlnet_model'
- | 'array'
- | 'item'
- | 'color'
- | 'image_collection';
+// TODO: Get this from the OpenAPI schema? may be tricky...
+export const zFieldType = z.enum([
+ 'integer',
+ 'float',
+ 'boolean',
+ 'string',
+ 'enum',
+ 'ImageField',
+ 'LatentsField',
+ 'ConditioningField',
+ 'ControlField',
+ 'MainModelField',
+ 'SDXLMainModelField',
+ 'SDXLRefinerModelField',
+ 'ONNXModelField',
+ 'VaeModelField',
+ 'LoRAModelField',
+ 'ControlNetModelField',
+ 'UNetField',
+ 'VaeField',
+ 'LoRAField',
+ 'ClipField',
+ 'ColorField',
+ 'ImageCollection',
+ 'IntegerCollection',
+ 'FloatCollection',
+ 'StringCollection',
+ 'BooleanCollection',
+ 'Seed',
+ 'FilePath',
+ 'Collection',
+ 'CollectionItem',
+]);
+
+export type FieldType = z.infer;
+
+export const isFieldType = (value: unknown): value is FieldType =>
+ zFieldType.safeParse(value).success;
/**
* An input field is persisted across reloads as part of the user's local state.
@@ -89,6 +103,7 @@ export type FieldType =
*/
export type InputFieldValue =
| IntegerInputFieldValue
+ | SeedInputFieldValue
| FloatInputFieldValue
| StringInputFieldValue
| BooleanInputFieldValue
@@ -101,12 +116,13 @@ export type InputFieldValue =
| ControlInputFieldValue
| EnumInputFieldValue
| MainModelInputFieldValue
- | RefinerModelInputFieldValue
+ | SDXLMainModelInputFieldValue
+ | SDXLRefinerModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
- | ArrayInputFieldValue
- | ItemInputFieldValue
+ | CollectionInputFieldValue
+ | CollectionItemInputFieldValue
| ColorInputFieldValue
| ImageCollectionInputFieldValue;
@@ -118,6 +134,7 @@ export type InputFieldValue =
*/
export type InputFieldTemplate =
| IntegerInputFieldTemplate
+ | SeedInputFieldTemplate
| FloatInputFieldTemplate
| StringInputFieldTemplate
| BooleanInputFieldTemplate
@@ -129,13 +146,14 @@ export type InputFieldTemplate =
| VaeInputFieldTemplate
| ControlInputFieldTemplate
| EnumInputFieldTemplate
- | ModelInputFieldTemplate
- | RefinerModelInputFieldTemplate
+ | MainModelInputFieldTemplate
+ | SDXLMainModelInputFieldTemplate
+ | SDXLRefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
- | ArrayInputFieldTemplate
- | ItemInputFieldTemplate
+ | CollectionInputFieldTemplate
+ | CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ImageCollectionInputFieldTemplate;
@@ -146,7 +164,7 @@ export type InputFieldTemplate =
* - `id` a unique identifier
* - `name` the name of the field, which comes from the python dataclass
*/
-export type OutputFieldValue = FieldValueBase;
+export type OutputFieldValue = FieldValueBase & { fieldKind: 'output' };
/**
* An output field template is generated on each page load from the OpenAPI schema.
@@ -154,17 +172,13 @@ export type OutputFieldValue = FieldValueBase;
* The template provides the output field's name, type, title, and description.
*/
export type OutputFieldTemplate = {
+ fieldKind: 'output';
name: string;
type: FieldType;
title: string;
description: string;
};
-/**
- * Indicates when/if this field needs an input.
- */
-export type InputRequirement = 'always' | 'never' | 'optional';
-
/**
* Indicates the kind of input(s) this field may have.
*/
@@ -176,108 +190,123 @@ export type FieldValueBase = {
type: FieldType;
};
-export type IntegerInputFieldValue = FieldValueBase & {
+export type InputFieldValueBase = FieldValueBase & {
+ fieldKind: 'input';
+ label: string;
+};
+
+export type IntegerInputFieldValue = InputFieldValueBase & {
type: 'integer';
value?: number;
};
-export type FloatInputFieldValue = FieldValueBase & {
+export type FloatInputFieldValue = InputFieldValueBase & {
type: 'float';
value?: number;
};
-export type StringInputFieldValue = FieldValueBase & {
+export type SeedInputFieldValue = InputFieldValueBase & {
+ type: 'Seed';
+ value?: number;
+};
+
+export type StringInputFieldValue = InputFieldValueBase & {
type: 'string';
value?: string;
};
-export type BooleanInputFieldValue = FieldValueBase & {
+export type BooleanInputFieldValue = InputFieldValueBase & {
type: 'boolean';
value?: boolean;
};
-export type EnumInputFieldValue = FieldValueBase & {
+export type EnumInputFieldValue = InputFieldValueBase & {
type: 'enum';
value?: number | string;
};
-export type LatentsInputFieldValue = FieldValueBase & {
- type: 'latents';
+export type LatentsInputFieldValue = InputFieldValueBase & {
+ type: 'LatentsField';
value?: undefined;
};
-export type ConditioningInputFieldValue = FieldValueBase & {
- type: 'conditioning';
+export type ConditioningInputFieldValue = InputFieldValueBase & {
+ type: 'ConditioningField';
value?: string;
};
-export type ControlInputFieldValue = FieldValueBase & {
- type: 'control';
+export type ControlInputFieldValue = InputFieldValueBase & {
+ type: 'ControlField';
value?: undefined;
};
-export type UNetInputFieldValue = FieldValueBase & {
- type: 'unet';
+export type UNetInputFieldValue = InputFieldValueBase & {
+ type: 'UNetField';
value?: undefined;
};
-export type ClipInputFieldValue = FieldValueBase & {
- type: 'clip';
+export type ClipInputFieldValue = InputFieldValueBase & {
+ type: 'ClipField';
value?: undefined;
};
-export type VaeInputFieldValue = FieldValueBase & {
- type: 'vae';
+export type VaeInputFieldValue = InputFieldValueBase & {
+ type: 'VaeField';
value?: undefined;
};
-export type ImageInputFieldValue = FieldValueBase & {
- type: 'image';
+export type ImageInputFieldValue = InputFieldValueBase & {
+ type: 'ImageField';
value?: ImageField;
};
-export type ImageCollectionInputFieldValue = FieldValueBase & {
- type: 'image_collection';
+export type ImageCollectionInputFieldValue = InputFieldValueBase & {
+ type: 'ImageCollection';
value?: ImageField[];
};
-export type MainModelInputFieldValue = FieldValueBase & {
- type: 'model';
- value?: MainModelParam;
+export type MainModelInputFieldValue = InputFieldValueBase & {
+ type: 'MainModelField';
+ value?: MainModelParam | OnnxModelParam;
};
-export type RefinerModelInputFieldValue = FieldValueBase & {
- type: 'refiner_model';
- value?: MainModelParam;
+export type SDXLMainModelInputFieldValue = InputFieldValueBase & {
+ type: 'SDXLMainModelField';
+ value?: MainModelParam | OnnxModelParam;
};
-export type VaeModelInputFieldValue = FieldValueBase & {
- type: 'vae_model';
+export type SDXLRefinerModelInputFieldValue = InputFieldValueBase & {
+ type: 'SDXLRefinerModelField';
+ value?: MainModelParam | OnnxModelParam;
+};
+
+export type VaeModelInputFieldValue = InputFieldValueBase & {
+ type: 'VaeModelField';
value?: VaeModelParam;
};
-export type LoRAModelInputFieldValue = FieldValueBase & {
- type: 'lora_model';
+export type LoRAModelInputFieldValue = InputFieldValueBase & {
+ type: 'LoRAModelField';
value?: LoRAModelParam;
};
-export type ControlNetModelInputFieldValue = FieldValueBase & {
- type: 'controlnet_model';
+export type ControlNetModelInputFieldValue = InputFieldValueBase & {
+ type: 'ControlNetModelField';
value?: ControlNetModelParam;
};
-export type ArrayInputFieldValue = FieldValueBase & {
- type: 'array';
+export type CollectionInputFieldValue = InputFieldValueBase & {
+ type: 'Collection';
value?: (string | number)[];
};
-export type ItemInputFieldValue = FieldValueBase & {
- type: 'item';
+export type CollectionItemInputFieldValue = InputFieldValueBase & {
+ type: 'CollectionItem';
value?: undefined;
};
-export type ColorInputFieldValue = FieldValueBase & {
- type: 'color';
+export type ColorInputFieldValue = InputFieldValueBase & {
+ type: 'ColorField';
value?: RgbaColor;
};
@@ -286,9 +315,9 @@ export type InputFieldTemplateBase = {
title: string;
description: string;
type: FieldType;
- inputRequirement: InputRequirement;
- inputKind: InputKind;
-};
+ required: boolean;
+ fieldKind: 'input';
+} & _InputField;
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
type: 'integer';
@@ -300,6 +329,16 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
exclusiveMinimum?: boolean;
};
+export type SeedInputFieldTemplate = InputFieldTemplateBase & {
+ type: 'Seed';
+ default: number;
+ multipleOf?: number;
+ maximum?: number;
+ exclusiveMaximum?: boolean;
+ minimum?: number;
+ exclusiveMinimum?: boolean;
+};
+
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
type: 'float';
default: number;
@@ -325,42 +364,42 @@ export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
default: ImageDTO;
- type: 'image';
+ type: 'ImageField';
};
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: ImageField[];
- type: 'image_collection';
+ type: 'ImageCollection';
};
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
default: string;
- type: 'latents';
+ type: 'LatentsField';
};
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
- type: 'conditioning';
+ type: 'ConditioningField';
};
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
- type: 'unet';
+ type: 'UNetField';
};
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
- type: 'clip';
+ type: 'ClipField';
};
export type VaeInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
- type: 'vae';
+ type: 'VaeField';
};
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
- type: 'control';
+ type: 'ControlField';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
@@ -370,46 +409,59 @@ export type EnumInputFieldTemplate = InputFieldTemplateBase & {
options: Array;
};
-export type ModelInputFieldTemplate = InputFieldTemplateBase & {
- default: string;
- type: 'model';
+export type MainModelInputFieldTemplate = InputFieldTemplateBase & {
+ default: undefined;
+ type: 'MainModelField';
};
-export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
- default: string;
- type: 'refiner_model';
+export type SDXLMainModelInputFieldTemplate = InputFieldTemplateBase & {
+ default: undefined;
+ type: 'SDXLMainModelField';
+};
+
+export type SDXLRefinerModelInputFieldTemplate = InputFieldTemplateBase & {
+ default: undefined;
+ type: 'SDXLRefinerModelField';
};
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
- type: 'vae_model';
+ type: 'VaeModelField';
};
export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
- type: 'lora_model';
+ type: 'LoRAModelField';
};
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
- type: 'controlnet_model';
+ type: 'ControlNetModelField';
};
-export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
+export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
default: [];
- type: 'array';
+ type: 'Collection';
};
-export type ItemInputFieldTemplate = InputFieldTemplateBase & {
+export type CollectionItemInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
- type: 'item';
+ type: 'CollectionItem';
};
export type ColorInputFieldTemplate = InputFieldTemplateBase & {
default: RgbaColor;
- type: 'color';
+ type: 'ColorField';
};
+export const isInputFieldValue = (
+ field: InputFieldValue | OutputFieldValue
+): field is InputFieldValue => field.fieldKind === 'input';
+
+export const isInputFieldTemplate = (
+ fieldTemplate: InputFieldTemplate | OutputFieldTemplate
+): fieldTemplate is InputFieldTemplate => fieldTemplate.fieldKind === 'input';
+
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
*/
@@ -422,12 +474,12 @@ export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation
ui?: {
tags?: string[];
- type_hints?: TypeHints;
title?: string;
};
title: string;
properties: Omit<
- NonNullable,
+ NonNullable &
+ (_InputField | _OutputField),
'type'
> & {
type: Omit & {
@@ -446,6 +498,8 @@ export type InvocationBaseSchemaObject = Omit<
> &
InvocationSchemaExtra;
+export type InvocationFieldSchema = OpenAPIV3.SchemaObject & _InputField;
+
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
type: OpenAPIV3.ArraySchemaObjectType;
items: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject;
@@ -459,3 +513,153 @@ export type InvocationSchemaObject = ArraySchemaObject | NonArraySchemaObject;
export const isInvocationSchemaObject = (
obj: OpenAPIV3.ReferenceObject | InvocationSchemaObject
): obj is InvocationSchemaObject => !('$ref' in obj);
+
+export const isInvocationFieldSchema = (
+ obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
+): obj is InvocationFieldSchema => !('$ref' in obj);
+
+export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
+
+export const zInputFieldValue = z.object({
+ id: z.string().trim().min(1),
+ name: z.string().trim().min(1),
+ type: zFieldType,
+ label: z.string(),
+ isExposed: z.boolean(),
+});
+
+export const zInvocationNodeData = z.object({
+ id: z.string().trim().min(1),
+ type: z.string().trim().min(1),
+ inputs: z.record(z.any()),
+ outputs: z.record(z.any()),
+ label: z.string(),
+ isOpen: z.boolean(),
+ notes: z.string(),
+});
+
+export const zNotesNodeData = z.object({
+ id: z.string().trim().min(1),
+ type: z.literal('notes'),
+ label: z.string(),
+ isOpen: z.boolean(),
+ notes: z.string(),
+});
+
+export const zWorkflow = z.object({
+ name: z.string().trim().min(1),
+ author: z.string(),
+ description: z.string(),
+ version: z.string(),
+ contact: z.string(),
+ tags: z.string(),
+ notes: z.string(),
+ nodes: z.array(
+ z.object({
+ id: z.string().trim().min(1),
+ type: z.string().trim().min(1),
+ data: z.union([zInvocationNodeData, zNotesNodeData]),
+ width: z.number().gt(0),
+ height: z.number().gt(0),
+ position: z.object({
+ x: z.number(),
+ y: z.number(),
+ }),
+ })
+ ),
+ edges: z.array(
+ z.object({
+ source: z.string().trim().min(1),
+ sourceHandle: z.string().trim().min(1),
+ target: z.string().trim().min(1),
+ targetHandle: z.string().trim().min(1),
+ id: z.string().trim().min(1),
+ type: z.string().trim().min(1),
+ })
+ ),
+});
+
+export type Workflow = {
+ name: string;
+ author: string;
+ description: string;
+ version: string;
+ contact: string;
+ tags: string;
+ notes: string;
+ nodes: Pick<
+ Node,
+ 'id' | 'type' | 'data' | 'width' | 'height' | 'position'
+ >[];
+ edges: Pick<
+ Edge,
+ 'source' | 'sourceHandle' | 'target' | 'targetHandle' | 'id' | 'type'
+ >[];
+ exposedFields: FieldIdentifier[];
+};
+
+export type InvocationNodeData = {
+ id: string;
+ type: AnyInvocationType;
+ inputs: Record;
+ outputs: Record;
+ label: string;
+ isOpen: boolean;
+ notes: string;
+};
+
+export type NotesNodeData = {
+ id: string;
+ type: 'notes';
+ label: string;
+ notes: string;
+ isOpen: boolean;
+};
+
+export type CurrentImageNodeData = {
+ id: string;
+ type: 'current_image';
+ isOpen: boolean;
+ label: string;
+};
+
+export type NodeData =
+ | InvocationNodeData
+ | NotesNodeData
+ | CurrentImageNodeData;
+
+export const isInvocationNode = (
+ node?: Node
+): node is Node => node?.type === 'invocation';
+
+export const isInvocationNodeData = (
+ node?: NodeData
+): node is InvocationNodeData =>
+ !['notes', 'current_image'].includes(node?.type ?? '');
+
+export const isNotesNode = (
+ node?: Node
+): node is Node => node?.type === 'notes';
+
+export const isProgressImageNode = (
+ node?: Node
+): node is Node => node?.type === 'current_image';
+
+export enum NodeStatus {
+ PENDING,
+ IN_PROGRESS,
+ COMPLETED,
+ FAILED,
+}
+
+export type NodeExecutionState = {
+ status: NodeStatus;
+ progress: number | null;
+ progressImage: ProgressImage | null;
+ error: string | null;
+};
+
+export type FieldIdentifier = {
+ nodeId: string;
+ fieldName: string;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts
new file mode 100644
index 0000000000..da3aff7e1c
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts
@@ -0,0 +1,42 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { stateSelector } from 'app/store/store';
+import { pick } from 'lodash-es';
+import { NodesState } from '../store/types';
+import { Workflow, isInvocationNode, isNotesNode } from '../types/types';
+
+export const buildWorkflow = (nodesState: NodesState): Workflow => {
+ const { workflow: workflowMeta, nodes, edges } = nodesState;
+ const workflow: Workflow = {
+ ...workflowMeta,
+ nodes: [],
+ edges: [],
+ };
+
+ nodes.forEach((node) => {
+ if (!isInvocationNode(node) && !isNotesNode(node)) {
+ return;
+ }
+ workflow.nodes.push(
+ pick(node, ['id', 'type', 'position', 'width', 'height', 'data'])
+ );
+ });
+
+ edges.forEach((edge) => {
+ workflow.edges.push(
+ pick(edge, [
+ 'source',
+ 'sourceHandle',
+ 'target',
+ 'targetHandle',
+ 'id',
+ 'type',
+ ])
+ );
+ });
+
+ return workflow;
+};
+
+export const workflowSelector = createSelector(stateSelector, ({ nodes }) =>
+ buildWorkflow(nodes)
+);
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
index de7d798c69..44712ec600 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
@@ -1,11 +1,13 @@
+import { logger } from 'app/logging/logger';
+import { parseify } from 'common/util/serialize';
import { reduce } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
-import { FIELD_TYPE_MAP } from '../types/constants';
import { isSchemaObject } from '../types/typeGuards';
import {
- ArrayInputFieldTemplate,
BooleanInputFieldTemplate,
ClipInputFieldTemplate,
+ CollectionInputFieldTemplate,
+ CollectionItemInputFieldTemplate,
ColorInputFieldTemplate,
ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
@@ -17,29 +19,28 @@ import {
ImageInputFieldTemplate,
InputFieldTemplateBase,
IntegerInputFieldTemplate,
- ItemInputFieldTemplate,
+ InvocationFieldSchema,
+ InvocationSchemaObject,
LatentsInputFieldTemplate,
LoRAModelInputFieldTemplate,
- ModelInputFieldTemplate,
+ MainModelInputFieldTemplate,
OutputFieldTemplate,
- RefinerModelInputFieldTemplate,
+ SDXLRefinerModelInputFieldTemplate,
+ SDXLMainModelInputFieldTemplate,
+ SeedInputFieldTemplate,
StringInputFieldTemplate,
- TypeHints,
UNetInputFieldTemplate,
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
+ isFieldType,
+ isInvocationFieldSchema,
} from '../types/types';
-import { logger } from 'app/logging/logger';
-import { parseify } from 'common/util/serialize';
export type BaseFieldProperties = 'name' | 'title' | 'description';
export type BuildInputFieldArg = {
- schemaObject: OpenAPIV3.SchemaObject;
- baseField: Omit<
- InputFieldTemplateBase,
- 'type' | 'inputRequirement' | 'inputKind'
- >;
+ schemaObject: InvocationFieldSchema;
+ baseField: Omit;
};
/**
@@ -52,12 +53,12 @@ export type BuildInputFieldArg = {
*/
export const refObjectToFieldType = (
refObject: OpenAPIV3.ReferenceObject
-): keyof typeof FIELD_TYPE_MAP => {
+): FieldType => {
const name = refObject.$ref.split('/').slice(-1)[0];
if (!name) {
- return 'UNKNOWN FIELD TYPE';
+ throw `Unknown field type: ${name}`;
}
- return name;
+ return name as FieldType;
};
const buildIntegerInputFieldTemplate = ({
@@ -67,8 +68,39 @@ const buildIntegerInputFieldTemplate = ({
const template: IntegerInputFieldTemplate = {
...baseField,
type: 'integer',
- inputRequirement: 'always',
- inputKind: 'any',
+ default: schemaObject.default ?? 0,
+ };
+
+ if (schemaObject.multipleOf !== undefined) {
+ template.multipleOf = schemaObject.multipleOf;
+ }
+
+ if (schemaObject.maximum !== undefined) {
+ template.maximum = schemaObject.maximum;
+ }
+
+ if (schemaObject.exclusiveMaximum !== undefined) {
+ template.exclusiveMaximum = schemaObject.exclusiveMaximum;
+ }
+
+ if (schemaObject.minimum !== undefined) {
+ template.minimum = schemaObject.minimum;
+ }
+
+ if (schemaObject.exclusiveMinimum !== undefined) {
+ template.exclusiveMinimum = schemaObject.exclusiveMinimum;
+ }
+
+ return template;
+};
+
+const buildSeedInputFieldTemplate = ({
+ schemaObject,
+ baseField,
+}: BuildInputFieldArg): SeedInputFieldTemplate => {
+ const template: SeedInputFieldTemplate = {
+ ...baseField,
+ type: 'Seed',
default: schemaObject.default ?? 0,
};
@@ -102,8 +134,6 @@ const buildFloatInputFieldTemplate = ({
const template: FloatInputFieldTemplate = {
...baseField,
type: 'float',
- inputRequirement: 'always',
- inputKind: 'any',
default: schemaObject.default ?? 0,
};
@@ -137,8 +167,6 @@ const buildStringInputFieldTemplate = ({
const template: StringInputFieldTemplate = {
...baseField,
type: 'string',
- inputRequirement: 'always',
- inputKind: 'any',
default: schemaObject.default ?? '',
};
@@ -164,23 +192,32 @@ const buildBooleanInputFieldTemplate = ({
const template: BooleanInputFieldTemplate = {
...baseField,
type: 'boolean',
- inputRequirement: 'always',
- inputKind: 'any',
default: schemaObject.default ?? false,
};
return template;
};
-const buildModelInputFieldTemplate = ({
+const buildMainModelInputFieldTemplate = ({
schemaObject,
baseField,
-}: BuildInputFieldArg): ModelInputFieldTemplate => {
- const template: ModelInputFieldTemplate = {
+}: BuildInputFieldArg): MainModelInputFieldTemplate => {
+ const template: MainModelInputFieldTemplate = {
...baseField,
- type: 'model',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'MainModelField',
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
+const buildSDXLMainModelInputFieldTemplate = ({
+ schemaObject,
+ baseField,
+}: BuildInputFieldArg): SDXLMainModelInputFieldTemplate => {
+ const template: SDXLMainModelInputFieldTemplate = {
+ ...baseField,
+ type: 'SDXLMainModelField',
default: schemaObject.default ?? undefined,
};
@@ -190,12 +227,10 @@ const buildModelInputFieldTemplate = ({
const buildRefinerModelInputFieldTemplate = ({
schemaObject,
baseField,
-}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
- const template: RefinerModelInputFieldTemplate = {
+}: BuildInputFieldArg): SDXLRefinerModelInputFieldTemplate => {
+ const template: SDXLRefinerModelInputFieldTemplate = {
...baseField,
- type: 'refiner_model',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'SDXLRefinerModelField',
default: schemaObject.default ?? undefined,
};
@@ -208,9 +243,7 @@ const buildVaeModelInputFieldTemplate = ({
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
const template: VaeModelInputFieldTemplate = {
...baseField,
- type: 'vae_model',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'VaeModelField',
default: schemaObject.default ?? undefined,
};
@@ -223,9 +256,7 @@ const buildLoRAModelInputFieldTemplate = ({
}: BuildInputFieldArg): LoRAModelInputFieldTemplate => {
const template: LoRAModelInputFieldTemplate = {
...baseField,
- type: 'lora_model',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'LoRAModelField',
default: schemaObject.default ?? undefined,
};
@@ -238,9 +269,7 @@ const buildControlNetModelInputFieldTemplate = ({
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
const template: ControlNetModelInputFieldTemplate = {
...baseField,
- type: 'controlnet_model',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'ControlNetModelField',
default: schemaObject.default ?? undefined,
};
@@ -253,9 +282,7 @@ const buildImageInputFieldTemplate = ({
}: BuildInputFieldArg): ImageInputFieldTemplate => {
const template: ImageInputFieldTemplate = {
...baseField,
- type: 'image',
- inputRequirement: 'always',
- inputKind: 'any',
+ type: 'ImageField',
default: schemaObject.default ?? undefined,
};
@@ -268,9 +295,7 @@ const buildImageCollectionInputFieldTemplate = ({
}: BuildInputFieldArg): ImageCollectionInputFieldTemplate => {
const template: ImageCollectionInputFieldTemplate = {
...baseField,
- type: 'image_collection',
- inputRequirement: 'always',
- inputKind: 'any',
+ type: 'ImageCollection',
default: schemaObject.default ?? undefined,
};
@@ -283,9 +308,7 @@ const buildLatentsInputFieldTemplate = ({
}: BuildInputFieldArg): LatentsInputFieldTemplate => {
const template: LatentsInputFieldTemplate = {
...baseField,
- type: 'latents',
- inputRequirement: 'always',
- inputKind: 'connection',
+ type: 'LatentsField',
default: schemaObject.default ?? undefined,
};
@@ -298,9 +321,7 @@ const buildConditioningInputFieldTemplate = ({
}: BuildInputFieldArg): ConditioningInputFieldTemplate => {
const template: ConditioningInputFieldTemplate = {
...baseField,
- type: 'conditioning',
- inputRequirement: 'always',
- inputKind: 'connection',
+ type: 'ConditioningField',
default: schemaObject.default ?? undefined,
};
@@ -313,9 +334,8 @@ const buildUNetInputFieldTemplate = ({
}: BuildInputFieldArg): UNetInputFieldTemplate => {
const template: UNetInputFieldTemplate = {
...baseField,
- type: 'unet',
- inputRequirement: 'always',
- inputKind: 'connection',
+ type: 'UNetField',
+
default: schemaObject.default ?? undefined,
};
@@ -328,9 +348,7 @@ const buildClipInputFieldTemplate = ({
}: BuildInputFieldArg): ClipInputFieldTemplate => {
const template: ClipInputFieldTemplate = {
...baseField,
- type: 'clip',
- inputRequirement: 'always',
- inputKind: 'connection',
+ type: 'ClipField',
default: schemaObject.default ?? undefined,
};
@@ -343,9 +361,7 @@ const buildVaeInputFieldTemplate = ({
}: BuildInputFieldArg): VaeInputFieldTemplate => {
const template: VaeInputFieldTemplate = {
...baseField,
- type: 'vae',
- inputRequirement: 'always',
- inputKind: 'connection',
+ type: 'VaeField',
default: schemaObject.default ?? undefined,
};
@@ -358,9 +374,7 @@ const buildControlInputFieldTemplate = ({
}: BuildInputFieldArg): ControlInputFieldTemplate => {
const template: ControlInputFieldTemplate = {
...baseField,
- type: 'control',
- inputRequirement: 'always',
- inputKind: 'connection',
+ type: 'ControlField',
default: schemaObject.default ?? undefined,
};
@@ -377,36 +391,30 @@ const buildEnumInputFieldTemplate = ({
type: 'enum',
enumType: (schemaObject.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
options: options,
- inputRequirement: 'always',
- inputKind: 'direct',
default: schemaObject.default ?? options[0],
};
return template;
};
-const buildArrayInputFieldTemplate = ({
+const buildCollectionInputFieldTemplate = ({
baseField,
-}: BuildInputFieldArg): ArrayInputFieldTemplate => {
- const template: ArrayInputFieldTemplate = {
+}: BuildInputFieldArg): CollectionInputFieldTemplate => {
+ const template: CollectionInputFieldTemplate = {
...baseField,
- type: 'array',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'Collection',
default: [],
};
return template;
};
-const buildItemInputFieldTemplate = ({
+const buildCollectionItemInputFieldTemplate = ({
baseField,
-}: BuildInputFieldArg): ItemInputFieldTemplate => {
- const template: ItemInputFieldTemplate = {
+}: BuildInputFieldArg): CollectionItemInputFieldTemplate => {
+ const template: CollectionItemInputFieldTemplate = {
...baseField,
- type: 'item',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'CollectionItem',
default: undefined,
};
@@ -419,9 +427,7 @@ const buildColorInputFieldTemplate = ({
}: BuildInputFieldArg): ColorInputFieldTemplate => {
const template: ColorInputFieldTemplate = {
...baseField,
- type: 'color',
- inputRequirement: 'always',
- inputKind: 'direct',
+ type: 'ColorField',
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
};
@@ -429,42 +435,45 @@ const buildColorInputFieldTemplate = ({
};
export const getFieldType = (
- schemaObject: OpenAPIV3.SchemaObject,
- name: string,
- typeHints?: TypeHints
+ schemaObject: InvocationFieldSchema
): FieldType => {
- let rawFieldType = '';
+ let fieldType = '';
- if (typeHints && name in typeHints) {
- rawFieldType = typeHints[name] ?? 'UNKNOWN FIELD TYPE';
+ const { ui_type_hint } = schemaObject;
+ if (ui_type_hint) {
+ fieldType = ui_type_hint;
} else if (!schemaObject.type) {
+ // console.log('refObject', schemaObject);
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
if (schemaObject.allOf) {
- rawFieldType = refObjectToFieldType(
+ fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.anyOf) {
- rawFieldType = refObjectToFieldType(
+ fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
- rawFieldType = refObjectToFieldType(
+ fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
}
} else if (schemaObject.enum) {
- rawFieldType = 'enum';
+ fieldType = 'enum';
} else if (schemaObject.type) {
- rawFieldType = schemaObject.type;
+ if (schemaObject.type === 'number') {
+ // floats are "number" in OpenAPI, while ints are "integer"
+ fieldType = 'float';
+ } else {
+ fieldType = schemaObject.type;
+ }
}
- const fieldType = FIELD_TYPE_MAP[rawFieldType];
-
- if (!fieldType) {
- throw `Field type "${rawFieldType}" is unknown!`;
+ if (!isFieldType(fieldType)) {
+ throw `Field type "${fieldType}" is unknown!`;
}
return fieldType;
@@ -472,93 +481,171 @@ export const getFieldType = (
/**
* Builds an input field from an invocation schema property.
- * @param schemaObject The schema object
+ * @param fieldSchema The schema object
* @returns An input field
*/
export const buildInputFieldTemplate = (
- schemaObject: OpenAPIV3.SchemaObject,
- name: string,
- typeHints?: TypeHints
+ nodeSchema: InvocationSchemaObject,
+ fieldSchema: InvocationFieldSchema,
+ name: string
) => {
- const fieldType = getFieldType(schemaObject, name, typeHints);
+ // console.log('input', schemaObject);
+ const fieldType = getFieldType(fieldSchema);
+ // console.log('input fieldType', fieldType);
+
+ const { input, ui_hidden, ui_component, ui_type_hint } = fieldSchema;
+
+ const extra = {
+ input,
+ ui_hidden,
+ ui_component,
+ ui_type_hint,
+ required: nodeSchema.required?.includes(name) ?? false,
+ };
const baseField = {
name,
- title: schemaObject.title ?? '',
- description: schemaObject.description ?? '',
+ title: fieldSchema.title ?? '',
+ description: fieldSchema.description ?? '',
+ fieldKind: 'input' as const,
+ ...extra,
};
- if (['image'].includes(fieldType)) {
- return buildImageInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'ImageField') {
+ return buildImageInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
-
- if (['image_collection'].includes(fieldType)) {
- return buildImageCollectionInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'ImageCollection') {
+ return buildImageCollectionInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['latents'].includes(fieldType)) {
- return buildLatentsInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'LatentsField') {
+ return buildLatentsInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['conditioning'].includes(fieldType)) {
- return buildConditioningInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'ConditioningField') {
+ return buildConditioningInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['unet'].includes(fieldType)) {
- return buildUNetInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'UNetField') {
+ return buildUNetInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['clip'].includes(fieldType)) {
- return buildClipInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'ClipField') {
+ return buildClipInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['vae'].includes(fieldType)) {
- return buildVaeInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'VaeField') {
+ return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
}
- if (['control'].includes(fieldType)) {
- return buildControlInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'ControlField') {
+ return buildControlInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['model'].includes(fieldType)) {
- return buildModelInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'MainModelField') {
+ return buildMainModelInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['refiner_model'].includes(fieldType)) {
- return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'SDXLRefinerModelField') {
+ return buildRefinerModelInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['vae_model'].includes(fieldType)) {
- return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'SDXLMainModelField') {
+ return buildSDXLMainModelInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['lora_model'].includes(fieldType)) {
- return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'VaeModelField') {
+ return buildVaeModelInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['controlnet_model'].includes(fieldType)) {
- return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'LoRAModelField') {
+ return buildLoRAModelInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['enum'].includes(fieldType)) {
- return buildEnumInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'ControlNetModelField') {
+ return buildControlNetModelInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['integer'].includes(fieldType)) {
- return buildIntegerInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'enum') {
+ return buildEnumInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['number', 'float'].includes(fieldType)) {
- return buildFloatInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'integer') {
+ return buildIntegerInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['string'].includes(fieldType)) {
- return buildStringInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'float') {
+ return buildFloatInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['boolean'].includes(fieldType)) {
- return buildBooleanInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'string') {
+ return buildStringInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['array'].includes(fieldType)) {
- return buildArrayInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'boolean') {
+ return buildBooleanInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['item'].includes(fieldType)) {
- return buildItemInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'Seed') {
+ return buildSeedInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['color'].includes(fieldType)) {
- return buildColorInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'Collection') {
+ return buildCollectionInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['array'].includes(fieldType)) {
- return buildArrayInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'CollectionItem') {
+ return buildCollectionItemInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
- if (['item'].includes(fieldType)) {
- return buildItemInputFieldTemplate({ schemaObject, baseField });
+ if (fieldType === 'ColorField') {
+ return buildColorInputFieldTemplate({
+ schemaObject: fieldSchema,
+ baseField,
+ });
}
-
return;
};
@@ -570,8 +657,7 @@ export const buildInputFieldTemplate = (
*/
export const buildOutputFieldTemplates = (
refObject: OpenAPIV3.ReferenceObject,
- openAPI: OpenAPIV3.Document,
- typeHints?: TypeHints
+ openAPI: OpenAPIV3.Document
): Record => {
// extract output schema name from ref
const outputSchemaName = refObject.$ref.split('/').slice(-1)[0];
@@ -585,31 +671,34 @@ export const buildOutputFieldTemplates = (
}
// get the output schema itself
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
- const outputSchema = openAPI.components!.schemas![outputSchemaName];
-
+ const outputSchema = openAPI.components?.schemas?.[outputSchemaName];
if (!outputSchema) {
logger('nodes').error({ outputSchemaName }, 'Output schema not found');
throw 'Output schema not found';
}
+ // console.log('output', outputSchema);
if (isSchemaObject(outputSchema)) {
+ // console.log('isSchemaObject');
const outputFields = reduce(
outputSchema.properties as OpenAPIV3.SchemaObject,
(outputsAccumulator, property, propertyName) => {
if (
!['type', 'id'].includes(propertyName) &&
!['object'].includes(property.type) && // TODO: handle objects?
- isSchemaObject(property)
+ isInvocationFieldSchema(property)
) {
- const fieldType = getFieldType(property, propertyName, typeHints);
-
+ const fieldType = getFieldType(property);
+ // console.log('output fieldType', fieldType);
outputsAccumulator[propertyName] = {
+ fieldKind: 'output',
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: fieldType,
};
+ } else {
+ // console.warn('Unhandled OUTPUT property', property);
}
return outputsAccumulator;
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
index 3c6850a88a..473dc83bb6 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
@@ -8,89 +8,89 @@ export const buildInputFieldValue = (
id,
name: template.name,
type: template.type,
+ label: '',
+ fieldKind: 'input',
};
- if (template.inputRequirement !== 'never') {
- if (template.type === 'string') {
+ if (template.type === 'string') {
+ fieldValue.value = template.default ?? '';
+ }
+
+ if (template.type === 'integer') {
+ fieldValue.value = template.default ?? 0;
+ }
+
+ if (template.type === 'float') {
+ fieldValue.value = template.default ?? 0;
+ }
+
+ if (template.type === 'boolean') {
+ fieldValue.value = template.default ?? false;
+ }
+
+ if (template.type === 'enum') {
+ if (template.enumType === 'number') {
+ fieldValue.value = template.default ?? 0;
+ }
+ if (template.enumType === 'string') {
fieldValue.value = template.default ?? '';
}
+ }
- if (template.type === 'integer') {
- fieldValue.value = template.default ?? 0;
- }
+ if (template.type === 'Collection') {
+ fieldValue.value = template.default ?? 1;
+ }
- if (template.type === 'float') {
- fieldValue.value = template.default ?? 0;
- }
+ if (template.type === 'ImageField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'boolean') {
- fieldValue.value = template.default ?? false;
- }
+ if (template.type === 'ImageCollection') {
+ fieldValue.value = [];
+ }
- if (template.type === 'enum') {
- if (template.enumType === 'number') {
- fieldValue.value = template.default ?? 0;
- }
- if (template.enumType === 'string') {
- fieldValue.value = template.default ?? '';
- }
- }
+ if (template.type === 'LatentsField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'array') {
- fieldValue.value = template.default ?? 1;
- }
+ if (template.type === 'ConditioningField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'image') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'UNetField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'image_collection') {
- fieldValue.value = [];
- }
+ if (template.type === 'ClipField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'latents') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'VaeField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'conditioning') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'ControlField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'unet') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'MainModelField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'clip') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'SDXLRefinerModelField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'vae') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'VaeModelField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'control') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'LoRAModelField') {
+ fieldValue.value = undefined;
+ }
- if (template.type === 'model') {
- fieldValue.value = undefined;
- }
-
- if (template.type === 'refiner_model') {
- fieldValue.value = undefined;
- }
-
- if (template.type === 'vae_model') {
- fieldValue.value = undefined;
- }
-
- if (template.type === 'lora_model') {
- fieldValue.value = undefined;
- }
-
- if (template.type === 'controlnet_model') {
- fieldValue.value = undefined;
- }
+ if (template.type === 'ControlNetModelField') {
+ fieldValue.value = undefined;
}
return fieldValue;
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts
index 578c4371f2..491c6547ba 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts
@@ -81,9 +81,9 @@ export const addControlNetToLinearGraph = (
return;
}
- graph.nodes[controlNetNode.id] = controlNetNode;
+ graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
- if (metadataAccumulator) {
+ if (metadataAccumulator?.controlnets) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const controlField = omit(controlNetNode, [
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts
index cdd91d6e4f..01158d1cf0 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts
@@ -67,7 +67,7 @@ export const addLoRAsToGraph = (
};
// add the lora to the metadata accumulator
- if (metadataAccumulator) {
+ if (metadataAccumulator?.loras) {
metadataAccumulator.loras.push({
lora: { model_name, base_model },
weight,
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts
index 35e7f3ac38..3291348d0a 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts
@@ -43,7 +43,7 @@ export const addNSFWCheckerToGraph = (
is_intermediate,
};
- graph.nodes[NSFW_CHECKER] = nsfwCheckerNode;
+ graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
graph.edges.push({
source: {
node_id: nodeIdToAddTo,
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts
index c0f7f7ca82..190816f21f 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts
@@ -70,6 +70,9 @@ export const addSDXLLoRAsToGraph = (
// add the lora to the metadata accumulator
if (metadataAccumulator) {
+ if (!metadataAccumulator.loras) {
+ metadataAccumulator.loras = [];
+ }
metadataAccumulator.loras.push({
lora: { model_name, base_model },
weight,
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts
index adce34adf5..2a6ef8e80c 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts
@@ -41,9 +41,9 @@ export const addSDXLRefinerToGraph = (
if (metadataAccumulator) {
metadataAccumulator.refiner_model = refinerModel;
- metadataAccumulator.refiner_positive_aesthetic_score =
+ metadataAccumulator.refiner_positive_aesthetic_store =
refinerPositiveAestheticScore;
- metadataAccumulator.refiner_negative_aesthetic_score =
+ metadataAccumulator.refiner_negative_aesthetic_store =
refinerNegativeAestheticScore;
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
metadataAccumulator.refiner_scheduler = refinerScheduler;
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts
index 23f6acb539..97688bd154 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts
@@ -5,7 +5,7 @@ import {
ImageBlurInvocation,
ImageDTO,
ImageToLatentsInvocation,
- InfillPatchmatchInvocation,
+ InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
RandomIntInvocation,
@@ -539,7 +539,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
...(graph.nodes[INPAINT_INFILL] as
| InfillTileInvocation
- | InfillPatchmatchInvocation),
+ | InfillPatchMatchInvocation),
image: canvasInitImage,
};
graph.nodes[NOISE] = {
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts
index 50a773bf50..6bffb83b6e 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts
@@ -5,7 +5,7 @@ import {
ImageBlurInvocation,
ImageDTO,
ImageToLatentsInvocation,
- InfillPatchmatchInvocation,
+ InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
RandomIntInvocation,
@@ -553,7 +553,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
...(graph.nodes[INPAINT_INFILL] as
| InfillTileInvocation
- | InfillPatchmatchInvocation),
+ | InfillPatchMatchInvocation),
image: canvasInitImage,
};
graph.nodes[NOISE] = {
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
index bae29ee3f5..4304737ba9 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
@@ -1,5 +1,5 @@
-import { RootState } from 'app/store/store';
-import { InputFieldValue } from 'features/nodes/types/types';
+import { NodesState } from 'features/nodes/store/types';
+import { InputFieldValue, isInvocationNode } from 'features/nodes/types/types';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
@@ -9,7 +9,7 @@ import { v4 as uuidv4 } from 'uuid';
* We need to do special handling for some fields
*/
export const parseFieldValue = (field: InputFieldValue) => {
- if (field.type === 'color') {
+ if (field.type === 'ColorField') {
if (field.value) {
const clonedValue = cloneDeep(field.value);
@@ -30,10 +30,10 @@ export const parseFieldValue = (field: InputFieldValue) => {
/**
* Builds a graph from the node editor state.
*/
-export const buildNodesGraph = (state: RootState): Graph => {
- const { nodes, edges } = state.nodes;
+export const buildNodesGraph = (nodesState: NodesState): Graph => {
+ const { nodes, edges } = nodesState;
- const filteredNodes = nodes.filter((n) => n.type !== 'progress_image');
+ const filteredNodes = nodes.filter(isInvocationNode);
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce>(
@@ -70,8 +70,11 @@ export const buildNodesGraph = (state: RootState): Graph => {
{}
);
+ // skip out the "dummy" edges between collapsed nodes
+ const filteredEdges = edges.filter((n) => n.type !== 'collapsed');
+
// Reduce the node editor edges into invocation graph edges
- const parsedEdges = edges.reduce>(
+ const parsedEdges = filteredEdges.reduce>(
(edgesAccumulator, edge) => {
const { source, target, sourceHandle, targetHandle } = edge;
diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
index 3a9cf233b5..19201b23bb 100644
--- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
@@ -1,11 +1,10 @@
import { filter, reduce } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
-import { isSchemaObject } from '../types/typeGuards';
import {
InputFieldTemplate,
InvocationSchemaObject,
InvocationTemplate,
- OutputFieldTemplate,
+ isInvocationFieldSchema,
isInvocationSchemaObject,
} from '../types/types';
import {
@@ -13,12 +12,7 @@ import {
buildOutputFieldTemplates,
} from './fieldTemplateBuilders';
-const getReservedFieldNames = (type: string): string[] => {
- if (type === 'l2i') {
- return ['id', 'type', 'metadata'];
- }
- return ['id', 'type', 'is_intermediate', 'metadata'];
-};
+const RESERVED_FIELD_NAMES = ['id', 'type', 'metadata'];
const invocationDenylist = [
'Graph',
@@ -42,83 +36,41 @@ export const parseSchema = (
>((acc, schema) => {
if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default;
- const RESERVED_FIELD_NAMES = getReservedFieldNames(type);
-
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
- const typeHints = schema.ui?.type_hints;
+ const tags = schema.ui?.tags ?? [];
+ const description = schema.description ?? '';
- const inputs: Record = {};
+ const inputs = reduce(
+ schema.properties,
+ (inputsAccumulator, property, propertyName) => {
+ if (
+ !RESERVED_FIELD_NAMES.includes(propertyName) &&
+ isInvocationFieldSchema(property) &&
+ !property.ui_hidden
+ ) {
+ const field = buildInputFieldTemplate(
+ schema,
+ property,
+ propertyName
+ );
- if (type === 'collect') {
- const itemProperty = schema.properties.item as InvocationSchemaObject;
- inputs.item = {
- type: 'item',
- name: 'item',
- description: itemProperty.description ?? '',
- title: 'Collection Item',
- inputKind: 'connection',
- inputRequirement: 'always',
- default: undefined,
- };
- } else if (type === 'iterate') {
- const itemProperty = schema.properties
- .collection as InvocationSchemaObject;
- inputs.collection = {
- type: 'array',
- name: 'collection',
- title: itemProperty.title ?? '',
- default: [],
- description: itemProperty.description ?? '',
- inputRequirement: 'always',
- inputKind: 'connection',
- };
- } else {
- reduce(
- schema.properties,
- (inputsAccumulator, property, propertyName) => {
- if (
- !RESERVED_FIELD_NAMES.includes(propertyName) &&
- isSchemaObject(property)
- ) {
- const field = buildInputFieldTemplate(
- property,
- propertyName,
- typeHints
- );
- if (field) {
- inputsAccumulator[propertyName] = field;
- }
+ if (field) {
+ inputsAccumulator[propertyName] = field;
}
- return inputsAccumulator;
- },
- inputs
- );
- }
+ }
+ return inputsAccumulator;
+ },
+ {} as Record
+ );
const rawOutput = (schema as InvocationSchemaObject).output;
- let outputs: Record;
-
- if (type === 'iterate') {
- const iterationOutput = openAPI.components?.schemas?.[
- 'IterateInvocationOutput'
- ] as OpenAPIV3.SchemaObject;
- outputs = {
- item: {
- name: 'item',
- title: iterationOutput?.title ?? '',
- description: iterationOutput?.description ?? '',
- type: 'array',
- },
- };
- } else {
- outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
- }
+ const outputs = buildOutputFieldTemplates(rawOutput, openAPI);
const invocation: InvocationTemplate = {
title,
type,
- tags: schema.ui?.tags ?? [],
- description: schema.description ?? '',
+ tags,
+ description,
inputs,
outputs,
};
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx
index 0c5b2c68d0..e8a629e2ac 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx
@@ -1,14 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
-import {
- TypesafeDraggableData,
- TypesafeDroppableData,
-} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
+import {
+ TypesafeDraggableData,
+ TypesafeDroppableData,
+} from 'features/dnd/types';
import { useMemo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx
index af810f7836..04c71be20f 100644
--- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx
+++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx
@@ -9,8 +9,8 @@ import {
ModalHeader,
ModalOverlay,
Text,
- useDisclosure,
useColorMode,
+ useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { VALID_LOG_LEVELS } from 'app/logging/logger';
@@ -23,7 +23,6 @@ import { setShouldShowAdvancedOptions } from 'features/parameters/store/generati
import {
consoleLogLevelChanged,
setEnableImageDebugging,
- setIsNodesEnabled,
setShouldConfirmOnDelete,
shouldAntialiasProgressImageChanged,
shouldLogToConsoleChanged,
@@ -46,14 +45,14 @@ import {
import { useTranslation } from 'react-i18next';
import { LogLevelName } from 'roarr';
import { useGetAppConfigQuery } from 'services/api/endpoints/appInfo';
+import { useFeatureStatus } from '../../hooks/useFeatureStatus';
+import { LANGUAGES } from '../../store/constants';
+import { languageSelector } from '../../store/systemSelectors';
+import { languageChanged } from '../../store/systemSlice';
import SettingSwitch from './SettingSwitch';
import SettingsClearIntermediates from './SettingsClearIntermediates';
import SettingsSchedulers from './SettingsSchedulers';
import StyledFlex from './StyledFlex';
-import { useFeatureStatus } from '../../hooks/useFeatureStatus';
-import { LANGUAGES } from '../../store/constants';
-import { languageChanged } from '../../store/systemSlice';
-import { languageSelector } from '../../store/systemSelectors';
const selector = createSelector(
[stateSelector],
@@ -64,7 +63,6 @@ const selector = createSelector(
consoleLogLevel,
shouldLogToConsole,
shouldAntialiasProgressImage,
- isNodesEnabled,
shouldUseNSFWChecker,
shouldUseWatermarker,
} = system;
@@ -87,7 +85,6 @@ const selector = createSelector(
shouldLogToConsole,
shouldAntialiasProgressImage,
shouldShowAdvancedOptions,
- isNodesEnabled,
shouldUseNSFWChecker,
shouldUseWatermarker,
};
@@ -103,7 +100,6 @@ type ConfigOptions = {
shouldShowBetaLayout: boolean;
shouldShowAdvancedOptionsSettings: boolean;
shouldShowClearIntermediates: boolean;
- shouldShowNodesToggle: boolean;
shouldShowLocalizationToggle: boolean;
};
@@ -125,7 +121,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
config?.shouldShowAdvancedOptionsSettings ?? true;
const shouldShowClearIntermediates =
config?.shouldShowClearIntermediates ?? true;
- const shouldShowNodesToggle = config?.shouldShowNodesToggle ?? true;
const shouldShowLocalizationToggle =
config?.shouldShowLocalizationToggle ?? true;
@@ -167,7 +162,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
shouldLogToConsole,
shouldAntialiasProgressImage,
shouldShowAdvancedOptions,
- isNodesEnabled,
shouldUseNSFWChecker,
shouldUseWatermarker,
} = useAppSelector(selector);
@@ -207,13 +201,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
[dispatch]
);
- const handleToggleNodes = useCallback(
- (e: ChangeEvent) => {
- dispatch(setIsNodesEnabled(e.target.checked));
- },
- [dispatch]
- );
-
const { colorMode, toggleColorMode } = useColorMode();
const isLocalizationEnabled =
@@ -320,14 +307,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
}
/>
)}
- {shouldShowNodesToggle && (
-
- )}
{shouldShowLocalizationToggle && (
) {
state.progressImage = action.payload;
},
- setIsNodesEnabled(state, action: PayloadAction) {
- state.isNodesEnabled = action.payload;
- },
shouldUseNSFWCheckerChanged(state, action: PayloadAction) {
state.shouldUseNSFWChecker = action.payload;
},
@@ -425,7 +420,6 @@ export const {
shouldAntialiasProgressImageChanged,
languageChanged,
progressImageSet,
- setIsNodesEnabled,
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
} = systemSlice.actions;
diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx
index 87ce801ead..2dd9220288 100644
--- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx
@@ -11,11 +11,10 @@ import {
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
-import { RootState } from 'app/store/store';
+import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
-import { configSelector } from 'features/system/store/configSelectors';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { ResourceKey } from 'i18next';
@@ -37,7 +36,6 @@ import NodesTab from './tabs/Nodes/NodesTab';
import ResizeHandle from './tabs/ResizeHandle';
import TextToImageTab from './tabs/TextToImage/TextToImageTab';
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
-import { systemSelector } from '../../system/store/systemSelectors';
export interface InvokeTabInfo {
id: InvokeTabName;
@@ -77,27 +75,13 @@ const tabs: InvokeTabInfo[] = [
icon: ,
content: ,
},
- // {
- // id: 'batch',
- // icon: ,
- // content: ,
- // },
];
const enabledTabsSelector = createSelector(
- [configSelector, systemSelector],
- (config, system) => {
+ [stateSelector],
+ ({ config }) => {
const { disabledTabs } = config;
- const { isNodesEnabled } = system;
-
- const enabledTabs = tabs.filter((tab) => {
- if (tab.id === 'nodes') {
- return isNodesEnabled && !disabledTabs.includes(tab.id);
- } else {
- return !disabledTabs.includes(tab.id);
- }
- });
-
+ const enabledTabs = tabs.filter((tab) => !disabledTabs.includes(tab.id));
return enabledTabs;
},
{
@@ -185,6 +169,7 @@ const InvokeTabs = () => {
return (
{
- return ;
+ return (
+
+
+
+ );
};
export default memo(NodesTab);
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx
index 57f2e89ef0..b8f99bfa29 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/ResizeHandle.tsx
@@ -1,33 +1,47 @@
-import { Box, Flex, FlexProps, useColorMode } from '@chakra-ui/react';
+import { Box, Flex, FlexProps, useColorModeValue } from '@chakra-ui/react';
import { memo } from 'react';
import { PanelResizeHandle } from 'react-resizable-panels';
-import { mode } from 'theme/util/mode';
type ResizeHandleProps = Omit & {
direction?: 'horizontal' | 'vertical';
+ collapsedDirection?: 'top' | 'bottom' | 'left' | 'right';
};
const ResizeHandle = (props: ResizeHandleProps) => {
- const { direction = 'horizontal', ...rest } = props;
- const { colorMode } = useColorMode();
+ const { direction = 'horizontal', collapsedDirection, ...rest } = props;
+ const bg = useColorModeValue('base.100', 'base.850');
+ const hoverBg = useColorModeValue('base.300', 'base.700');
if (direction === 'horizontal') {
return (
@@ -38,19 +52,32 @@ const ResizeHandle = (props: ResizeHandleProps) => {
return (
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx
index f48daa678a..91bf4732c3 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx
@@ -2,19 +2,16 @@ import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import IAIDropOverlay from 'common/components/IAIDropOverlay';
import IAICanvas from 'features/canvas/components/IAICanvas';
import IAICanvasResizer from 'features/canvas/components/IAICanvasResizer';
import IAICanvasToolbar from 'features/canvas/components/IAICanvasToolbar/IAICanvasToolbar';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
+import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
+import { CanvasInitialImageDropData } from 'features/dnd/types';
+import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { uiSelector } from 'features/ui/store/uiSelectors';
-
-import {
- CanvasInitialImageDropData,
- isValidDrop,
- useDroppable,
-} from 'app/components/ImageDnd/typesafeDnd';
-import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { memo, useLayoutEffect } from 'react';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
@@ -47,7 +44,7 @@ const UnifiedCanvasContent = () => {
isOver,
setNodeRef: setDroppableRef,
active,
- } = useDroppable({
+ } = useDroppableTypesafe({
id: 'unifiedCanvas',
data: droppableData,
});
diff --git a/invokeai/frontend/web/src/features/ui/hooks/useMinimumPanelSize.ts b/invokeai/frontend/web/src/features/ui/hooks/useMinimumPanelSize.ts
index b382b27329..3b2d9a068a 100644
--- a/invokeai/frontend/web/src/features/ui/hooks/useMinimumPanelSize.ts
+++ b/invokeai/frontend/web/src/features/ui/hooks/useMinimumPanelSize.ts
@@ -14,7 +14,10 @@ export const useMinimumPanelSize = (
defaultSizePct: number,
groupId: string,
orientation: 'horizontal' | 'vertical' = 'horizontal'
-): { ref: RefObject; minSizePct: number } => {
+): {
+ ref: RefObject;
+ minSizePct: number;
+} => {
const ref = useRef(null);
const [minSizePct, setMinSizePct] = useState(defaultSizePct);
@@ -31,7 +34,9 @@ export const useMinimumPanelSize = (
`[data-panel-group-id="${groupId}"]`
);
const resizeHandles = document.querySelectorAll(
- '[data-panel-resize-handle-id]'
+ orientation === 'horizontal'
+ ? '.resize-handle-horizontal'
+ : '.resize-handle-vertical'
);
if (!panelGroup) {
diff --git a/invokeai/frontend/web/src/features/ui/store/hotkeysSlice.ts b/invokeai/frontend/web/src/features/ui/store/hotkeysSlice.ts
index 2c16d6b5f4..2b59300ddd 100644
--- a/invokeai/frontend/web/src/features/ui/store/hotkeysSlice.ts
+++ b/invokeai/frontend/web/src/features/ui/store/hotkeysSlice.ts
@@ -3,10 +3,14 @@ import { createSlice } from '@reduxjs/toolkit';
type HotkeysState = {
shift: boolean;
+ ctrl: boolean;
+ meta: boolean;
};
export const initialHotkeysState: HotkeysState = {
shift: false,
+ ctrl: false,
+ meta: false,
};
export const hotkeysSlice = createSlice({
@@ -16,9 +20,16 @@ export const hotkeysSlice = createSlice({
shiftKeyPressed: (state, action: PayloadAction) => {
state.shift = action.payload;
},
+ ctrlKeyPressed: (state, action: PayloadAction) => {
+ state.ctrl = action.payload;
+ },
+ metaKeyPressed: (state, action: PayloadAction) => {
+ state.meta = action.payload;
+ },
},
});
-export const { shiftKeyPressed } = hotkeysSlice.actions;
+export const { shiftKeyPressed, ctrlKeyPressed, metaKeyPressed } =
+ hotkeysSlice.actions;
export default hotkeysSlice.reducer;
diff --git a/invokeai/frontend/web/src/services/api/constants.ts b/invokeai/frontend/web/src/services/api/constants.ts
index 8bf35d0198..9db6334965 100644
--- a/invokeai/frontend/web/src/services/api/constants.ts
+++ b/invokeai/frontend/web/src/services/api/constants.ts
@@ -13,4 +13,7 @@ export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
'sdxl',
];
+export const SDXL_MAIN_MODELS: BaseModelType[] = ['sdxl'];
+export const NON_SDXL_MAIN_MODELS: BaseModelType[] = ['sd-1', 'sd-2'];
+
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];
diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts
index 0bfa7c334f..4348a28fd0 100644
--- a/invokeai/frontend/web/src/services/api/schema.d.ts
+++ b/invokeai/frontend/web/src/services/api/schema.d.ts
@@ -312,7 +312,7 @@ export type components = {
added_image_names: (string)[];
};
/**
- * AddInvocation
+ * Add Integers
* @description Adds two numbers
*/
AddInvocation: {
@@ -332,7 +332,7 @@ export type components = {
* @default add
* @enum {string}
*/
- type?: "add";
+ type: "add";
/**
* A
* @description The first number
@@ -549,7 +549,7 @@ export type components = {
file: Blob;
};
/**
- * CannyImageProcessorInvocation
+ * Canny Processor
* @description Canny edge detection for ControlNet
*/
CannyImageProcessorInvocation: {
@@ -569,7 +569,7 @@ export type components = {
* @default canny_image_processor
* @enum {string}
*/
- type?: "canny_image_processor";
+ type: "canny_image_processor";
/**
* Image
* @description The image to process
@@ -612,7 +612,7 @@ export type components = {
loras: (components["schemas"]["LoraInfo"])[];
};
/**
- * ClipSkipInvocation
+ * CLIP Skip
* @description Skip layers in clip text_encoder model.
*/
ClipSkipInvocation: {
@@ -632,15 +632,15 @@ export type components = {
* @default clip_skip
* @enum {string}
*/
- type?: "clip_skip";
+ type: "clip_skip";
/**
- * Clip
- * @description Clip to use
+ * CLIP
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
* Skipped Layers
- * @description Number of layers to skip in text_encoder
+ * @description Number of layers to skip in text encoder
* @default 0
*/
skipped_layers?: number;
@@ -657,8 +657,8 @@ export type components = {
*/
type?: "clip_skip_output";
/**
- * Clip
- * @description Clip with skipped layers
+ * CLIP
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
@@ -683,9 +683,9 @@ export type components = {
* @default collect
* @enum {string}
*/
- type?: "collect";
+ type: "collect";
/**
- * Item
+ * Collection Item
* @description The item to collect (all inputs must be of the same type)
*/
item?: unknown;
@@ -705,7 +705,7 @@ export type components = {
* @default collect_output
* @enum {string}
*/
- type: "collect_output";
+ type?: "collect_output";
/**
* Collection
* @description The collection of input items
@@ -713,7 +713,7 @@ export type components = {
collection: (unknown)[];
};
/**
- * ColorCorrectInvocation
+ * Color Correct
* @description Shifts the colors of a target image to match the reference image, optionally
* using a mask to only color-correct certain regions of the target image.
*/
@@ -734,7 +734,7 @@ export type components = {
* @default color_correct
* @enum {string}
*/
- type?: "color_correct";
+ type: "color_correct";
/**
* Image
* @description The image to color-correct
@@ -781,7 +781,7 @@ export type components = {
a: number;
};
/**
- * CompelInvocation
+ * Compel Prompt
* @description Parse prompt using compel package to conditioning.
*/
CompelInvocation: {
@@ -801,16 +801,16 @@ export type components = {
* @default compel
* @enum {string}
*/
- type?: "compel";
+ type: "compel";
/**
* Prompt
- * @description Prompt
+ * @description Prompt to be parsed by Compel to create a conditioning tensor
* @default
*/
prompt?: string;
/**
- * Clip
- * @description Clip to use
+ * CLIP
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
@@ -827,9 +827,9 @@ export type components = {
type?: "compel_output";
/**
* Conditioning
- * @description Conditioning
+ * @description Conditioning tensor
*/
- conditioning?: components["schemas"]["ConditioningField"];
+ conditioning: components["schemas"]["ConditioningField"];
};
/** ConditioningField */
ConditioningField: {
@@ -840,7 +840,7 @@ export type components = {
conditioning_name: string;
};
/**
- * ContentShuffleImageProcessorInvocation
+ * Content Shuffle Processor
* @description Applies content shuffle processing to image
*/
ContentShuffleImageProcessorInvocation: {
@@ -860,7 +860,7 @@ export type components = {
* @default content_shuffle_image_processor
* @enum {string}
*/
- type?: "content_shuffle_image_processor";
+ type: "content_shuffle_image_processor";
/**
* Image
* @description The image to process
@@ -868,13 +868,13 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
@@ -914,19 +914,19 @@ export type components = {
* @description The weight given to the ControlNet
* @default 1
*/
- control_weight: number | (number)[];
+ control_weight?: number | (number)[];
/**
* Begin Step Percent
* @description When the ControlNet is first applied (% of total steps)
* @default 0
*/
- begin_step_percent: number;
+ begin_step_percent?: number;
/**
* End Step Percent
* @description When the ControlNet is last applied (% of total steps)
* @default 1
*/
- end_step_percent: number;
+ end_step_percent?: number;
/**
* Control Mode
* @description The control mode to use
@@ -943,7 +943,7 @@ export type components = {
resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple";
};
/**
- * ControlNetInvocation
+ * ControlNet
* @description Collects ControlNet info to pass to other nodes
*/
ControlNetInvocation: {
@@ -963,7 +963,7 @@ export type components = {
* @default controlnet
* @enum {string}
*/
- type?: "controlnet";
+ type: "controlnet";
/**
* Image
* @description The control image
@@ -971,7 +971,7 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Control Model
- * @description control model used
+ * @description ControlNet model to load
* @default lllyasviel/sd-controlnet-canny
*/
control_model?: components["schemas"]["ControlNetModelField"];
@@ -1078,9 +1078,9 @@ export type components = {
type?: "control_output";
/**
* Control
- * @description The control info
+ * @description ControlNet(s) to apply
*/
- control?: components["schemas"]["ControlField"];
+ control: components["schemas"]["ControlField"];
};
/**
* CoreMetadata
@@ -1090,7 +1090,7 @@ export type components = {
/**
* App Version
* @description The version of InvokeAI used to generate this image
- * @default 3.0.2
+ * @default 3.0.2post1
*/
app_version?: string;
/**
@@ -1225,7 +1225,7 @@ export type components = {
refiner_start?: number;
};
/**
- * CvInpaintInvocation
+ * OpenCV Inpaint
* @description Simple inpaint using opencv.
*/
CvInpaintInvocation: {
@@ -1245,7 +1245,7 @@ export type components = {
* @default cv_inpaint
* @enum {string}
*/
- type?: "cv_inpaint";
+ type: "cv_inpaint";
/**
* Image
* @description The image to inpaint
@@ -1281,7 +1281,7 @@ export type components = {
deleted_images: (string)[];
};
/**
- * DenoiseLatentsInvocation
+ * Denoise Latents
* @description Denoises noisy latents to decodable images
*/
DenoiseLatentsInvocation: {
@@ -1301,74 +1301,76 @@ export type components = {
* @default denoise_latents
* @enum {string}
*/
- type?: "denoise_latents";
- /**
- * Positive Conditioning
- * @description Positive conditioning for generation
- */
- positive_conditioning?: components["schemas"]["ConditioningField"];
- /**
- * Negative Conditioning
- * @description Negative conditioning for generation
- */
- negative_conditioning?: components["schemas"]["ConditioningField"];
+ type: "denoise_latents";
/**
* Noise
- * @description The noise to use
+ * @description Noise tensor
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
- * @description The number of steps to use to generate the image
+ * @description Number of steps to run
* @default 10
*/
steps?: number;
/**
* Cfg Scale
- * @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
+ * @description Classifier-Free Guidance scale
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Denoising Start
+ * @description When to start denoising, expressed a percentage of total steps
* @default 0
*/
denoising_start?: number;
/**
* Denoising End
+ * @description When to stop denoising, expressed a percentage of total steps
* @default 1
*/
denoising_end?: number;
/**
* Scheduler
- * @description The scheduler to use
+ * @description Scheduler to use during inference
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
- /**
- * Unet
- * @description UNet submodel
- */
- unet?: components["schemas"]["UNetField"];
/**
* Control
- * @description The control to use
+ * @description ControlNet(s) to apply
*/
control?: components["schemas"]["ControlField"] | (components["schemas"]["ControlField"])[];
/**
* Latents
- * @description The latents to use as a base image
+ * @description Latents tensor
*/
latents?: components["schemas"]["LatentsField"];
/**
* Mask
- * @description Mask
+ * @description The mask to use for the operation
*/
mask?: components["schemas"]["ImageField"];
+ /**
+ * Positive Conditioning
+ * @description Positive conditioning tensor
+ */
+ positive_conditioning?: components["schemas"]["ConditioningField"];
+ /**
+ * Negative Conditioning
+ * @description Negative conditioning tensor
+ */
+ negative_conditioning?: components["schemas"]["ConditioningField"];
+ /**
+ * Unet
+ * @description UNet (scheduler, LoRAs)
+ */
+ unet?: components["schemas"]["UNetField"];
};
/**
- * DivideInvocation
+ * Divide Integers
* @description Divides two numbers
*/
DivideInvocation: {
@@ -1388,7 +1390,7 @@ export type components = {
* @default div
* @enum {string}
*/
- type?: "div";
+ type: "div";
/**
* A
* @description The first number
@@ -1403,7 +1405,7 @@ export type components = {
b?: number;
};
/**
- * DynamicPromptInvocation
+ * Dynamic Prompt
* @description Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator
*/
DynamicPromptInvocation: {
@@ -1423,12 +1425,12 @@ export type components = {
* @default dynamic_prompt
* @enum {string}
*/
- type?: "dynamic_prompt";
+ type: "dynamic_prompt";
/**
* Prompt
* @description The prompt to parse with dynamicprompts
*/
- prompt: string;
+ prompt?: string;
/**
* Max Prompts
* @description The number of prompts to generate
@@ -1443,7 +1445,7 @@ export type components = {
combinatorial?: boolean;
};
/**
- * ESRGANInvocation
+ * Upscale (RealESRGAN)
* @description Upscales an image using RealESRGAN.
*/
ESRGANInvocation: {
@@ -1463,7 +1465,7 @@ export type components = {
* @default esrgan
* @enum {string}
*/
- type?: "esrgan";
+ type: "esrgan";
/**
* Image
* @description The input image
@@ -1510,10 +1512,10 @@ export type components = {
FloatCollectionOutput: {
/**
* Type
- * @default float_collection
+ * @default float_collection_output
* @enum {string}
*/
- type?: "float_collection";
+ type?: "float_collection_output";
/**
* Collection
* @description The float collection
@@ -1522,7 +1524,7 @@ export type components = {
collection?: (number)[];
};
/**
- * FloatLinearRangeInvocation
+ * Float Range
* @description Creates a range
*/
FloatLinearRangeInvocation: {
@@ -1542,7 +1544,7 @@ export type components = {
* @default float_range
* @enum {string}
*/
- type?: "float_range";
+ type: "float_range";
/**
* Start
* @description The first value of the range
@@ -1574,10 +1576,10 @@ export type components = {
*/
type?: "float_output";
/**
- * Param
+ * A
* @description The output float
*/
- param?: number;
+ a?: number;
};
/** Graph */
Graph: {
@@ -1591,7 +1593,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
- [key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"]) | undefined;
+ [key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"]) | undefined;
};
/**
* Edges
@@ -1634,7 +1636,7 @@ export type components = {
* @description The results of node executions
*/
results: {
- [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
+ [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
};
/**
* Errors
@@ -1679,7 +1681,7 @@ export type components = {
* @default graph
* @enum {string}
*/
- type?: "graph";
+ type: "graph";
/**
* Graph
* @description The graph to run
@@ -1704,7 +1706,7 @@ export type components = {
detail?: (components["schemas"]["ValidationError"])[];
};
/**
- * HedImageProcessorInvocation
+ * HED (softedge) Processor
* @description Applies HED edge detection to image
*/
HedImageProcessorInvocation: {
@@ -1724,7 +1726,7 @@ export type components = {
* @default hed_image_processor
* @enum {string}
*/
- type?: "hed_image_processor";
+ type: "hed_image_processor";
/**
* Image
* @description The image to process
@@ -1732,25 +1734,25 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
/**
* Scribble
- * @description Whether to use scribble mode
+ * @description Whether or not to use scribble mode
* @default false
*/
scribble?: boolean;
};
/**
- * ImageBlurInvocation
+ * Blur Image
* @description Blurs an image
*/
ImageBlurInvocation: {
@@ -1770,7 +1772,7 @@ export type components = {
* @default img_blur
* @enum {string}
*/
- type?: "img_blur";
+ type: "img_blur";
/**
* Image
* @description The image to blur
@@ -1803,7 +1805,7 @@ export type components = {
*/
ImageCategory: "general" | "mask" | "control" | "user" | "other";
/**
- * ImageChannelInvocation
+ * Extract Image Channel
* @description Gets a channel from an image.
*/
ImageChannelInvocation: {
@@ -1823,7 +1825,7 @@ export type components = {
* @default img_chan
* @enum {string}
*/
- type?: "img_chan";
+ type: "img_chan";
/**
* Image
* @description The image to get the channel from
@@ -1838,7 +1840,7 @@ export type components = {
channel?: "A" | "R" | "G" | "B";
};
/**
- * ImageCollectionInvocation
+ * Image Collection
* @description Load a collection of images and provide it as output.
*/
ImageCollectionInvocation: {
@@ -1858,7 +1860,7 @@ export type components = {
* @default image_collection
* @enum {string}
*/
- type?: "image_collection";
+ type: "image_collection";
/**
* Images
* @description The image collection to load
@@ -1873,19 +1875,19 @@ export type components = {
ImageCollectionOutput: {
/**
* Type
- * @default image_collection
+ * @default image_collection_output
* @enum {string}
*/
- type: "image_collection";
+ type?: "image_collection_output";
/**
* Collection
* @description The output images
* @default []
*/
- collection: (components["schemas"]["ImageField"])[];
+ collection?: (components["schemas"]["ImageField"])[];
};
/**
- * ImageConvertInvocation
+ * Convert Image Mode
* @description Converts an image to a different mode.
*/
ImageConvertInvocation: {
@@ -1905,7 +1907,7 @@ export type components = {
* @default img_conv
* @enum {string}
*/
- type?: "img_conv";
+ type: "img_conv";
/**
* Image
* @description The image to convert
@@ -1920,7 +1922,7 @@ export type components = {
mode?: "L" | "RGB" | "RGBA" | "CMYK" | "YCbCr" | "LAB" | "HSV" | "I" | "F";
};
/**
- * ImageCropInvocation
+ * Crop Image
* @description Crops an image to a specified box. The box can be outside of the image.
*/
ImageCropInvocation: {
@@ -1940,7 +1942,7 @@ export type components = {
* @default img_crop
* @enum {string}
*/
- type?: "img_crop";
+ type: "img_crop";
/**
* Image
* @description The image to crop
@@ -2053,7 +2055,7 @@ export type components = {
image_name: string;
};
/**
- * ImageHueAdjustmentInvocation
+ * Image Hue Adjustment
* @description Adjusts the Hue of an image.
*/
ImageHueAdjustmentInvocation: {
@@ -2073,7 +2075,7 @@ export type components = {
* @default img_hue_adjust
* @enum {string}
*/
- type?: "img_hue_adjust";
+ type: "img_hue_adjust";
/**
* Image
* @description The image to adjust
@@ -2087,7 +2089,7 @@ export type components = {
hue?: number;
};
/**
- * ImageInverseLerpInvocation
+ * Inverse Lerp Image
* @description Inverse linear interpolation of all pixels of an image
*/
ImageInverseLerpInvocation: {
@@ -2107,7 +2109,7 @@ export type components = {
* @default img_ilerp
* @enum {string}
*/
- type?: "img_ilerp";
+ type: "img_ilerp";
/**
* Image
* @description The image to lerp
@@ -2127,7 +2129,7 @@ export type components = {
max?: number;
};
/**
- * ImageLerpInvocation
+ * Lerp Image
* @description Linear interpolation of all pixels of an image
*/
ImageLerpInvocation: {
@@ -2147,7 +2149,7 @@ export type components = {
* @default img_lerp
* @enum {string}
*/
- type?: "img_lerp";
+ type: "img_lerp";
/**
* Image
* @description The image to lerp
@@ -2167,7 +2169,7 @@ export type components = {
max?: number;
};
/**
- * ImageLuminosityAdjustmentInvocation
+ * Image Luminosity Adjustment
* @description Adjusts the Luminosity (Value) of an image.
*/
ImageLuminosityAdjustmentInvocation: {
@@ -2187,7 +2189,7 @@ export type components = {
* @default img_luminosity_adjust
* @enum {string}
*/
- type?: "img_luminosity_adjust";
+ type: "img_luminosity_adjust";
/**
* Image
* @description The image to adjust
@@ -2217,7 +2219,7 @@ export type components = {
graph?: Record;
};
/**
- * ImageMultiplyInvocation
+ * Multiply Images
* @description Multiplies two images together using `PIL.ImageChops.multiply()`.
*/
ImageMultiplyInvocation: {
@@ -2237,7 +2239,7 @@ export type components = {
* @default img_mul
* @enum {string}
*/
- type?: "img_mul";
+ type: "img_mul";
/**
* Image1
* @description The first image to multiply
@@ -2250,7 +2252,7 @@ export type components = {
image2?: components["schemas"]["ImageField"];
};
/**
- * ImageNSFWBlurInvocation
+ * Blur NSFW Image
* @description Add blur to NSFW-flagged images
*/
ImageNSFWBlurInvocation: {
@@ -2270,17 +2272,17 @@ export type components = {
* @default img_nsfw
* @enum {string}
*/
- type?: "img_nsfw";
+ type: "img_nsfw";
+ /**
+ * Metadata
+ * @description Optional core metadata to be written to image
+ */
+ metadata?: components["schemas"]["CoreMetadata"];
/**
* Image
* @description The image to check
*/
image?: components["schemas"]["ImageField"];
- /**
- * Metadata
- * @description Optional core metadata to be written to the image
- */
- metadata?: components["schemas"]["CoreMetadata"];
};
/**
* ImageOutput
@@ -2310,7 +2312,7 @@ export type components = {
height: number;
};
/**
- * ImagePasteInvocation
+ * Paste Image
* @description Pastes an image into another image.
*/
ImagePasteInvocation: {
@@ -2330,7 +2332,7 @@ export type components = {
* @default img_paste
* @enum {string}
*/
- type?: "img_paste";
+ type: "img_paste";
/**
* Base Image
* @description The base image
@@ -2380,7 +2382,7 @@ export type components = {
* @default image_processor
* @enum {string}
*/
- type?: "image_processor";
+ type: "image_processor";
/**
* Image
* @description The image to process
@@ -2411,7 +2413,7 @@ export type components = {
is_intermediate?: boolean;
};
/**
- * ImageResizeInvocation
+ * Resize Image
* @description Resizes an image to specific dimensions
*/
ImageResizeInvocation: {
@@ -2431,7 +2433,7 @@ export type components = {
* @default img_resize
* @enum {string}
*/
- type?: "img_resize";
+ type: "img_resize";
/**
* Image
* @description The image to resize
@@ -2439,12 +2441,14 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Width
- * @description The width to resize to (px)
+ * @description The width to resize to (px)
+ * @default 512
*/
width?: number;
/**
* Height
- * @description The height to resize to (px)
+ * @description The height to resize to (px)
+ * @default 512
*/
height?: number;
/**
@@ -2456,7 +2460,7 @@ export type components = {
resample_mode?: "nearest" | "box" | "bilinear" | "hamming" | "bicubic" | "lanczos";
};
/**
- * ImageSaturationAdjustmentInvocation
+ * Image Saturation Adjustment
* @description Adjusts the Saturation of an image.
*/
ImageSaturationAdjustmentInvocation: {
@@ -2476,7 +2480,7 @@ export type components = {
* @default img_saturation_adjust
* @enum {string}
*/
- type?: "img_saturation_adjust";
+ type: "img_saturation_adjust";
/**
* Image
* @description The image to adjust
@@ -2490,7 +2494,7 @@ export type components = {
saturation?: number;
};
/**
- * ImageScaleInvocation
+ * Scale Image
* @description Scales an image by a factor
*/
ImageScaleInvocation: {
@@ -2510,7 +2514,7 @@ export type components = {
* @default img_scale
* @enum {string}
*/
- type?: "img_scale";
+ type: "img_scale";
/**
* Image
* @description The image to scale
@@ -2531,7 +2535,7 @@ export type components = {
resample_mode?: "nearest" | "box" | "bilinear" | "hamming" | "bicubic" | "lanczos";
};
/**
- * ImageToLatentsInvocation
+ * Image to Latents
* @description Encodes an image into latents.
*/
ImageToLatentsInvocation: {
@@ -2551,7 +2555,7 @@ export type components = {
* @default i2l
* @enum {string}
*/
- type?: "i2l";
+ type: "i2l";
/**
* Image
* @description The image to encode
@@ -2559,18 +2563,18 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Vae
- * @description Vae submodel
+ * @description VAE
*/
vae?: components["schemas"]["VaeField"];
/**
* Tiled
- * @description Encode latents by overlaping tiles(less memory consumption)
+ * @description Processing using overlapping tiles (reduce memory consumption)
* @default false
*/
tiled?: boolean;
/**
* Fp32
- * @description Decode in full precision
+ * @description Whether or not to use full float32 precision
* @default false
*/
fp32?: boolean;
@@ -2597,7 +2601,7 @@ export type components = {
thumbnail_url: string;
};
/**
- * ImageWatermarkInvocation
+ * Add Invisible Watermark
* @description Add an invisible watermark to an image
*/
ImageWatermarkInvocation: {
@@ -2617,7 +2621,7 @@ export type components = {
* @default img_watermark
* @enum {string}
*/
- type?: "img_watermark";
+ type: "img_watermark";
/**
* Image
* @description The image to check
@@ -2631,12 +2635,12 @@ export type components = {
text?: string;
/**
* Metadata
- * @description Optional core metadata to be written to the image
+ * @description Optional core metadata to be written to image
*/
metadata?: components["schemas"]["CoreMetadata"];
};
/**
- * InfillColorInvocation
+ * Solid Color Infill
* @description Infills transparent areas of an image with a solid color
*/
InfillColorInvocation: {
@@ -2656,7 +2660,7 @@ export type components = {
* @default infill_rgba
* @enum {string}
*/
- type?: "infill_rgba";
+ type: "infill_rgba";
/**
* Image
* @description The image to infill
@@ -2675,7 +2679,7 @@ export type components = {
color?: components["schemas"]["ColorField"];
};
/**
- * InfillPatchMatchInvocation
+ * PatchMatch Infill
* @description Infills transparent areas of an image using the PatchMatch algorithm
*/
InfillPatchMatchInvocation: {
@@ -2695,7 +2699,7 @@ export type components = {
* @default infill_patchmatch
* @enum {string}
*/
- type?: "infill_patchmatch";
+ type: "infill_patchmatch";
/**
* Image
* @description The image to infill
@@ -2703,7 +2707,7 @@ export type components = {
image?: components["schemas"]["ImageField"];
};
/**
- * InfillTileInvocation
+ * Tile Infill
* @description Infills transparent areas of an image with tiles of the image
*/
InfillTileInvocation: {
@@ -2723,7 +2727,7 @@ export type components = {
* @default infill_tile
* @enum {string}
*/
- type?: "infill_tile";
+ type: "infill_tile";
/**
* Image
* @description The image to infill
@@ -2748,10 +2752,10 @@ export type components = {
IntCollectionOutput: {
/**
* Type
- * @default int_collection
+ * @default int_collection_output
* @enum {string}
*/
- type?: "int_collection";
+ type?: "int_collection_output";
/**
* Collection
* @description The int collection
@@ -2797,7 +2801,7 @@ export type components = {
* @default iterate
* @enum {string}
*/
- type?: "iterate";
+ type: "iterate";
/**
* Collection
* @description The list of items to iterate over
@@ -2820,12 +2824,12 @@ export type components = {
* @default iterate_output
* @enum {string}
*/
- type: "iterate_output";
+ type?: "iterate_output";
/**
- * Item
+ * Collection Item
* @description The item being iterated over
*/
- item: unknown;
+ item?: unknown;
};
/**
* LatentsField
@@ -2856,22 +2860,22 @@ export type components = {
type?: "latents_output";
/**
* Latents
- * @description The output latents
+ * @description Latents tensor
*/
- latents?: components["schemas"]["LatentsField"];
+ latents: components["schemas"]["LatentsField"];
/**
* Width
- * @description The width of the latents in pixels
+ * @description Width of output (px)
*/
width: number;
/**
* Height
- * @description The height of the latents in pixels
+ * @description Height of output (px)
*/
height: number;
};
/**
- * LatentsToImageInvocation
+ * Latents to Image
* @description Generates an image from latents.
*/
LatentsToImageInvocation: {
@@ -2891,37 +2895,37 @@ export type components = {
* @default l2i
* @enum {string}
*/
- type?: "l2i";
- /**
- * Latents
- * @description The latents to generate an image from
- */
- latents?: components["schemas"]["LatentsField"];
- /**
- * Vae
- * @description Vae submodel
- */
- vae?: components["schemas"]["VaeField"];
+ type: "l2i";
/**
* Tiled
- * @description Decode latents by overlaping tiles (less memory consumption)
+ * @description Processing using overlapping tiles (reduce memory consumption)
* @default false
*/
tiled?: boolean;
/**
* Fp32
- * @description Decode in full precision
+ * @description Whether or not to use full float32 precision
* @default false
*/
fp32?: boolean;
/**
* Metadata
- * @description Optional core metadata to be written to the image
+ * @description Optional core metadata to be written to image
*/
metadata?: components["schemas"]["CoreMetadata"];
+ /**
+ * Latents
+ * @description Latents tensor
+ */
+ latents?: components["schemas"]["LatentsField"];
+ /**
+ * Vae
+ * @description VAE
+ */
+ vae?: components["schemas"]["VaeField"];
};
/**
- * LeresImageProcessorInvocation
+ * Leres (Depth) Processor
* @description Applies leres processing to image
*/
LeresImageProcessorInvocation: {
@@ -2941,7 +2945,7 @@ export type components = {
* @default leres_image_processor
* @enum {string}
*/
- type?: "leres_image_processor";
+ type: "leres_image_processor";
/**
* Image
* @description The image to process
@@ -2967,19 +2971,19 @@ export type components = {
boost?: boolean;
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
};
/**
- * LineartAnimeImageProcessorInvocation
+ * Lineart Anime Processor
* @description Applies line art anime processing to image
*/
LineartAnimeImageProcessorInvocation: {
@@ -2999,7 +3003,7 @@ export type components = {
* @default lineart_anime_image_processor
* @enum {string}
*/
- type?: "lineart_anime_image_processor";
+ type: "lineart_anime_image_processor";
/**
* Image
* @description The image to process
@@ -3007,19 +3011,19 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
};
/**
- * LineartImageProcessorInvocation
+ * Lineart Processor
* @description Applies line art processing to image
*/
LineartImageProcessorInvocation: {
@@ -3039,7 +3043,7 @@ export type components = {
* @default lineart_image_processor
* @enum {string}
*/
- type?: "lineart_image_processor";
+ type: "lineart_image_processor";
/**
* Image
* @description The image to process
@@ -3047,13 +3051,13 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
@@ -3117,7 +3121,7 @@ export type components = {
*/
LoRAModelFormat: "lycoris" | "diffusers";
/**
- * LoadImageInvocation
+ * Load Image
* @description Load an image and provide it as output.
*/
LoadImageInvocation: {
@@ -3137,7 +3141,7 @@ export type components = {
* @default load_image
* @enum {string}
*/
- type?: "load_image";
+ type: "load_image";
/**
* Image
* @description The image to load
@@ -3170,7 +3174,7 @@ export type components = {
weight: number;
};
/**
- * LoraLoaderInvocation
+ * LoRA Loader
* @description Apply selected lora to unet and text_encoder.
*/
LoraLoaderInvocation: {
@@ -3190,26 +3194,26 @@ export type components = {
* @default lora_loader
* @enum {string}
*/
- type?: "lora_loader";
+ type: "lora_loader";
/**
- * Lora
- * @description Lora model name
+ * LoRA
+ * @description LoRA model to load
*/
- lora?: components["schemas"]["LoRAModelField"];
+ lora: components["schemas"]["LoRAModelField"];
/**
* Weight
- * @description With what weight to apply lora
+ * @description The weight at which the LoRA is applied to each model
* @default 0.75
*/
weight?: number;
/**
- * Unet
- * @description UNet model for applying lora
+ * UNet
+ * @description UNet (scheduler, LoRAs)
*/
unet?: components["schemas"]["UNetField"];
/**
- * Clip
- * @description Clip model for applying lora
+ * CLIP
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
@@ -3225,13 +3229,13 @@ export type components = {
*/
type?: "lora_loader_output";
/**
- * Unet
- * @description UNet submodel
+ * UNet
+ * @description UNet (scheduler, LoRAs)
*/
unet?: components["schemas"]["UNetField"];
/**
- * Clip
- * @description Tokenizer and text_encoder submodels
+ * CLIP
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
@@ -3251,7 +3255,7 @@ export type components = {
model_type: components["schemas"]["ModelType"];
};
/**
- * MainModelLoaderInvocation
+ * Main Model Loader
* @description Loads a main model, outputting its submodels.
*/
MainModelLoaderInvocation: {
@@ -3271,15 +3275,15 @@ export type components = {
* @default main_model_loader
* @enum {string}
*/
- type?: "main_model_loader";
+ type: "main_model_loader";
/**
* Model
- * @description The model to load
+ * @description Main model (UNet, VAE, CLIP) to load
*/
model: components["schemas"]["MainModelField"];
};
/**
- * MaskCombineInvocation
+ * Combine Mask
* @description Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.
*/
MaskCombineInvocation: {
@@ -3299,7 +3303,7 @@ export type components = {
* @default mask_combine
* @enum {string}
*/
- type?: "mask_combine";
+ type: "mask_combine";
/**
* Mask1
* @description The first mask to combine
@@ -3312,7 +3316,7 @@ export type components = {
mask2?: components["schemas"]["ImageField"];
};
/**
- * MaskEdgeInvocation
+ * Mask Edge
* @description Applies an edge mask to an image
*/
MaskEdgeInvocation: {
@@ -3332,7 +3336,7 @@ export type components = {
* @default mask_edge
* @enum {string}
*/
- type?: "mask_edge";
+ type: "mask_edge";
/**
* Image
* @description The image to apply the mask to
@@ -3342,25 +3346,25 @@ export type components = {
* Edge Size
* @description The size of the edge
*/
- edge_size: number;
+ edge_size?: number;
/**
* Edge Blur
* @description The amount of blur on the edge
*/
- edge_blur: number;
+ edge_blur?: number;
/**
* Low Threshold
* @description First threshold for the hysteresis procedure in Canny edge detection
*/
- low_threshold: number;
+ low_threshold?: number;
/**
* High Threshold
* @description Second threshold for the hysteresis procedure in Canny edge detection
*/
- high_threshold: number;
+ high_threshold?: number;
};
/**
- * MaskFromAlphaInvocation
+ * Mask from Alpha
* @description Extracts the alpha channel of an image as a mask.
*/
MaskFromAlphaInvocation: {
@@ -3380,7 +3384,7 @@ export type components = {
* @default tomask
* @enum {string}
*/
- type?: "tomask";
+ type: "tomask";
/**
* Image
* @description The image to create the mask from
@@ -3421,7 +3425,7 @@ export type components = {
height?: number;
};
/**
- * MediapipeFaceProcessorInvocation
+ * Mediapipe Face Processor
* @description Applies mediapipe face processing to image
*/
MediapipeFaceProcessorInvocation: {
@@ -3441,7 +3445,7 @@ export type components = {
* @default mediapipe_face_processor
* @enum {string}
*/
- type?: "mediapipe_face_processor";
+ type: "mediapipe_face_processor";
/**
* Image
* @description The image to process
@@ -3467,7 +3471,7 @@ export type components = {
*/
MergeInterpolationMethod: "weighted_sum" | "sigmoid" | "inv_sigmoid" | "add_difference";
/**
- * MetadataAccumulatorInvocation
+ * Metadata Accumulator
* @description Outputs a Core Metadata Object
*/
MetadataAccumulatorInvocation: {
@@ -3487,77 +3491,77 @@ export type components = {
* @default metadata_accumulator
* @enum {string}
*/
- type?: "metadata_accumulator";
+ type: "metadata_accumulator";
/**
* Generation Mode
* @description The generation mode that output this image
*/
- generation_mode: string;
+ generation_mode?: string;
/**
* Positive Prompt
* @description The positive prompt parameter
*/
- positive_prompt: string;
+ positive_prompt?: string;
/**
* Negative Prompt
* @description The negative prompt parameter
*/
- negative_prompt: string;
+ negative_prompt?: string;
/**
* Width
* @description The width parameter
*/
- width: number;
+ width?: number;
/**
* Height
* @description The height parameter
*/
- height: number;
+ height?: number;
/**
* Seed
* @description The seed used for noise generation
*/
- seed: number;
+ seed?: number;
/**
* Rand Device
* @description The device used for random number generation
*/
- rand_device: string;
+ rand_device?: string;
/**
* Cfg Scale
* @description The classifier-free guidance scale parameter
*/
- cfg_scale: number;
+ cfg_scale?: number;
/**
* Steps
* @description The number of steps used for inference
*/
- steps: number;
+ steps?: number;
/**
* Scheduler
* @description The scheduler used for inference
*/
- scheduler: string;
+ scheduler?: string;
/**
* Clip Skip
* @description The number of skipped CLIP layers
*/
- clip_skip: number;
+ clip_skip?: number;
/**
* Model
* @description The main model used for inference
*/
- model: components["schemas"]["MainModelField"];
+ model?: components["schemas"]["MainModelField"];
/**
* Controlnets
* @description The ControlNets used for inference
*/
- controlnets: (components["schemas"]["ControlField"])[];
+ controlnets?: (components["schemas"]["ControlField"])[];
/**
* Loras
* @description The LoRAs used for inference
*/
- loras: (components["schemas"]["LoRAMetadataField"])[];
+ loras?: (components["schemas"]["LoRAMetadataField"])[];
/**
* Strength
* @description The strength used for latents-to-latents
@@ -3604,15 +3608,15 @@ export type components = {
*/
refiner_scheduler?: string;
/**
- * Refiner Positive Aesthetic Score
+ * Refiner Positive Aesthetic Store
* @description The aesthetic score used for the refiner
*/
- refiner_positive_aesthetic_score?: number;
+ refiner_positive_aesthetic_store?: number;
/**
- * Refiner Negative Aesthetic Score
+ * Refiner Negative Aesthetic Store
* @description The aesthetic score used for the refiner
*/
- refiner_negative_aesthetic_score?: number;
+ refiner_negative_aesthetic_store?: number;
/**
* Refiner Start
* @description The start value used for refiner denoising
@@ -3637,7 +3641,7 @@ export type components = {
metadata: components["schemas"]["CoreMetadata"];
};
/**
- * MidasDepthImageProcessorInvocation
+ * Midas (Depth) Processor
* @description Applies Midas depth processing to image
*/
MidasDepthImageProcessorInvocation: {
@@ -3657,7 +3661,7 @@ export type components = {
* @default midas_depth_image_processor
* @enum {string}
*/
- type?: "midas_depth_image_processor";
+ type: "midas_depth_image_processor";
/**
* Image
* @description The image to process
@@ -3677,7 +3681,7 @@ export type components = {
bg_th?: number;
};
/**
- * MlsdImageProcessorInvocation
+ * MLSD Processor
* @description Applies MLSD processing to image
*/
MlsdImageProcessorInvocation: {
@@ -3697,7 +3701,7 @@ export type components = {
* @default mlsd_image_processor
* @enum {string}
*/
- type?: "mlsd_image_processor";
+ type: "mlsd_image_processor";
/**
* Image
* @description The image to process
@@ -3705,13 +3709,13 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
@@ -3760,20 +3764,20 @@ export type components = {
*/
type?: "model_loader_output";
/**
- * Unet
- * @description UNet submodel
+ * UNet
+ * @description UNet (scheduler, LoRAs)
*/
- unet?: components["schemas"]["UNetField"];
+ unet: components["schemas"]["UNetField"];
/**
- * Clip
- * @description Tokenizer and text_encoder submodels
+ * CLIP
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
- clip?: components["schemas"]["ClipField"];
+ clip: components["schemas"]["ClipField"];
/**
- * Vae
- * @description Vae submodel
+ * VAE
+ * @description VAE
*/
- vae?: components["schemas"]["VaeField"];
+ vae: components["schemas"]["VaeField"];
};
/**
* ModelType
@@ -3793,7 +3797,7 @@ export type components = {
models: (components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[];
};
/**
- * MultiplyInvocation
+ * Multiply Integers
* @description Multiplies two numbers
*/
MultiplyInvocation: {
@@ -3813,7 +3817,7 @@ export type components = {
* @default mul
* @enum {string}
*/
- type?: "mul";
+ type: "mul";
/**
* A
* @description The first number
@@ -3828,7 +3832,7 @@ export type components = {
b?: number;
};
/**
- * NoiseInvocation
+ * Noise
* @description Generates latent noise.
*/
NoiseInvocation: {
@@ -3848,21 +3852,21 @@ export type components = {
* @default noise
* @enum {string}
*/
- type?: "noise";
+ type: "noise";
/**
* Seed
- * @description The seed to use
+ * @description Seed for random number generation
*/
seed?: number;
/**
* Width
- * @description The width of the resulting noise
+ * @description Width of output (px)
* @default 512
*/
width?: number;
/**
* Height
- * @description The height of the resulting noise
+ * @description Height of output (px)
* @default 512
*/
height?: number;
@@ -3886,22 +3890,22 @@ export type components = {
type?: "noise_output";
/**
* Noise
- * @description The output noise
+ * @description Noise tensor
*/
noise?: components["schemas"]["LatentsField"];
/**
* Width
- * @description The width of the noise in pixels
+ * @description Width of output (px)
*/
width: number;
/**
* Height
- * @description The height of the noise in pixels
+ * @description Height of output (px)
*/
height: number;
};
/**
- * NormalbaeImageProcessorInvocation
+ * Normal BAE Processor
* @description Applies NormalBae processing to image
*/
NormalbaeImageProcessorInvocation: {
@@ -3921,7 +3925,7 @@ export type components = {
* @default normalbae_image_processor
* @enum {string}
*/
- type?: "normalbae_image_processor";
+ type: "normalbae_image_processor";
/**
* Image
* @description The image to process
@@ -3929,19 +3933,19 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
};
/**
- * ONNXLatentsToImageInvocation
+ * ONNX Latents to Image
* @description Generates an image from latents.
*/
ONNXLatentsToImageInvocation: {
@@ -3961,20 +3965,20 @@ export type components = {
* @default l2i_onnx
* @enum {string}
*/
- type?: "l2i_onnx";
+ type: "l2i_onnx";
/**
* Latents
- * @description The latents to generate an image from
+ * @description Denoised latents tensor
*/
latents?: components["schemas"]["LatentsField"];
/**
* Vae
- * @description Vae submodel
+ * @description VAE
*/
vae?: components["schemas"]["VaeField"];
/**
* Metadata
- * @description Optional core metadata to be written to the image
+ * @description Optional core metadata to be written to image
*/
metadata?: components["schemas"]["CoreMetadata"];
};
@@ -3990,28 +3994,28 @@ export type components = {
*/
type?: "model_loader_output_onnx";
/**
- * Unet
- * @description UNet submodel
+ * UNet
+ * @description UNet (scheduler, LoRAs)
*/
unet?: components["schemas"]["UNetField"];
/**
- * Clip
- * @description Tokenizer and text_encoder submodels
+ * CLIP
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
- * Vae Decoder
- * @description Vae submodel
+ * VAE Decoder
+ * @description VAE
*/
vae_decoder?: components["schemas"]["VaeField"];
/**
- * Vae Encoder
- * @description Vae submodel
+ * VAE Encoder
+ * @description VAE
*/
vae_encoder?: components["schemas"]["VaeField"];
};
/**
- * ONNXPromptInvocation
+ * ONNX Prompt (Raw)
* @description A node to process inputs and produce outputs.
* May use dependency injection in __init__ to receive providers.
*/
@@ -4032,48 +4036,19 @@ export type components = {
* @default prompt_onnx
* @enum {string}
*/
- type?: "prompt_onnx";
+ type: "prompt_onnx";
/**
* Prompt
- * @description Prompt
+ * @description Raw prompt text (no parsing)
* @default
*/
prompt?: string;
/**
* Clip
- * @description Clip to use
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
- /**
- * ONNXSD1ModelLoaderInvocation
- * @description Loading submodels of selected model.
- */
- ONNXSD1ModelLoaderInvocation: {
- /**
- * Id
- * @description The id of this node. Must be unique among all nodes.
- */
- id: string;
- /**
- * Is Intermediate
- * @description Whether or not this node is an intermediate node.
- * @default false
- */
- is_intermediate?: boolean;
- /**
- * Type
- * @default sd1_model_loader_onnx
- * @enum {string}
- */
- type?: "sd1_model_loader_onnx";
- /**
- * Model Name
- * @description Model to load
- * @default
- */
- model_name?: string;
- };
/** ONNXStableDiffusion1ModelConfig */
ONNXStableDiffusion1ModelConfig: {
/** Model Name */
@@ -4122,7 +4097,7 @@ export type components = {
upcast_attention: boolean;
};
/**
- * ONNXTextToLatentsInvocation
+ * ONNX Text to Latents
* @description Generates latents from conditionings.
*/
ONNXTextToLatentsInvocation: {
@@ -4142,56 +4117,56 @@ export type components = {
* @default t2l_onnx
* @enum {string}
*/
- type?: "t2l_onnx";
+ type: "t2l_onnx";
/**
* Positive Conditioning
- * @description Positive conditioning for generation
+ * @description Positive conditioning tensor
*/
positive_conditioning?: components["schemas"]["ConditioningField"];
/**
* Negative Conditioning
- * @description Negative conditioning for generation
+ * @description Negative conditioning tensor
*/
negative_conditioning?: components["schemas"]["ConditioningField"];
/**
* Noise
- * @description The noise to use
+ * @description Noise tensor
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
- * @description The number of steps to use to generate the image
+ * @description Number of steps to run
* @default 10
*/
steps?: number;
/**
* Cfg Scale
- * @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
+ * @description Classifier-Free Guidance scale
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Scheduler
- * @description The scheduler to use
+ * @description Scheduler to use during inference
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
/**
* Precision
- * @description The precision to use when generating latents
+ * @description Precision to use
* @default tensor(float16)
* @enum {string}
*/
precision?: "tensor(bool)" | "tensor(int8)" | "tensor(uint8)" | "tensor(int16)" | "tensor(uint16)" | "tensor(int32)" | "tensor(uint32)" | "tensor(int64)" | "tensor(uint64)" | "tensor(float16)" | "tensor(float)" | "tensor(double)";
/**
* Unet
- * @description UNet submodel
+ * @description UNet (scheduler, LoRAs)
*/
unet?: components["schemas"]["UNetField"];
/**
* Control
- * @description The control to use
+ * @description ControlNet(s) to apply
*/
control?: components["schemas"]["ControlField"] | (components["schemas"]["ControlField"])[];
};
@@ -4263,7 +4238,7 @@ export type components = {
model_type: components["schemas"]["ModelType"];
};
/**
- * OnnxModelLoaderInvocation
+ * ONNX Model Loader
* @description Loads a main model, outputting its submodels.
*/
OnnxModelLoaderInvocation: {
@@ -4283,15 +4258,15 @@ export type components = {
* @default onnx_model_loader
* @enum {string}
*/
- type?: "onnx_model_loader";
+ type: "onnx_model_loader";
/**
* Model
- * @description The model to load
+ * @description ONNX Main model (UNet, VAE, CLIP) to load
*/
model: components["schemas"]["OnnxModelField"];
};
/**
- * OpenposeImageProcessorInvocation
+ * Openpose Processor
* @description Applies Openpose processing to image
*/
OpenposeImageProcessorInvocation: {
@@ -4311,7 +4286,7 @@ export type components = {
* @default openpose_image_processor
* @enum {string}
*/
- type?: "openpose_image_processor";
+ type: "openpose_image_processor";
/**
* Image
* @description The image to process
@@ -4325,13 +4300,13 @@ export type components = {
hand_and_face?: boolean;
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
@@ -4368,7 +4343,7 @@ export type components = {
total: number;
};
/**
- * ParamFloatInvocation
+ * Float Parameter
* @description A float parameter
*/
ParamFloatInvocation: {
@@ -4388,7 +4363,7 @@ export type components = {
* @default param_float
* @enum {string}
*/
- type?: "param_float";
+ type: "param_float";
/**
* Param
* @description The float value
@@ -4397,7 +4372,7 @@ export type components = {
param?: number;
};
/**
- * ParamIntInvocation
+ * Integer Parameter
* @description An integer parameter
*/
ParamIntInvocation: {
@@ -4417,7 +4392,7 @@ export type components = {
* @default param_int
* @enum {string}
*/
- type?: "param_int";
+ type: "param_int";
/**
* A
* @description The integer value
@@ -4426,7 +4401,7 @@ export type components = {
a?: number;
};
/**
- * ParamPromptInvocation
+ * Prompt Parameter
* @description A prompt input parameter
*/
ParamPromptInvocation: {
@@ -4446,7 +4421,7 @@ export type components = {
* @default param_prompt
* @enum {string}
*/
- type?: "param_prompt";
+ type: "param_prompt";
/**
* Prompt
* @description The prompt value
@@ -4455,7 +4430,7 @@ export type components = {
prompt?: string;
};
/**
- * ParamStringInvocation
+ * String Parameter
* @description A string parameter
*/
ParamStringInvocation: {
@@ -4475,7 +4450,7 @@ export type components = {
* @default param_string
* @enum {string}
*/
- type?: "param_string";
+ type: "param_string";
/**
* Text
* @description The string value
@@ -4484,7 +4459,7 @@ export type components = {
text?: string;
};
/**
- * PidiImageProcessorInvocation
+ * PIDI Processor
* @description Applies PIDI processing to image
*/
PidiImageProcessorInvocation: {
@@ -4504,7 +4479,7 @@ export type components = {
* @default pidi_image_processor
* @enum {string}
*/
- type?: "pidi_image_processor";
+ type: "pidi_image_processor";
/**
* Image
* @description The image to process
@@ -4512,25 +4487,25 @@ export type components = {
image?: components["schemas"]["ImageField"];
/**
* Detect Resolution
- * @description The pixel resolution for detection
+ * @description Pixel resolution for detection
* @default 512
*/
detect_resolution?: number;
/**
* Image Resolution
- * @description The pixel resolution for the output image
+ * @description Pixel resolution for output image
* @default 512
*/
image_resolution?: number;
/**
* Safe
- * @description Whether to use safe mode
+ * @description Whether or not to use safe mode
* @default false
*/
safe?: boolean;
/**
* Scribble
- * @description Whether to use scribble mode
+ * @description Whether or not to use scribble mode
* @default false
*/
scribble?: boolean;
@@ -4545,7 +4520,7 @@ export type components = {
* @default prompt_collection_output
* @enum {string}
*/
- type: "prompt_collection_output";
+ type?: "prompt_collection_output";
/**
* Prompt Collection
* @description The output prompt collection
@@ -4567,7 +4542,7 @@ export type components = {
* @default prompt
* @enum {string}
*/
- type: "prompt";
+ type?: "prompt";
/**
* Prompt
* @description The output prompt
@@ -4575,7 +4550,7 @@ export type components = {
prompt: string;
};
/**
- * PromptsFromFileInvocation
+ * Prompts from File
* @description Loads prompts from a text file
*/
PromptsFromFileInvocation: {
@@ -4595,12 +4570,12 @@ export type components = {
* @default prompt_from_file
* @enum {string}
*/
- type?: "prompt_from_file";
+ type: "prompt_from_file";
/**
* File Path
* @description Path to prompt text file
*/
- file_path: string;
+ file_path?: string;
/**
* Pre Prompt
* @description String to prepend to each prompt
@@ -4625,7 +4600,7 @@ export type components = {
max_prompts?: number;
};
/**
- * RandomIntInvocation
+ * Random Integer
* @description Outputs a single random integer.
*/
RandomIntInvocation: {
@@ -4645,7 +4620,7 @@ export type components = {
* @default rand_int
* @enum {string}
*/
- type?: "rand_int";
+ type: "rand_int";
/**
* Low
* @description The inclusive low value
@@ -4660,7 +4635,7 @@ export type components = {
high?: number;
};
/**
- * RandomRangeInvocation
+ * Random Range
* @description Creates a collection of random numbers
*/
RandomRangeInvocation: {
@@ -4680,7 +4655,7 @@ export type components = {
* @default random_range
* @enum {string}
*/
- type?: "random_range";
+ type: "random_range";
/**
* Low
* @description The inclusive low value
@@ -4706,7 +4681,7 @@ export type components = {
seed?: number;
};
/**
- * RangeInvocation
+ * Integer Range
* @description Creates a range of numbers from start to stop with step
*/
RangeInvocation: {
@@ -4726,7 +4701,7 @@ export type components = {
* @default range
* @enum {string}
*/
- type?: "range";
+ type: "range";
/**
* Start
* @description The start of the range
@@ -4747,7 +4722,7 @@ export type components = {
step?: number;
};
/**
- * RangeOfSizeInvocation
+ * Integer Range of Size
* @description Creates a range from start to start + size with step
*/
RangeOfSizeInvocation: {
@@ -4767,7 +4742,7 @@ export type components = {
* @default range_of_size
* @enum {string}
*/
- type?: "range_of_size";
+ type: "range_of_size";
/**
* Start
* @description The start of the range
@@ -4796,7 +4771,7 @@ export type components = {
removed_image_names: (string)[];
};
/**
- * ResizeLatentsInvocation
+ * Resize Latents
* @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.
*/
ResizeLatentsInvocation: {
@@ -4816,34 +4791,32 @@ export type components = {
* @default lresize
* @enum {string}
*/
- type?: "lresize";
+ type: "lresize";
/**
* Latents
- * @description The latents to resize
+ * @description Latents tensor
*/
latents?: components["schemas"]["LatentsField"];
/**
* Width
- * @description The width to resize to (px)
- * @default 512
+ * @description Width of output (px)
*/
width?: number;
/**
* Height
- * @description The height to resize to (px)
- * @default 512
+ * @description Width of output (px)
*/
height?: number;
/**
* Mode
- * @description The interpolation mode
+ * @description Interpolation mode
* @default bilinear
* @enum {string}
*/
mode?: "nearest" | "linear" | "bilinear" | "bicubic" | "trilinear" | "area" | "nearest-exact";
/**
* Antialias
- * @description Whether or not to antialias (applied in bilinear and bicubic modes only)
+ * @description Whether or not to apply antialiasing (bilinear or bicubic only)
* @default false
*/
antialias?: boolean;
@@ -4859,7 +4832,7 @@ export type components = {
*/
ResourceOrigin: "internal" | "external";
/**
- * SDXLCompelPromptInvocation
+ * SDXL Compel Prompt
* @description Parse prompt using compel package to conditioning.
*/
SDXLCompelPromptInvocation: {
@@ -4879,16 +4852,16 @@ export type components = {
* @default sdxl_compel_prompt
* @enum {string}
*/
- type?: "sdxl_compel_prompt";
+ type: "sdxl_compel_prompt";
/**
* Prompt
- * @description Prompt
+ * @description Prompt to be parsed by Compel to create a conditioning tensor
* @default
*/
prompt?: string;
/**
* Style
- * @description Style prompt
+ * @description Prompt to be parsed by Compel to create a conditioning tensor
* @default
*/
style?: string;
@@ -4924,17 +4897,17 @@ export type components = {
target_height?: number;
/**
* Clip
- * @description Clip to use
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
- * @description Clip2 to use
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip2?: components["schemas"]["ClipField"];
};
/**
- * SDXLLoraLoaderInvocation
+ * SDXL LoRA Loader
* @description Apply selected lora to unet and text_encoder.
*/
SDXLLoraLoaderInvocation: {
@@ -4954,37 +4927,37 @@ export type components = {
* @default sdxl_lora_loader
* @enum {string}
*/
- type?: "sdxl_lora_loader";
+ type: "sdxl_lora_loader";
/**
- * Lora
- * @description Lora model name
+ * LoRA
+ * @description LoRA model to load
*/
- lora?: components["schemas"]["LoRAModelField"];
+ lora: components["schemas"]["LoRAModelField"];
/**
* Weight
- * @description With what weight to apply lora
+ * @description The weight at which the LoRA is applied to each model
* @default 0.75
*/
weight?: number;
/**
- * Unet
- * @description UNet model for applying lora
+ * UNET
+ * @description UNet (scheduler, LoRAs)
*/
unet?: components["schemas"]["UNetField"];
/**
- * Clip
- * @description Clip model for applying lora
+ * CLIP 1
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
- * Clip2
- * @description Clip2 model for applying lora
+ * CLIP 2
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLLoraLoaderOutput
- * @description Model loader output
+ * @description SDXL LoRA Loader Output
*/
SDXLLoraLoaderOutput: {
/**
@@ -4994,23 +4967,23 @@ export type components = {
*/
type?: "sdxl_lora_loader_output";
/**
- * Unet
- * @description UNet submodel
+ * UNet
+ * @description UNet (scheduler, LoRAs)
*/
unet?: components["schemas"]["UNetField"];
/**
- * Clip
- * @description Tokenizer and text_encoder submodels
+ * CLIP 1
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
- * Clip2
- * @description Tokenizer2 and text_encoder2 submodels
+ * CLIP 2
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip2?: components["schemas"]["ClipField"];
};
/**
- * SDXLModelLoaderInvocation
+ * SDXL Main Model Loader
* @description Loads an sdxl base model, outputting its submodels.
*/
SDXLModelLoaderInvocation: {
@@ -5030,10 +5003,10 @@ export type components = {
* @default sdxl_model_loader
* @enum {string}
*/
- type?: "sdxl_model_loader";
+ type: "sdxl_model_loader";
/**
* Model
- * @description The model to load
+ * @description SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load
*/
model: components["schemas"]["MainModelField"];
};
@@ -5049,28 +5022,28 @@ export type components = {
*/
type?: "sdxl_model_loader_output";
/**
- * Unet
- * @description UNet submodel
+ * UNet
+ * @description UNet (scheduler, LoRAs)
*/
- unet?: components["schemas"]["UNetField"];
+ unet: components["schemas"]["UNetField"];
/**
- * Clip
- * @description Tokenizer and text_encoder submodels
+ * CLIP 1
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
- clip?: components["schemas"]["ClipField"];
+ clip: components["schemas"]["ClipField"];
/**
- * Clip2
- * @description Tokenizer and text_encoder submodels
+ * CLIP 2
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
- clip2?: components["schemas"]["ClipField"];
+ clip2: components["schemas"]["ClipField"];
/**
- * Vae
- * @description Vae submodel
+ * VAE
+ * @description VAE
*/
- vae?: components["schemas"]["VaeField"];
+ vae: components["schemas"]["VaeField"];
};
/**
- * SDXLRefinerCompelPromptInvocation
+ * SDXL Refiner Compel Prompt
* @description Parse prompt using compel package to conditioning.
*/
SDXLRefinerCompelPromptInvocation: {
@@ -5090,10 +5063,10 @@ export type components = {
* @default sdxl_refiner_compel_prompt
* @enum {string}
*/
- type?: "sdxl_refiner_compel_prompt";
+ type: "sdxl_refiner_compel_prompt";
/**
* Style
- * @description Style prompt
+ * @description Prompt to be parsed by Compel to create a conditioning tensor
* @default
*/
style?: string;
@@ -5119,17 +5092,18 @@ export type components = {
crop_left?: number;
/**
* Aesthetic Score
+ * @description The aesthetic score to apply to the conditioning tensor
* @default 6
*/
aesthetic_score?: number;
/**
* Clip2
- * @description Clip to use
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip2?: components["schemas"]["ClipField"];
};
/**
- * SDXLRefinerModelLoaderInvocation
+ * SDXL Refiner Model Loader
* @description Loads an sdxl refiner model, outputting its submodels.
*/
SDXLRefinerModelLoaderInvocation: {
@@ -5149,10 +5123,10 @@ export type components = {
* @default sdxl_refiner_model_loader
* @enum {string}
*/
- type?: "sdxl_refiner_model_loader";
+ type: "sdxl_refiner_model_loader";
/**
* Model
- * @description The model to load
+ * @description SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load
*/
model: components["schemas"]["MainModelField"];
};
@@ -5168,23 +5142,23 @@ export type components = {
*/
type?: "sdxl_refiner_model_loader_output";
/**
- * Unet
- * @description UNet submodel
+ * UNet
+ * @description UNet (scheduler, LoRAs)
*/
- unet?: components["schemas"]["UNetField"];
+ unet: components["schemas"]["UNetField"];
/**
- * Clip2
- * @description Tokenizer and text_encoder submodels
+ * CLIP 2
+ * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
- clip2?: components["schemas"]["ClipField"];
+ clip2: components["schemas"]["ClipField"];
/**
- * Vae
- * @description Vae submodel
+ * VAE
+ * @description VAE
*/
- vae?: components["schemas"]["VaeField"];
+ vae: components["schemas"]["VaeField"];
};
/**
- * ScaleLatentsInvocation
+ * Scale Latents
* @description Scales latents by a given factor.
*/
ScaleLatentsInvocation: {
@@ -5204,27 +5178,27 @@ export type components = {
* @default lscale
* @enum {string}
*/
- type?: "lscale";
+ type: "lscale";
/**
* Latents
- * @description The latents to scale
+ * @description Latents tensor
*/
latents?: components["schemas"]["LatentsField"];
/**
* Scale Factor
- * @description The factor by which to scale the latents
+ * @description The factor by which to scale
*/
- scale_factor: number;
+ scale_factor?: number;
/**
* Mode
- * @description The interpolation mode
+ * @description Interpolation mode
* @default bilinear
* @enum {string}
*/
mode?: "nearest" | "linear" | "bilinear" | "bicubic" | "trilinear" | "area" | "nearest-exact";
/**
* Antialias
- * @description Whether or not to antialias (applied in bilinear and bicubic modes only)
+ * @description Whether or not to apply antialiasing (bilinear or bicubic only)
* @default false
*/
antialias?: boolean;
@@ -5236,7 +5210,7 @@ export type components = {
*/
SchedulerPredictionType: "epsilon" | "v_prediction" | "sample";
/**
- * SegmentAnythingProcessorInvocation
+ * Segment Anything Processor
* @description Applies segment anything processing to image
*/
SegmentAnythingProcessorInvocation: {
@@ -5256,7 +5230,7 @@ export type components = {
* @default segment_anything_processor
* @enum {string}
*/
- type?: "segment_anything_processor";
+ type: "segment_anything_processor";
/**
* Image
* @description The image to process
@@ -5264,7 +5238,7 @@ export type components = {
image?: components["schemas"]["ImageField"];
};
/**
- * ShowImageInvocation
+ * Show Image
* @description Displays a provided image, and passes it forward in the pipeline.
*/
ShowImageInvocation: {
@@ -5284,7 +5258,7 @@ export type components = {
* @default show_image
* @enum {string}
*/
- type?: "show_image";
+ type: "show_image";
/**
* Image
* @description The image to show
@@ -5442,7 +5416,7 @@ export type components = {
variant: components["schemas"]["ModelVariantType"];
};
/**
- * StepParamEasingInvocation
+ * Step Param Easing
* @description Experimental per-step parameter easing for denoising steps
*/
StepParamEasingInvocation: {
@@ -5462,7 +5436,7 @@ export type components = {
* @default step_param_easing
* @enum {string}
*/
- type?: "step_param_easing";
+ type: "step_param_easing";
/**
* Easing
* @description The easing function to use
@@ -5523,6 +5497,24 @@ export type components = {
*/
show_easing_plot?: boolean;
};
+ /**
+ * StringCollectionOutput
+ * @description A collection of strings
+ */
+ StringCollectionOutput: {
+ /**
+ * Type
+ * @default string_collection_output
+ * @enum {string}
+ */
+ type?: "string_collection_output";
+ /**
+ * Collection
+ * @description The output strings
+ * @default []
+ */
+ collection?: (string)[];
+ };
/**
* StringOutput
* @description A string output
@@ -5538,7 +5530,7 @@ export type components = {
* Text
* @description The output string
*/
- text?: string;
+ text: string;
};
/**
* SubModelType
@@ -5547,7 +5539,7 @@ export type components = {
*/
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker";
/**
- * SubtractInvocation
+ * Subtract Integers
* @description Subtracts two numbers
*/
SubtractInvocation: {
@@ -5567,7 +5559,7 @@ export type components = {
* @default sub
* @enum {string}
*/
- type?: "sub";
+ type: "sub";
/**
* A
* @description The first number
@@ -5600,8 +5592,8 @@ export type components = {
error?: components["schemas"]["ModelError"];
};
/**
- * TileResamplerProcessorInvocation
- * @description Base class for invocations that preprocess images for ControlNet
+ * Tile Resample Processor
+ * @description Tile resampler processor
*/
TileResamplerProcessorInvocation: {
/**
@@ -5620,7 +5612,7 @@ export type components = {
* @default tile_image_processor
* @enum {string}
*/
- type?: "tile_image_processor";
+ type: "tile_image_processor";
/**
* Image
* @description The image to process
@@ -5686,7 +5678,7 @@ export type components = {
vae: components["schemas"]["ModelInfo"];
};
/**
- * VaeLoaderInvocation
+ * VAE Loader
* @description Loads a VAE model, outputting a VaeLoaderOutput
*/
VaeLoaderInvocation: {
@@ -5706,10 +5698,10 @@ export type components = {
* @default vae_loader
* @enum {string}
*/
- type?: "vae_loader";
+ type: "vae_loader";
/**
- * Vae Model
- * @description The VAE to load
+ * VAE
+ * @description VAE model to load
*/
vae_model: components["schemas"]["VAEModelField"];
};
@@ -5725,10 +5717,10 @@ export type components = {
*/
type?: "vae_loader_output";
/**
- * Vae
- * @description Vae model
+ * VAE
+ * @description VAE
*/
- vae?: components["schemas"]["VaeField"];
+ vae: components["schemas"]["VaeField"];
};
/** VaeModelConfig */
VaeModelConfig: {
@@ -5763,7 +5755,7 @@ export type components = {
type: string;
};
/**
- * ZoeDepthImageProcessorInvocation
+ * Zoe (Depth) Processor
* @description Applies Zoe depth processing to image
*/
ZoeDepthImageProcessorInvocation: {
@@ -5783,7 +5775,7 @@ export type components = {
* @default zoe_depth_image_processor
* @enum {string}
*/
- type?: "zoe_depth_image_processor";
+ type: "zoe_depth_image_processor";
/**
* Image
* @description The image to process
@@ -5791,17 +5783,71 @@ export type components = {
image?: components["schemas"]["ImageField"];
};
/**
- * ControlNetModelFormat
- * @description An enumeration.
- * @enum {string}
+ * UIConfigBase
+ * @description Provides additional node configuration to the UI.
+ * This is used internally by the @tags and @title decorator logic. You probably want to use those
+ * decorators, though you may add this class to a node definition to specify the title and tags.
*/
- ControlNetModelFormat: "checkpoint" | "diffusers";
+ UIConfigBase: {
+ /**
+ * Tags
+ * @description The tags to display in the UI
+ */
+ tags?: (string)[];
+ /**
+ * Title
+ * @description The display name of the node
+ */
+ title?: string;
+ };
/**
- * StableDiffusion2ModelFormat
- * @description An enumeration.
+ * Input
+ * @description The type of input a field accepts.
+ * - `Input.Direct`: The field must have its value provided directly, when the invocation and field are instantiated.
+ * - `Input.Connection`: The field must have its value provided by a connection.
+ * - `Input.Any`: The field may have its value provided either directly or by a connection.
* @enum {string}
*/
- StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
+ Input: "connection" | "direct" | "any";
+ /**
+ * UITypeHint
+ * @description Type hints for the UI.
+ * If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes.
+ * @enum {string}
+ */
+ UITypeHint: "integer" | "float" | "boolean" | "string" | "enum" | "array" | "ImageField" | "LatentsField" | "ConditioningField" | "ControlField" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "ColorField" | "ImageCollection" | "IntegerCollection" | "FloatCollection" | "StringCollection" | "BooleanCollection" | "Collection" | "CollectionItem" | "Seed" | "FilePath";
+ /**
+ * UIComponent
+ * @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type.
+ * @enum {string}
+ */
+ UIComponent: "none" | "textarea" | "slider";
+ /**
+ * _InputField
+ * @description *DO NOT USE*
+ * This helper class is used to tell the client about our custom field attributes via OpenAPI
+ * schema generation, and Typescript type generation from that schema. It serves no functional
+ * purpose in the backend.
+ */
+ _InputField: {
+ input: components["schemas"]["Input"];
+ /** Ui Hidden */
+ ui_hidden: boolean;
+ ui_type_hint?: components["schemas"]["UITypeHint"];
+ ui_component?: components["schemas"]["UIComponent"];
+ };
+ /**
+ * _OutputField
+ * @description *DO NOT USE*
+ * This helper class is used to tell the client about our custom field attributes via OpenAPI
+ * schema generation, and Typescript type generation from that schema. It serves no functional
+ * purpose in the backend.
+ */
+ _OutputField: {
+ /** Ui Hidden */
+ ui_hidden: boolean;
+ ui_type_hint?: components["schemas"]["UITypeHint"];
+ };
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
@@ -5820,6 +5866,18 @@ export type components = {
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
+ /**
+ * ControlNetModelFormat
+ * @description An enumeration.
+ * @enum {string}
+ */
+ ControlNetModelFormat: "checkpoint" | "diffusers";
+ /**
+ * StableDiffusion2ModelFormat
+ * @description An enumeration.
+ * @enum {string}
+ */
+ StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@@ -5930,7 +5988,7 @@ export type operations = {
};
requestBody: {
content: {
- "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
+ "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
};
};
responses: {
@@ -5967,7 +6025,7 @@ export type operations = {
};
requestBody: {
content: {
- "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
+ "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
};
};
responses: {
diff --git a/invokeai/frontend/web/src/services/api/thunks/schema.ts b/invokeai/frontend/web/src/services/api/thunks/schema.ts
index 9b5748ae41..d296d824b1 100644
--- a/invokeai/frontend/web/src/services/api/thunks/schema.ts
+++ b/invokeai/frontend/web/src/services/api/thunks/schema.ts
@@ -1,5 +1,4 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
-import { logger } from 'app/logging/logger';
function getCircularReplacer() {
const ancestors: Record[] = [];
@@ -23,14 +22,11 @@ function getCircularReplacer() {
export const receivedOpenAPISchema = createAsyncThunk(
'nodes/receivedOpenAPISchema',
- async (_, { dispatch, rejectWithValue }) => {
- const log = logger('system');
+ async (_, { rejectWithValue }) => {
try {
const response = await fetch(`openapi.json`);
const openAPISchema = await response.json();
- log.info({ openAPISchema }, 'Received OpenAPI schema');
-
const schemaJSON = JSON.parse(
JSON.stringify(openAPISchema, getCircularReplacer())
);
diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts
index 435b605489..338f8d2849 100644
--- a/invokeai/frontend/web/src/services/api/types.ts
+++ b/invokeai/frontend/web/src/services/api/types.ts
@@ -1,8 +1,9 @@
import { UseToastOptions } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
-import { O } from 'ts-toolbelt';
import { components, paths } from './schema';
+type s = components['schemas'];
+
export type ImageCache = EntityState;
export type ListImagesArgs = NonNullable<
@@ -25,70 +26,63 @@ export type UpdateBoardArg =
* This is an unsafe type; the object inside is not guaranteed to be valid.
*/
export type UnsafeImageMetadata = {
- metadata: components['schemas']['CoreMetadata'];
- graph: NonNullable;
+ metadata: s['CoreMetadata'];
+ graph: NonNullable;
};
-/**
- * Marks the `type` property as required. Use for nodes.
- */
-type TypeReq = O.Required;
-
-// Extracted types from API schema
+export type _InputField = s['_InputField'];
+export type _OutputField = s['_OutputField'];
// App Info
-export type AppVersion = components['schemas']['AppVersion'];
-export type AppConfig = components['schemas']['AppConfig'];
+export type AppVersion = s['AppVersion'];
+export type AppConfig = s['AppConfig'];
// Images
-export type ImageDTO = components['schemas']['ImageDTO'];
-export type BoardDTO = components['schemas']['BoardDTO'];
-export type BoardChanges = components['schemas']['BoardChanges'];
-export type ImageChanges = components['schemas']['ImageRecordChanges'];
-export type ImageCategory = components['schemas']['ImageCategory'];
-export type ResourceOrigin = components['schemas']['ResourceOrigin'];
-export type ImageField = components['schemas']['ImageField'];
-export type ImageMetadata = components['schemas']['ImageMetadata'];
+export type ImageDTO = s['ImageDTO'];
+export type BoardDTO = s['BoardDTO'];
+export type BoardChanges = s['BoardChanges'];
+export type ImageChanges = s['ImageRecordChanges'];
+export type ImageCategory = s['ImageCategory'];
+export type ResourceOrigin = s['ResourceOrigin'];
+export type ImageField = s['ImageField'];
+export type ImageMetadata = s['ImageMetadata'];
export type OffsetPaginatedResults_BoardDTO_ =
- components['schemas']['OffsetPaginatedResults_BoardDTO_'];
+ s['OffsetPaginatedResults_BoardDTO_'];
export type OffsetPaginatedResults_ImageDTO_ =
- components['schemas']['OffsetPaginatedResults_ImageDTO_'];
+ s['OffsetPaginatedResults_ImageDTO_'];
// Models
-export type ModelType = components['schemas']['ModelType'];
-export type SubModelType = components['schemas']['SubModelType'];
-export type BaseModelType = components['schemas']['BaseModelType'];
-export type MainModelField = components['schemas']['MainModelField'];
-export type OnnxModelField = components['schemas']['OnnxModelField'];
-export type VAEModelField = components['schemas']['VAEModelField'];
-export type LoRAModelField = components['schemas']['LoRAModelField'];
-export type ControlNetModelField =
- components['schemas']['ControlNetModelField'];
-export type ModelsList = components['schemas']['ModelsList'];
-export type ControlField = components['schemas']['ControlField'];
+export type ModelType = s['ModelType'];
+export type SubModelType = s['SubModelType'];
+export type BaseModelType = s['BaseModelType'];
+export type MainModelField = s['MainModelField'];
+export type OnnxModelField = s['OnnxModelField'];
+export type VAEModelField = s['VAEModelField'];
+export type LoRAModelField = s['LoRAModelField'];
+export type ControlNetModelField = s['ControlNetModelField'];
+export type ModelsList = s['ModelsList'];
+export type ControlField = s['ControlField'];
// Model Configs
-export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
-export type VaeModelConfig = components['schemas']['VaeModelConfig'];
+export type LoRAModelConfig = s['LoRAModelConfig'];
+export type VaeModelConfig = s['VaeModelConfig'];
export type ControlNetModelCheckpointConfig =
- components['schemas']['ControlNetModelCheckpointConfig'];
+ s['ControlNetModelCheckpointConfig'];
export type ControlNetModelDiffusersConfig =
- components['schemas']['ControlNetModelDiffusersConfig'];
+ s['ControlNetModelDiffusersConfig'];
export type ControlNetModelConfig =
| ControlNetModelCheckpointConfig
| ControlNetModelDiffusersConfig;
-export type TextualInversionModelConfig =
- components['schemas']['TextualInversionModelConfig'];
+export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
export type DiffusersModelConfig =
- | components['schemas']['StableDiffusion1ModelDiffusersConfig']
- | components['schemas']['StableDiffusion2ModelDiffusersConfig']
- | components['schemas']['StableDiffusionXLModelDiffusersConfig'];
+ | s['StableDiffusion1ModelDiffusersConfig']
+ | s['StableDiffusion2ModelDiffusersConfig']
+ | s['StableDiffusionXLModelDiffusersConfig'];
export type CheckpointModelConfig =
- | components['schemas']['StableDiffusion1ModelCheckpointConfig']
- | components['schemas']['StableDiffusion2ModelCheckpointConfig']
- | components['schemas']['StableDiffusionXLModelCheckpointConfig'];
-export type OnnxModelConfig =
- components['schemas']['ONNXStableDiffusion1ModelConfig'];
+ | s['StableDiffusion1ModelCheckpointConfig']
+ | s['StableDiffusion2ModelCheckpointConfig']
+ | s['StableDiffusionXLModelCheckpointConfig'];
+export type OnnxModelConfig = s['ONNXStableDiffusion1ModelConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
@@ -98,154 +92,75 @@ export type AnyModelConfig =
| MainModelConfig
| OnnxModelConfig;
-export type MergeModelConfig = components['schemas']['Body_merge_models'];
-export type ImportModelConfig = components['schemas']['Body_import_model'];
+export type MergeModelConfig = s['Body_merge_models'];
+export type ImportModelConfig = s['Body_import_model'];
// Graphs
-export type Graph = components['schemas']['Graph'];
-export type Edge = components['schemas']['Edge'];
-export type GraphExecutionState = components['schemas']['GraphExecutionState'];
+export type Graph = s['Graph'];
+export type Edge = s['Edge'];
+export type GraphExecutionState = s['GraphExecutionState'];
// General nodes
-export type CollectInvocation = TypeReq<
- components['schemas']['CollectInvocation']
->;
-export type IterateInvocation = TypeReq<
- components['schemas']['IterateInvocation']
->;
-export type RangeInvocation = TypeReq;
-export type RandomRangeInvocation = TypeReq<
- components['schemas']['RandomRangeInvocation']
->;
-export type RangeOfSizeInvocation = TypeReq<
- components['schemas']['RangeOfSizeInvocation']
->;
-export type ImageResizeInvocation = TypeReq<
- components['schemas']['ImageResizeInvocation']
->;
-export type ImageScaleInvocation = TypeReq<
- components['schemas']['ImageScaleInvocation']
->;
-export type RandomIntInvocation = TypeReq<
- components['schemas']['RandomIntInvocation']
->;
-export type CompelInvocation = TypeReq<
- components['schemas']['CompelInvocation']
->;
-export type DynamicPromptInvocation = TypeReq<
- components['schemas']['DynamicPromptInvocation']
->;
-export type NoiseInvocation = TypeReq;
-export type ONNXTextToLatentsInvocation = TypeReq<
- components['schemas']['ONNXTextToLatentsInvocation']
->;
-export type DenoiseLatentsInvocation = TypeReq<
- components['schemas']['DenoiseLatentsInvocation']
->;
-export type ImageToLatentsInvocation = TypeReq<
- components['schemas']['ImageToLatentsInvocation']
->;
-export type LatentsToImageInvocation = TypeReq<
- components['schemas']['LatentsToImageInvocation']
->;
-export type ImageCollectionInvocation = TypeReq<
- components['schemas']['ImageCollectionInvocation']
->;
-export type MainModelLoaderInvocation = TypeReq<
- components['schemas']['MainModelLoaderInvocation']
->;
-export type OnnxModelLoaderInvocation = TypeReq<
- components['schemas']['OnnxModelLoaderInvocation']
->;
-export type LoraLoaderInvocation = TypeReq<
- components['schemas']['LoraLoaderInvocation']
->;
-export type SDXLLoraLoaderInvocation = TypeReq<
- components['schemas']['SDXLLoraLoaderInvocation']
->;
-export type MetadataAccumulatorInvocation = TypeReq<
- components['schemas']['MetadataAccumulatorInvocation']
->;
-export type ESRGANInvocation = TypeReq<
- components['schemas']['ESRGANInvocation']
->;
-export type DivideInvocation = TypeReq<
- components['schemas']['DivideInvocation']
->;
-export type InfillTileInvocation = TypeReq<
- components['schemas']['InfillTileInvocation']
->;
-export type InfillPatchmatchInvocation = TypeReq<
- components['schemas']['InfillPatchMatchInvocation']
->;
-export type ImageNSFWBlurInvocation = TypeReq<
- components['schemas']['ImageNSFWBlurInvocation']
->;
-export type ImageWatermarkInvocation = TypeReq<
- components['schemas']['ImageWatermarkInvocation']
->;
-export type ImageBlurInvocation = TypeReq<
- components['schemas']['ImageBlurInvocation']
->;
-export type ColorCorrectInvocation = TypeReq<
- components['schemas']['ColorCorrectInvocation']
->;
-export type ImagePasteInvocation = TypeReq<
- components['schemas']['ImagePasteInvocation']
->;
+export type CollectInvocation = s['CollectInvocation'];
+export type IterateInvocation = s['IterateInvocation'];
+export type RangeInvocation = s['RangeInvocation'];
+export type RandomRangeInvocation = s['RandomRangeInvocation'];
+export type RangeOfSizeInvocation = s['RangeOfSizeInvocation'];
+export type ImageResizeInvocation = s['ImageResizeInvocation'];
+export type ImageBlurInvocation = s['ImageBlurInvocation'];
+export type ImageScaleInvocation = s['ImageScaleInvocation'];
+export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation'];
+export type InfillTileInvocation = s['InfillTileInvocation'];
+export type RandomIntInvocation = s['RandomIntInvocation'];
+export type CompelInvocation = s['CompelInvocation'];
+export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
+export type NoiseInvocation = s['NoiseInvocation'];
+export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation'];
+export type ONNXTextToLatentsInvocation = s['ONNXTextToLatentsInvocation'];
+export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation'];
+export type ImageToLatentsInvocation = s['ImageToLatentsInvocation'];
+export type LatentsToImageInvocation = s['LatentsToImageInvocation'];
+export type ImageCollectionInvocation = s['ImageCollectionInvocation'];
+export type MainModelLoaderInvocation = s['MainModelLoaderInvocation'];
+export type OnnxModelLoaderInvocation = s['OnnxModelLoaderInvocation'];
+export type LoraLoaderInvocation = s['LoraLoaderInvocation'];
+export type MetadataAccumulatorInvocation = s['MetadataAccumulatorInvocation'];
+export type ESRGANInvocation = s['ESRGANInvocation'];
+export type DivideInvocation = s['DivideInvocation'];
+export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
+export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
// ControlNet Nodes
-export type ControlNetInvocation = TypeReq<
- components['schemas']['ControlNetInvocation']
->;
-export type CannyImageProcessorInvocation = TypeReq<
- components['schemas']['CannyImageProcessorInvocation']
->;
-export type ContentShuffleImageProcessorInvocation = TypeReq<
- components['schemas']['ContentShuffleImageProcessorInvocation']
->;
-export type HedImageProcessorInvocation = TypeReq<
- components['schemas']['HedImageProcessorInvocation']
->;
-export type LineartAnimeImageProcessorInvocation = TypeReq<
- components['schemas']['LineartAnimeImageProcessorInvocation']
->;
-export type LineartImageProcessorInvocation = TypeReq<
- components['schemas']['LineartImageProcessorInvocation']
->;
-export type MediapipeFaceProcessorInvocation = TypeReq<
- components['schemas']['MediapipeFaceProcessorInvocation']
->;
-export type MidasDepthImageProcessorInvocation = TypeReq<
- components['schemas']['MidasDepthImageProcessorInvocation']
->;
-export type MlsdImageProcessorInvocation = TypeReq<
- components['schemas']['MlsdImageProcessorInvocation']
->;
-export type NormalbaeImageProcessorInvocation = TypeReq<
- components['schemas']['NormalbaeImageProcessorInvocation']
->;
-export type OpenposeImageProcessorInvocation = TypeReq<
- components['schemas']['OpenposeImageProcessorInvocation']
->;
-export type PidiImageProcessorInvocation = TypeReq<
- components['schemas']['PidiImageProcessorInvocation']
->;
-export type ZoeDepthImageProcessorInvocation = TypeReq<
- components['schemas']['ZoeDepthImageProcessorInvocation']
->;
+export type ControlNetInvocation = s['ControlNetInvocation'];
+export type CannyImageProcessorInvocation = s['CannyImageProcessorInvocation'];
+export type ContentShuffleImageProcessorInvocation =
+ s['ContentShuffleImageProcessorInvocation'];
+export type HedImageProcessorInvocation = s['HedImageProcessorInvocation'];
+export type LineartAnimeImageProcessorInvocation =
+ s['LineartAnimeImageProcessorInvocation'];
+export type LineartImageProcessorInvocation =
+ s['LineartImageProcessorInvocation'];
+export type MediapipeFaceProcessorInvocation =
+ s['MediapipeFaceProcessorInvocation'];
+export type MidasDepthImageProcessorInvocation =
+ s['MidasDepthImageProcessorInvocation'];
+export type MlsdImageProcessorInvocation = s['MlsdImageProcessorInvocation'];
+export type NormalbaeImageProcessorInvocation =
+ s['NormalbaeImageProcessorInvocation'];
+export type OpenposeImageProcessorInvocation =
+ s['OpenposeImageProcessorInvocation'];
+export type PidiImageProcessorInvocation = s['PidiImageProcessorInvocation'];
+export type ZoeDepthImageProcessorInvocation =
+ s['ZoeDepthImageProcessorInvocation'];
// Node Outputs
-export type ImageOutput = components['schemas']['ImageOutput'];
-export type MaskOutput = components['schemas']['MaskOutput'];
-export type PromptOutput = components['schemas']['PromptOutput'];
-export type IterateInvocationOutput =
- components['schemas']['IterateInvocationOutput'];
-export type CollectInvocationOutput =
- components['schemas']['CollectInvocationOutput'];
-export type LatentsOutput = components['schemas']['LatentsOutput'];
-export type GraphInvocationOutput =
- components['schemas']['GraphInvocationOutput'];
+export type ImageOutput = s['ImageOutput'];
+export type MaskOutput = s['MaskOutput'];
+export type PromptOutput = s['PromptOutput'];
+export type IterateInvocationOutput = s['IterateInvocationOutput'];
+export type CollectInvocationOutput = s['CollectInvocationOutput'];
+export type LatentsOutput = s['LatentsOutput'];
+export type GraphInvocationOutput = s['GraphInvocationOutput'];
// Post-image upload actions, controls workflows when images are uploaded
@@ -284,3 +199,14 @@ export type PostUploadAction =
| CanvasInitialImageAction
| ToastAction
| AddToBatchAction;
+
+type TypeGuard = {
+ (input: unknown): input is T;
+};
+
+// eslint-disable-next-line @typescript-eslint/no-explicit-any
+export type TypeGuardFor> = T extends TypeGuard<
+ infer U
+>
+ ? U
+ : never;
diff --git a/invokeai/frontend/web/src/theme/components/checkbox.ts b/invokeai/frontend/web/src/theme/components/checkbox.ts
index 58871237e5..2a32bc2796 100644
--- a/invokeai/frontend/web/src/theme/components/checkbox.ts
+++ b/invokeai/frontend/web/src/theme/components/checkbox.ts
@@ -13,12 +13,12 @@ const invokeAIControl = defineStyle((props) => {
return {
bg: mode('base.200', 'base.700')(props),
- borderColor: mode('base.200', 'base.700')(props),
+ borderColor: mode('base.300', 'base.600')(props),
color: mode('base.900', 'base.100')(props),
_checked: {
- bg: mode(`${c}.300`, `${c}.600`)(props),
- borderColor: mode(`${c}.300`, `${c}.600`)(props),
+ bg: mode(`${c}.300`, `${c}.500`)(props),
+ borderColor: mode(`${c}.300`, `${c}.500`)(props),
color: mode(`${c}.900`, `${c}.100`)(props),
_hover: {
@@ -45,7 +45,8 @@ const invokeAIControl = defineStyle((props) => {
},
_focusVisible: {
- boxShadow: 'outline',
+ boxShadow: 'none',
+ outline: 'none',
},
_invalid: {
diff --git a/invokeai/frontend/web/src/theme/components/formLabel.ts b/invokeai/frontend/web/src/theme/components/formLabel.ts
index 866bb7beb1..922cbd555c 100644
--- a/invokeai/frontend/web/src/theme/components/formLabel.ts
+++ b/invokeai/frontend/web/src/theme/components/formLabel.ts
@@ -14,6 +14,9 @@ const invokeAI = defineStyle((props) => {
opacity: 0.4,
},
color: mode('base.700', 'base.300')(props),
+ _invalid: {
+ color: mode('error.500', 'error.300')(props),
+ },
};
});
diff --git a/invokeai/frontend/web/src/theme/components/tabs.ts b/invokeai/frontend/web/src/theme/components/tabs.ts
index adcce73bbc..141c0410a7 100644
--- a/invokeai/frontend/web/src/theme/components/tabs.ts
+++ b/invokeai/frontend/web/src/theme/components/tabs.ts
@@ -8,16 +8,16 @@ import { mode } from '@chakra-ui/theme-tools';
const { defineMultiStyleConfig, definePartsStyle } =
createMultiStyleConfigHelpers(parts.keys);
-const invokeAIRoot = defineStyle((_props) => {
+const appTabsRoot = defineStyle((_props) => {
return {
display: 'flex',
columnGap: 4,
};
});
-const invokeAITab = defineStyle((_props) => ({}));
+const appTabsTab = defineStyle((_props) => ({}));
-const invokeAITablist = defineStyle((props) => {
+const appTabsTablist = defineStyle((props) => {
const { colorScheme: c } = props;
return {
@@ -65,24 +65,49 @@ const invokeAITablist = defineStyle((props) => {
};
});
-const invokeAITabpanel = defineStyle((_props) => ({
+const appTabsTabpanel = defineStyle((_props) => ({
padding: 0,
height: '100%',
}));
-const invokeAI = definePartsStyle((props) => ({
- root: invokeAIRoot(props),
- tab: invokeAITab(props),
- tablist: invokeAITablist(props),
- tabpanel: invokeAITabpanel(props),
+const appTabs = definePartsStyle((props) => ({
+ root: appTabsRoot(props),
+ tab: appTabsTab(props),
+ tablist: appTabsTablist(props),
+ tabpanel: appTabsTabpanel(props),
+}));
+
+const line = definePartsStyle((props) => ({
+ tab: {
+ borderTopRadius: 'base',
+ px: 4,
+ py: 1,
+ fontSize: 'sm',
+ color: mode('base.600', 'base.400')(props),
+ fontWeight: 500,
+ _selected: {
+ color: mode('accent.600', 'accent.400')(props),
+ },
+ },
+ tabpanel: {
+ p: 0,
+ pt: 4,
+ w: 'full',
+ h: 'full',
+ },
+ tabpanels: {
+ w: 'full',
+ h: 'full',
+ },
}));
export const tabsTheme = defineMultiStyleConfig({
variants: {
- invokeAI,
+ line,
+ appTabs,
},
defaultProps: {
- variant: 'invokeAI',
+ variant: 'appTabs',
colorScheme: 'accent',
},
});
diff --git a/invokeai/frontend/web/src/theme/components/textarea.ts b/invokeai/frontend/web/src/theme/components/textarea.ts
index 8dd59c18e0..c1a3da271a 100644
--- a/invokeai/frontend/web/src/theme/components/textarea.ts
+++ b/invokeai/frontend/web/src/theme/components/textarea.ts
@@ -40,6 +40,7 @@ const invokeAI = defineStyle((props) => ({
},
},
},
+ p: 2,
}));
export const textareaTheme = defineStyleConfig({
diff --git a/invokeai/frontend/web/src/theme/custom/reactflow.ts b/invokeai/frontend/web/src/theme/custom/reactflow.ts
new file mode 100644
index 0000000000..bb8cc1f04c
--- /dev/null
+++ b/invokeai/frontend/web/src/theme/custom/reactflow.ts
@@ -0,0 +1,21 @@
+import { SystemStyleObject } from '@chakra-ui/styled-system';
+
+const selectionStyles: SystemStyleObject = {
+ backgroundColor: 'accentAlpha.150 !important',
+ borderColor: 'accentAlpha.700 !important',
+ borderRadius: 'base !important',
+ borderStyle: 'dashed !important',
+ _dark: {
+ borderColor: 'accent.400 !important',
+ },
+};
+
+export const reactflowStyles: SystemStyleObject = {
+ '.react-flow__nodesselection-rect': {
+ ...selectionStyles,
+ padding: '1rem !important',
+ boxSizing: 'content-box !important',
+ transform: 'translate(-1rem, -1rem) !important',
+ },
+ '.react-flow__selection': selectionStyles,
+};
diff --git a/invokeai/frontend/web/src/theme/theme.ts b/invokeai/frontend/web/src/theme/theme.ts
index afed8688ee..f602fcef1c 100644
--- a/invokeai/frontend/web/src/theme/theme.ts
+++ b/invokeai/frontend/web/src/theme/theme.ts
@@ -21,6 +21,7 @@ import { tabsTheme } from './components/tabs';
import { textTheme } from './components/text';
import { textareaTheme } from './components/textarea';
import { tooltipTheme } from './components/tooltip';
+import { reactflowStyles } from './custom/reactflow';
export const theme: ThemeOverride = {
config: {
@@ -44,11 +45,27 @@ export const theme: ThemeOverride = {
color: 'base.900',
'.chakra-ui-dark &': { bg: 'base.800', color: 'base.100' },
},
+ nodeBody: {
+ bg: 'base.100',
+ color: 'base.900',
+ '.chakra-ui-dark &': { bg: 'base.800', color: 'base.100' },
+ },
+ nodeHeader: {
+ bg: 'base.200',
+ color: 'base.900',
+ '.chakra-ui-dark &': { bg: 'base.700', color: 'base.100' },
+ },
+ nodeFooter: {
+ bg: 'base.200',
+ color: 'base.900',
+ '.chakra-ui-dark &': { bg: 'base.700', color: 'base.100' },
+ },
},
styles: {
global: () => ({
layerStyle: 'body',
'*': { ...no_scrollbar },
+ ...reactflowStyles,
}),
},
direction: 'ltr',
@@ -85,7 +102,10 @@ export const theme: ThemeOverride = {
'0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 3px var(--invokeai-colors-accent-500)',
dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 3px var(--invokeai-colors-accent-400)',
},
- nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-accent-450)`,
+ nodeSelectedOutline: {
+ light: `0 0 0 2px var(--invokeai-colors-accent-400)`,
+ dark: `0 0 0 2px var(--invokeai-colors-accent-500)`,
+ },
},
colors: InvokeAIColors,
components: {