Compare commits

..

1 Commits

121 changed files with 2445 additions and 4838 deletions

View File

@@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
"Custom" fields will always be treated as stateless fields.
##### Single and Collection Fields
##### Collection and Scalar Fields
Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field.
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`).
- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`).
- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
## Implementation
@@ -173,7 +173,8 @@ Field types are represented as structured objects:
```ts
type FieldType = {
name: string;
cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION';
isCollection: boolean;
isCollectionOrScalar: boolean;
};
```
@@ -185,7 +186,7 @@ There are 4 general cases for field type parsing.
When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property.
We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`.
We create a field type name from this `type` string (e.g. `string` -> `StringField`).
##### Complex Types
@@ -199,13 +200,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi
When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type.
We use the item type for field type name. The cardinality is `"COLLECTION"`.
We use the item type for field type name, adding `isCollection: true` to the field type.
##### Single or Collection Types
##### Collection or Scalar Types
When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type.
##### Optional Fields

View File

@@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The
### Requirements
Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md).
Before you start, go through the [installation requirements].
### Installation Walkthrough
@@ -79,7 +79,7 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
- You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
!!! example "Install with an extra index URL"
@@ -116,4 +116,4 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
!!! warning
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.

View File

@@ -6,12 +6,13 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field, JsonValue
from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies
@@ -41,17 +42,13 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
metadata: Optional[JsonValue] = Body(
default=None, description="The metadata to associate with the image", embed=True
),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
_metadata = None
_workflow = None
_graph = None
metadata = None
workflow = None
contents = await file.read()
try:
@@ -65,28 +62,22 @@ async def upload_image(
# TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
metadata_raw = pil_image.info.get("invokeai_metadata", None)
if metadata_raw:
try:
metadata = MetadataFieldValidator.validate_json(metadata_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
# attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
pass
# attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image")
pass
if workflow_raw is not None:
try:
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
try:
image_dto = ApiDependencies.invoker.services.images.create(
@@ -95,9 +86,8 @@ async def upload_image(
image_category=image_category,
session_id=session_id,
board_id=board_id,
metadata=_metadata,
workflow=_workflow,
graph=_graph,
metadata=metadata,
workflow=workflow,
is_intermediate=is_intermediate,
)
@@ -195,21 +185,14 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
class WorkflowAndGraphResponse(BaseModel):
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> WorkflowAndGraphResponse:
) -> Optional[WorkflowWithoutID]:
try:
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
return ApiDependencies.invoker.services.images.get_workflow(image_name)
except Exception:
raise HTTPException(status_code=404)

View File

@@ -24,6 +24,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
UIType,
@@ -79,13 +80,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@@ -1,10 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
from typing import Literal, Optional, List, Union
import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from transformers import AutoModelForCausalLM, AutoTokenizer
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
@@ -15,7 +16,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.primitives import ImageOutput, CaptionImageOutputs, CaptionImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
@@ -66,6 +67,56 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto)
@invocation(
"auto_caption_image",
title="Automatically Caption Image",
tags=["image", "caption"],
category="image",
version="1.2.2",
)
class CaptionImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Adds a caption to an image"""
images: Union[ImageField,List[ImageField]] = InputField(description="The image to caption")
prompt: str = InputField(default="Describe this list of images in 20 words or less", description="Describe how you would like the image to be captioned.")
def invoke(self, context: InvocationContext) -> CaptionImageOutputs:
model_id = "vikhyatk/moondream2"
model_revision = "2024-04-02"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
moondream_model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=model_revision
)
output: CaptionImageOutputs = CaptionImageOutputs()
try:
from PIL.Image import Image
images: List[Image] = []
image_fields = self.images if isinstance(self.images, list) else [self.images]
for image in image_fields:
images.append(context.images.get_pil(image.image_name))
answers: List[str] = moondream_model.batch_answer(
images=images,
prompts=[self.prompt] * len(images),
tokenizer=tokenizer,
)
assert isinstance(answers, list)
for i, answer in enumerate(answers):
output.images.append(CaptionImageOutput(
image=image_fields[i],
width=images[i].width,
height=images[i].height,
caption=answer
))
except:
raise
finally:
del moondream_model
del tokenizer
return output
@invocation(
"img_crop",
title="Crop Image",
@@ -194,7 +245,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Extracts the alpha channel of an image as a mask."""
image: ImageField = InputField(description="The image to create the mask from")
image: List[ImageField] = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@@ -67,6 +67,7 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)

View File

@@ -11,7 +11,6 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -94,46 +93,19 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
@invocation_output("model_identifier_output")
class ModelIdentifierOutput(BaseInvocationOutput):
"""Model identifier output"""
model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")
@invocation(
"model_identifier",
title="Model identifier",
tags=["model"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
error."""
model: ModelIdentifierField = InputField(description="The model to select", title="Model")
def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
return ModelIdentifierOutput(model=self.model)
@invocation(
"main_model_loader",
title="Main Model",
tags=["model"],
category="model",
version="1.0.3",
version="1.0.2",
)
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@@ -162,12 +134,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2")
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -225,12 +197,12 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@@ -301,13 +273,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA",
tags=["lora", "model"],
category="model",
version="1.0.3",
version="1.0.2",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -442,12 +414,12 @@ class SDXLLoRACollectionLoader(BaseInvocation):
return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
)
def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from typing import Optional, List
import torch
@@ -247,6 +247,17 @@ class ImageOutput(BaseInvocationOutput):
)
@invocation_output("captioned_image_output")
class CaptionImageOutput(ImageOutput):
caption: str = OutputField(description="Caption for given image")
@invocation_output("captioned_image_outputs")
class CaptionImageOutputs(BaseInvocationOutput):
images: List[CaptionImageOutput] = OutputField(description="List of captioned images", default=[])
@invocation_output("image_collection_output")
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""

View File

@@ -1,4 +1,4 @@
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
@@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.3",
version="1.0.2",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
@@ -55,6 +55,7 @@ class T2IAdapterInvocation(BaseInvocation):
t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.T2IAdapterModel,
)

View File

@@ -122,8 +122,6 @@ class EventServiceBase:
source_node_id: str,
error_type: str,
error: str,
user_id: str | None,
project_id: str | None,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
@@ -137,8 +135,6 @@ class EventServiceBase:
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
"user_id": user_id,
"project_id": project_id,
},
)

View File

@@ -4,6 +4,9 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@@ -30,9 +33,8 @@ class ImageFileStorageBase(ABC):
self,
image: PILImageType,
image_name: str,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@@ -44,11 +46,6 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[str]:
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets the workflow of an image."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets the graph of an image."""
pass

View File

@@ -7,7 +7,9 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase
@@ -54,9 +56,8 @@ class DiskImageFileStorage(ImageFileStorageBase):
self,
image: PILImageType,
image_name: str,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
try:
@@ -67,14 +68,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
info_dict = {}
if metadata is not None:
info_dict["invokeai_metadata"] = metadata
pnginfo.add_text("invokeai_metadata", metadata)
metadata_json = metadata.model_dump_json()
info_dict["invokeai_metadata"] = metadata_json
pnginfo.add_text("invokeai_metadata", metadata_json)
if workflow is not None:
info_dict["invokeai_workflow"] = workflow
pnginfo.add_text("invokeai_workflow", workflow)
if graph is not None:
info_dict["invokeai_graph"] = graph
pnginfo.add_text("invokeai_graph", graph)
workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
@@ -129,18 +129,11 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def get_workflow(self, image_name: str) -> str | None:
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if isinstance(workflow, str):
return workflow
return None
def get_graph(self, image_name: str) -> str | None:
image = self.get(image_name)
graph = image.info.get("invokeai_graph", None)
if isinstance(graph, str):
return graph
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
return None
def __validate_storage_folders(self) -> None:

View File

@@ -80,7 +80,7 @@ class ImageRecordStorageBase(ABC):
starred: Optional[bool] = False,
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
metadata: Optional[MetadataField] = None,
) -> datetime:
"""Saves an image record."""
pass

View File

@@ -328,9 +328,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
starred: Optional[bool] = False,
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
metadata: Optional[MetadataField] = None,
) -> datetime:
try:
metadata_json = metadata.model_dump_json() if metadata is not None else None
self._lock.acquire()
self._cursor.execute(
"""--sql
@@ -357,7 +358,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height,
node_id,
session_id,
metadata,
metadata_json,
is_intermediate,
starred,
has_workflow,

View File

@@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC):
@@ -50,9 +51,8 @@ class ImageServiceABC(ABC):
session_id: Optional[str] = None,
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@@ -87,12 +87,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets an image's workflow."""
pass

View File

@@ -5,6 +5,7 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from ..image_files.image_files_common import (
ImageFileDeleteException,
@@ -41,9 +42,8 @@ class ImageService(ImageServiceABC):
session_id: Optional[str] = None,
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@@ -64,7 +64,7 @@ class ImageService(ImageServiceABC):
image_category=image_category,
width=width,
height=height,
has_workflow=workflow is not None or graph is not None,
has_workflow=workflow is not None,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
@@ -75,7 +75,7 @@ class ImageService(ImageServiceABC):
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
image_name=image_name, image=image, metadata=metadata, workflow=workflow
)
image_dto = self.get_dto(image_name)
@@ -157,7 +157,7 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image metadata")
raise e
def get_workflow(self, image_name: str) -> Optional[str]:
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
try:
return self.__invoker.services.image_files.get_workflow(image_name)
except ImageFileNotFoundException:
@@ -167,16 +167,6 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image workflow")
raise
def get_graph(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_graph(image_name)
except ImageFileNotFoundException:
self.__invoker.services.logger.error("Image file not found")
raise
except Exception:
self.__invoker.services.logger.error("Problem getting image graph")
raise
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))

View File

@@ -237,8 +237,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
user_id=None,
project_id=None,
)
pass

View File

@@ -180,9 +180,9 @@ class ImagesInterface(InvocationContextInterface):
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
metadata_ = metadata.model_dump_json()
elif isinstance(self._data.invocation, WithMetadata) and self._data.invocation.metadata:
metadata_ = self._data.invocation.metadata.model_dump_json()
metadata_ = metadata
elif isinstance(self._data.invocation, WithMetadata):
metadata_ = self._data.invocation.metadata
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None
@@ -191,14 +191,6 @@ class ImagesInterface(InvocationContextInterface):
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
board_id_ = self._data.invocation.board.board_id
workflow_ = None
if self._data.queue_item.workflow:
workflow_ = self._data.queue_item.workflow.model_dump_json()
graph_ = None
if self._data.queue_item.session.graph:
graph_ = self._data.queue_item.session.graph.model_dump_json()
return self._services.images.create(
image=image,
is_intermediate=self._data.invocation.is_intermediate,
@@ -206,8 +198,7 @@ class ImagesInterface(InvocationContextInterface):
board_id=board_id_,
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=workflow_,
graph=graph_,
workflow=self._data.queue_item.workflow,
session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id,
)

View File

@@ -775,14 +775,10 @@
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"missingNode": "Missing invocation node",
"missingInvocationTemplate": "Missing invocation template",
"missingFieldTemplate": "Missing field template",
"nodePack": "Node pack",
"collection": "Collection",
"singleFieldType": "{{name}} (Single)",
"collectionFieldType": "{{name}} (Collection)",
"collectionOrScalarFieldType": "{{name}} (Single or Collection)",
"collectionFieldType": "{{name}} Collection",
"collectionOrScalarFieldType": "{{name}} Collection|Scalar",
"colorCodeEdges": "Color-Code Edges",
"colorCodeEdgesHelp": "Color-code edges according to their connected fields",
"connectionWouldCreateCycle": "Connection would create a cycle",
@@ -884,7 +880,6 @@
"versionUnknown": " Version Unknown",
"workflow": "Workflow",
"graph": "Graph",
"noGraph": "No Graph",
"workflowAuthor": "Author",
"workflowContact": "Contact",
"workflowDescription": "Short Description",
@@ -952,7 +947,7 @@
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"controlAdapterNoImageSelected": "no Control Adapter image selected",
"controlAdapterImageNotProcessed": "Control Adapter image not processed",
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}",
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64",
"ipAdapterNoModelSelected": "no IP adapter selected",
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected",

View File

@@ -21,7 +21,6 @@ import i18n from 'i18n';
import { size } from 'lodash-es';
import { memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import PreselectedImage from './PreselectedImage';
@@ -47,7 +46,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
useSocketIO();
useGlobalModifiersInit();
useGlobalHotkeys();
useGetOpenAPISchemaQuery();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
@@ -44,9 +43,6 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
type: 'TOAST',
toastOptions: { title: t('toast.canvasSavedGallery') },
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),
},
})
);
},

View File

@@ -16,7 +16,6 @@ import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
@@ -48,10 +47,8 @@ const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batc
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => {
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const state = getState();
const originalState = getOriginalState();
// Cancel any in-progress instances of this listener
cancelActiveListeners();
@@ -60,33 +57,21 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// Delay before starting actual work
await delay(DEBOUNCE_MS);
// Double-check that we are still eligible for processing
const state = getState();
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
// If we have no image or there is no processor config, bail
if (!layer) {
return;
}
// We should only process if the processor settings or image have changed
const originalLayer = originalState.controlLayers.present.layers
.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const originalImage = originalLayer?.controlAdapter.image;
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image;
const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
// Neither config nor image have changed, we can bail
return;
}
if (!image || !config) {
// - If we have no image, we have nothing to process
// - If we have no processor config, we have nothing to process
// Clear the processed image and bail
// The user has reset the image or config, so we should clear the processed image
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
return;
}
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
@@ -96,8 +81,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
}
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error...
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
@@ -18,7 +18,6 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
},
});

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
@@ -31,12 +31,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
}
try {
const updatedNode = updateNode(node, template);
dispatch(
nodesChanged([
{ type: 'remove', id: updatedNode.id },
{ type: 'add', item: updatedNode },
])
);
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
} catch (e) {
if (e instanceof NodeUpdateError) {
unableToUpdateCount++;

View File

@@ -4,49 +4,31 @@ import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow);
return validateWorkflow(parsed, templates);
} else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return validateWorkflow(workflow, templates);
} else {
throw new Error('No workflow or graph provided');
}
};
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => {
const log = logger('nodes');
const { data, asCopy } = action.payload;
const { workflow, asCopy } = action.payload;
const nodeTemplates = $templates.get();
try {
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
delete workflow.id;
delete validatedWorkflow.id;
}
dispatch(workflowLoaded(workflow));
dispatch(workflowLoaded(validatedWorkflow));
if (!warnings.length) {
dispatch(
addToast(

View File

@@ -137,7 +137,7 @@ const createSelector = (templates: Templates) =>
if (l.controlAdapter.type === 't2i_adapter') {
const multiple = model?.base === 'sdxl' ? 32 : 64;
if (size.width % multiple !== 0 || size.height % multiple !== 0) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
}
}

View File

@@ -616,24 +616,12 @@ export const controlLayersSlice = createSlice({
iiLayerAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload;
// Retain opacity and denoising strength of existing initial image layer if exists
let opacity = 1;
let denoisingStrength = 0.75;
const iiLayer = state.layers.find((l) => l.id === layerId);
if (iiLayer) {
assert(isInitialImageLayer(iiLayer));
opacity = iiLayer.opacity;
denoisingStrength = iiLayer.denoisingStrength;
}
// Highlander! There can be only one!
state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true));
const layer: InitialImageLayer = {
id: layerId,
type: 'initial_image_layer',
opacity,
opacity: 1,
x: 0,
y: 0,
bbox: null,
@@ -641,7 +629,7 @@ export const controlLayersSlice = createSlice({
isEnabled: true,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
isSelected: true,
denoisingStrength,
denoisingStrength: 0.75,
};
state.layers.push(layer);
exclusivelySelectLayer(state, layer.id);

View File

@@ -11,12 +11,10 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { size } from 'lodash-es';
import { memo, useCallback } from 'react';
import { flushSync } from 'react-dom';
import { useTranslation } from 'react-i18next';
@@ -50,7 +48,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('canvas');
const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage();
const templates = useStore($templates);
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } =
useImageActions(imageDTO?.image_name);
@@ -136,7 +133,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
<MenuItem
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
onClickCapture={handleLoadWorkflow}
isDisabled={!imageDTO.has_workflow || !size(templates)}
isDisabled={!imageDTO.has_workflow}
>
{t('nodes.loadWorkflow')}
</MenuItem>

View File

@@ -1,34 +0,0 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
import DataViewer from './DataViewer';
type Props = {
image: ImageDTO;
};
const ImageMetadataGraphTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { currentData } = useDebouncedImageWorkflow(image);
const graph = useMemo(() => {
if (currentData?.graph) {
try {
return JSON.parse(currentData.graph);
} catch {
return null;
}
}
return null;
}, [currentData]);
if (!graph) {
return <IAINoContentFallback label={t('nodes.noGraph')} />;
}
return <DataViewer data={graph} label={t('nodes.graph')} />;
};
export default memo(ImageMetadataGraphTabContent);

View File

@@ -1,7 +1,6 @@
import { ExternalLink, Flex, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import ImageMetadataGraphTabContent from 'features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent';
import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
import { handlers } from 'features/metadata/util/handlers';
import { memo } from 'react';
@@ -53,7 +52,6 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<Tab>{t('metadata.metadata')}</Tab>
<Tab>{t('metadata.imageDetails')}</Tab>
<Tab>{t('metadata.workflow')}</Tab>
<Tab>{t('nodes.graph')}</Tab>
</TabList>
<TabPanels>
@@ -83,9 +81,6 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<TabPanel>
<ImageMetadataWorkflowTabContent image={image} />
</TabPanel>
<TabPanel>
<ImageMetadataGraphTabContent image={image} />
</TabPanel>
</TabPanels>
</Tabs>
</Flex>

View File

@@ -1,5 +1,5 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
@@ -12,17 +12,7 @@ type Props = {
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { currentData } = useDebouncedImageWorkflow(image);
const workflow = useMemo(() => {
if (currentData?.workflow) {
try {
return JSON.parse(currentData.workflow);
} catch {
return null;
}
}
return null;
}, [currentData]);
const { workflow } = useDebouncedImageWorkflow(image);
if (!workflow) {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;

View File

@@ -1,5 +1,4 @@
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
@@ -13,14 +12,12 @@ import { sentImageToImg2Img } from 'features/gallery/store/actions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
import { $templates } from 'features/nodes/store/nodesSlice';
import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { size } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -51,7 +48,7 @@ const CurrentImageButtons = () => {
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
const templates = useStore($templates);
const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const { t } = useTranslation();
@@ -146,7 +143,7 @@ const CurrentImageButtons = () => {
icon={<PiFlowArrowBold />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!imageDTO?.has_workflow || !size(templates)}
isDisabled={!imageDTO?.has_workflow}
onClick={handleLoadWorkflow}
isLoading={getAndLoadEmbeddedWorkflowResult.isLoading}
/>

View File

@@ -9,29 +9,27 @@ import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
$cursorPos,
$edgePendingUpdate,
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
closeAddNodePopover,
edgesChanged,
nodesChanged,
connectionMade,
nodeAdded,
openAddNodePopover,
} from 'features/nodes/store/nodesSlice';
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { filter, map, memoize, some } from 'lodash-es';
import type { KeyboardEventHandler } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import type { EdgeChange, NodeChange } from 'reactflow';
import { assert } from 'tsafe';
const createRegex = memoize(
(inputValue: string) =>
@@ -71,19 +69,17 @@ const AddNodePopover = () => {
const filteredTemplates = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices
const templatesArray = map(templates);
if (!pendingConnection) {
return templatesArray;
return map(templates);
}
return filter(templates, (template) => {
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
return some(candidateFields, (field) => {
const sourceType =
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
const targetType =
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
return validateConnectionTypes(sourceType, targetType);
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
});
}, [templates, pendingConnection]);
@@ -133,37 +129,11 @@ const AddNodePopover = () => {
});
return null;
}
// Find a cozy spot for the node
const cursorPos = $cursorPos.get();
const { nodes, edges } = store.getState().nodes.present;
node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y);
node.selected = true;
// Deselect all other nodes and edges
const nodeChanges: NodeChange[] = [{ type: 'add', item: node }];
const edgeChanges: EdgeChange[] = [];
nodes.forEach(({ id, selected }) => {
if (selected) {
nodeChanges.push({ type: 'select', id, selected: false });
}
});
edges.forEach(({ id, selected }) => {
if (selected) {
edgeChanges.push({ type: 'select', id, selected: false });
}
});
// Onwards!
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
dispatch(nodeAdded({ node, cursorPos }));
return node;
},
[buildInvocation, store, dispatch, t, toaster]
[dispatch, buildInvocation, toaster, t]
);
const onChange = useCallback<ComboboxOnChange>(
@@ -175,28 +145,12 @@ const AddNodePopover = () => {
// Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) {
const edgePendingUpdate = $edgePendingUpdate.get();
const { handleType } = pendingConnection;
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const template = templates[node.data.type];
assert(template, 'Template not found');
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(
source,
sourceHandle,
target,
targetHandle,
nodes,
edges,
templates,
edgePendingUpdate
);
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
if (connection) {
const newEdge = connectionToEdge(connection);
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
dispatch(connectionMade(connection));
}
}
@@ -206,23 +160,24 @@ const AddNodePopover = () => {
);
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
if (!$isAddNodePopoverOpen.get()) {
e.preventDefault();
openAddNodePopover();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
}
e.preventDefault();
openAddNodePopover();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
}, []);
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
if ($isAddNodePopoverOpen.get()) {
closeAddNodePopover();
}
closeAddNodePopover();
}, []);
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
useHotkeys(['escape'], handleHotkeyClose, { enableOnFormTags: ['TEXTAREA'] });
useHotkeys(['escape'], handleHotkeyClose);
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
if (e.key === 'Escape') {
closeAddNodePopover();
}
}, []);
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
@@ -260,6 +215,7 @@ const AddNodePopover = () => {
filterOption={filterOption}
onChange={onChange}
onMenuClose={closeAddNodePopover}
onKeyDown={onKeyDown}
inputRef={inputRef}
closeMenuOnSelect={false}
/>

View File

@@ -1,6 +1,6 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
@@ -8,35 +8,38 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import {
$cursorPos,
$didUpdateEdge,
$edgePendingUpdate,
$isAddNodePopoverOpen,
$lastEdgeUpdateMouseEvent,
$isUpdatingEdge,
$pendingConnection,
$viewport,
connectionMade,
edgeAdded,
edgeDeleted,
edgesChanged,
edgesDeleted,
nodesChanged,
nodesDeleted,
redo,
selectedAll,
undo,
} from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import type {
EdgeChange,
NodeChange,
OnEdgesChange,
OnEdgesDelete,
OnEdgeUpdateFunc,
OnInit,
OnMoveEnd,
OnNodesChange,
OnNodesDelete,
ProOptions,
ReactFlowProps,
ReactFlowState,
} from 'reactflow';
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from 'reactflow';
import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow';
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
@@ -45,6 +48,8 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
import NotesNode from './nodes/Notes/NotesNode';
const DELETE_KEYS = ['Delete', 'Backspace'];
const edgeTypes = {
collapsed: InvocationCollapsedEdge,
default: InvocationDefaultEdge,
@@ -76,8 +81,6 @@ export const Flow = memo(() => {
const flowWrapper = useRef<HTMLDivElement>(null);
const isValidConnection = useIsValidConnection();
const cancelConnection = useReactFlowStore(selectCancelConnection);
const updateNodeInternals = useUpdateNodeInternals();
const store = useAppStore();
useWorkflowWatcher();
useSyncExecutionState();
const [borderRadius] = useToken('radii', ['base']);
@@ -90,17 +93,29 @@ export const Flow = memo(() => {
);
const onNodesChange: OnNodesChange = useCallback(
(nodeChanges) => {
dispatch(nodesChanged(nodeChanges));
(changes) => {
dispatch(nodesChanged(changes));
},
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(
(changes) => {
if (changes.length > 0) {
dispatch(edgesChanged(changes));
}
dispatch(edgesChanged(changes));
},
[dispatch]
);
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
},
[dispatch]
);
const onNodesDelete: OnNodesDelete = useCallback(
(nodes) => {
dispatch(nodesDeleted(nodes));
},
[dispatch]
);
@@ -142,50 +157,45 @@ export const Flow = memo(() => {
* where the edge is deleted if you click it accidentally).
*/
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback((e, edge, _handleType) => {
$edgePendingUpdate.set(edge);
$didUpdateEdge.set(false);
$lastEdgeUpdateMouseEvent.set(e);
}, []);
// We have a ref for cursor position, but it is the *projected* cursor position.
// Easiest to just keep track of the last mouse event for this particular feature
const edgeUpdateMouseEvent = useRef<MouseEvent>();
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback(
(e, edge, _handleType) => {
$isUpdatingEdge.set(true);
// update mouse event
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
},
[dispatch]
);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(oldEdge, newConnection) => {
// This event is fired when an edge update is successful
$didUpdateEdge.set(true);
// When an edge update is successful, we need to delete the old edge and create a new one
const newEdge = connectionToEdge(newConnection);
dispatch(
edgesChanged([
{ type: 'remove', id: oldEdge.id },
{ type: 'add', item: newEdge },
])
);
// Because we shift the position of handles depending on whether a field is connected or not, we must use
// updateNodeInternals to tell reactflow to recalculate the positions of the handles
updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]);
(_oldEdge, newConnection) => {
// Because we deleted the edge when the update started, we must create a new edge from the connection
dispatch(connectionMade(newConnection));
},
[dispatch, updateNodeInternals]
[dispatch]
);
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> = useCallback(
(e, edge, _handleType) => {
const didUpdateEdge = $didUpdateEdge.get();
// Fall back to a reasonable default event
const lastEvent = $lastEdgeUpdateMouseEvent.get() ?? { clientX: 0, clientY: 0 };
// We have to narrow this event down to MouseEvents - could be TouchEvent
const didMouseMove =
!('touches' in e) && Math.hypot(e.clientX - lastEvent.clientX, e.clientY - lastEvent.clientY) > 5;
// If we got this far and did not successfully update an edge, and the mouse moved away from the handle,
// the user probably intended to delete the edge
if (!didUpdateEdge && didMouseMove) {
dispatch(edgesChanged([{ type: 'remove', id: edge.id }]));
}
$edgePendingUpdate.set(null);
$didUpdateEdge.set(false);
$isUpdatingEdge.set(false);
$pendingConnection.set(null);
$lastEdgeUpdateMouseEvent.set(null);
// Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting
// the edge update - we need to add it back
if (
// ignore touch events
!('touches' in e) &&
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
edgeUpdateMouseEvent.current?.clientY === e.clientY
) {
dispatch(edgeAdded(edge));
}
// reset mouse event
edgeUpdateMouseEvent.current = undefined;
},
[dispatch]
);
@@ -206,27 +216,9 @@ export const Flow = memo(() => {
const onSelectAllHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
const { nodes, edges } = store.getState().nodes.present;
const nodeChanges: NodeChange[] = [];
const edgeChanges: EdgeChange[] = [];
nodes.forEach(({ id, selected }) => {
if (!selected) {
nodeChanges.push({ type: 'select', id, selected: true });
}
});
edges.forEach(({ id, selected }) => {
if (!selected) {
edgeChanges.push({ type: 'select', id, selected: true });
}
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
dispatch(selectedAll());
},
[dispatch, store]
[dispatch]
);
useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey);
@@ -263,37 +255,12 @@ export const Flow = memo(() => {
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
const onEscapeHotkey = useCallback(() => {
if (!$edgePendingUpdate.get()) {
$pendingConnection.set(null);
$isAddNodePopoverOpen.set(false);
cancelConnection();
}
$pendingConnection.set(null);
$isAddNodePopoverOpen.set(false);
cancelConnection();
}, [cancelConnection]);
useHotkeys('esc', onEscapeHotkey);
const onDeleteHotkey = useCallback(() => {
const { nodes, edges } = store.getState().nodes.present;
const nodeChanges: NodeChange[] = [];
const edgeChanges: EdgeChange[] = [];
nodes
.filter((n) => n.selected)
.forEach(({ id }) => {
nodeChanges.push({ type: 'remove', id });
});
edges
.filter((e) => e.selected)
.forEach(({ id }) => {
edgeChanges.push({ type: 'remove', id });
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
}, [dispatch, store]);
useHotkeys(['delete', 'backspace'], onDeleteHotkey);
return (
<ReactFlow
id="workflow-editor"
@@ -307,9 +274,11 @@ export const Flow = memo(() => {
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}
onEdgeUpdate={onEdgeUpdate}
onEdgeUpdateStart={onEdgeUpdateStart}
onEdgeUpdateEnd={onEdgeUpdateEnd}
onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
@@ -323,10 +292,9 @@ export const Flow = memo(() => {
proOptions={proOptions}
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={null}
deleteKeyCode={DELETE_KEYS}
selectionMode={selectionMode}
elevateEdgesOnSelect
nodeDragThreshold={1}
>
<Background />
</ReactFlow>

View File

@@ -2,13 +2,13 @@ import { Badge, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { makeEdgeSelector } from 'features/nodes/components/flow/edges/util/makeEdgeSelector';
import { $templates } from 'features/nodes/store/nodesSlice';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
import { makeEdgeSelector } from './util/makeEdgeSelector';
const InvocationCollapsedEdge = ({
sourceX,
sourceY,
@@ -18,19 +18,19 @@ const InvocationCollapsedEdge = ({
targetPosition,
markerEnd,
data,
selected = false,
selected,
source,
sourceHandleId,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId),
[templates, source, sourceHandleId, target, targetHandleId]
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, selected, source, sourceHandleId, target, targetHandleId]
);
const { shouldAnimateEdges, areConnectedNodesSelected } = useAppSelector(selector);
const { isSelected, shouldAnimate } = useAppSelector(selector);
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
@@ -44,8 +44,14 @@ const InvocationCollapsedEdge = ({
const { base500 } = useChakraThemeTokens();
const edgeStyles = useMemo(
() => getEdgeStyles(base500, selected, shouldAnimateEdges, areConnectedNodesSelected),
[areConnectedNodesSelected, base500, selected, shouldAnimateEdges]
() => ({
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}),
[base500, isSelected, shouldAnimate]
);
return (
@@ -54,15 +60,11 @@ const InvocationCollapsedEdge = ({
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
data-testid="asdfasdfasdf"
position="absolute"
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
className="nodrag nopan"
// Unfortunately edge labels do not get the same zIndex treatment as edges do, so we need to manage this ourselves
// See: https://github.com/xyflow/xyflow/issues/3658
zIndex={1001}
>
<Badge variant="solid" bg="base.500" opacity={selected ? 0.8 : 0.5} boxShadow="base">
<Badge variant="solid" bg="base.500" opacity={isSelected ? 0.8 : 0.5} boxShadow="base">
{data.count}
</Badge>
</Flex>

View File

@@ -1,8 +1,8 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
@@ -17,7 +17,7 @@ const InvocationDefaultEdge = ({
sourcePosition,
targetPosition,
markerEnd,
selected = false,
selected,
source,
target,
sourceHandleId,
@@ -25,11 +25,11 @@ const InvocationDefaultEdge = ({
}: EdgeProps) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId),
[templates, source, sourceHandleId, target, targetHandleId]
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, source, sourceHandleId, target, targetHandleId, selected]
);
const { shouldAnimateEdges, areConnectedNodesSelected, stroke, label } = useAppSelector(selector);
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels);
const [edgePath, labelX, labelY] = getBezierPath({
@@ -41,9 +41,15 @@ const InvocationDefaultEdge = ({
targetPosition,
});
const edgeStyles = useMemo(
() => getEdgeStyles(stroke, selected, shouldAnimateEdges, areConnectedNodesSelected),
[areConnectedNodesSelected, stroke, selected, shouldAnimateEdges]
const edgeStyles = useMemo<CSSProperties>(
() => ({
strokeWidth: isSelected ? 3 : 2,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}),
[isSelected, shouldAnimate, stroke]
);
return (
@@ -59,13 +65,13 @@ const InvocationDefaultEdge = ({
bg="base.800"
borderRadius="base"
borderWidth={1}
borderColor={selected ? 'undefined' : 'transparent'}
opacity={selected ? 1 : 0.5}
borderColor={isSelected ? 'undefined' : 'transparent'}
opacity={isSelected ? 1 : 0.5}
py={1}
px={3}
shadow="md"
>
<Text size="sm" fontWeight="semibold" color={selected ? 'base.100' : 'base.300'}>
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
{label}
</Text>
</Flex>

View File

@@ -1,7 +1,6 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELD_COLORS } from 'features/nodes/types/constants';
import type { FieldType } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
export const getFieldColor = (fieldType: FieldType | null): string => {
if (!fieldType) {
@@ -11,16 +10,3 @@ export const getFieldColor = (fieldType: FieldType | null): string => {
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};
export const getEdgeStyles = (
stroke: string,
selected: boolean,
shouldAnimateEdges: boolean,
areConnectedNodesSelected: boolean
): CSSProperties => ({
strokeWidth: 3,
stroke,
opacity: selected ? 1 : 0.5,
animation: shouldAnimateEdges ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: selected || areConnectedNodesSelected ? 5 : 'none',
});

View File

@@ -1,6 +1,5 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { deepClone } from 'common/util/deepClone';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
@@ -9,8 +8,8 @@ import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
const defaultReturnValue = {
areConnectedNodesSelected: false,
shouldAnimateEdges: false,
isSelected: false,
shouldAnimate: false,
stroke: colorTokenToCssVar('base.500'),
label: '',
};
@@ -20,27 +19,21 @@ export const makeEdgeSelector = (
source: string,
sourceHandleId: string | null | undefined,
target: string,
targetHandleId: string | null | undefined
targetHandleId: string | null | undefined,
selected?: boolean
) =>
createMemoizedSelector(
selectNodesSlice,
selectWorkflowSettingsSlice,
(
nodes,
workflowSettings
): { areConnectedNodesSelected: boolean; shouldAnimateEdges: boolean; stroke: string; label: string } => {
const { shouldAnimateEdges, shouldColorEdges } = workflowSettings;
(nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const returnValue = deepClone(defaultReturnValue);
returnValue.shouldAnimateEdges = shouldAnimateEdges;
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
returnValue.areConnectedNodesSelected = Boolean(sourceNode?.selected || targetNode?.selected);
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return returnValue;
return defaultReturnValue;
}
const sourceNodeTemplate = templates[sourceNode.data.type];
@@ -49,10 +42,16 @@ export const makeEdgeSelector = (
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
returnValue.stroke = sourceType && shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const stroke =
sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
returnValue.label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return returnValue;
return {
isSelected,
shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected,
stroke,
label,
};
}
);

View File

@@ -1,20 +0,0 @@
import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
type Props = PropsWithChildren<{
nodeId: string;
fieldName?: string;
}>;
export const MissingFallback = memo((props: Props) => {
// We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field
const exists = useDoesFieldExist(props.nodeId, props.fieldName);
if (!exists) {
return null;
}
return props.children;
});
MissingFallback.displayName = 'MissingFallback';

View File

@@ -25,11 +25,10 @@ interface Props {
kind: 'inputs' | 'outputs';
isMissingInput?: boolean;
withTooltip?: boolean;
shouldDim?: boolean;
}
const EditableFieldTitle = forwardRef((props: Props, ref) => {
const { nodeId, fieldName, kind, isMissingInput = false, withTooltip = false, shouldDim = false } = props;
const { nodeId, fieldName, kind, isMissingInput = false, withTooltip = false } = props;
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const { t } = useTranslation();
@@ -40,11 +39,13 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
const handleSubmit = useCallback(
async (newTitleRaw: string) => {
const newTitle = newTitleRaw.trim();
const finalTitle = newTitle || fieldTemplateTitle || t('nodes.unknownField');
setLocalTitle(finalTitle);
dispatch(fieldLabelChanged({ nodeId, fieldName, label: finalTitle }));
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || t('nodes.unknownField'));
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
},
[fieldTemplateTitle, dispatch, nodeId, fieldName, t]
[label, fieldTemplateTitle, dispatch, nodeId, fieldName, t]
);
const handleChange = useCallback((newTitle: string) => {
@@ -79,7 +80,6 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
sx={editablePreviewStyles}
noOfLines={1}
color={isMissingInput ? 'error.300' : 'base.300'}
opacity={shouldDim ? 0.5 : 1}
/>
</Tooltip>
<EditableInput className="nodrag" sx={editableInputStyles} />

View File

@@ -2,12 +2,10 @@ import { Tooltip } from '@invoke-ai/ui-library';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import { type FieldInputTemplate, type FieldOutputTemplate, isSingle } from 'features/nodes/types/field';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { HandleType } from 'reactflow';
import { Handle, Position } from 'reactflow';
@@ -16,12 +14,11 @@ type FieldHandleProps = {
handleType: HandleType;
isConnectionInProgress: boolean;
isConnectionStartField: boolean;
validationResult: ValidationResult;
connectionError?: string;
};
const FieldHandle = (props: FieldHandleProps) => {
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props;
const { t } = useTranslation();
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props;
const { name } = fieldTemplate;
const type = fieldTemplate.type;
const fieldTypeName = useFieldTypeName(type);
@@ -29,11 +26,11 @@ const FieldHandle = (props: FieldHandleProps) => {
const isModelType = MODEL_TYPES.some((t) => t === type.name);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor: !isSingle(type) ? colorTokenToCssVar('base.900') : color,
backgroundColor: type.isCollection || type.isCollectionOrScalar ? colorTokenToCssVar('base.900') : color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: !isSingle(type) ? 4 : 0,
borderWidth: type.isCollection || type.isCollectionOrScalar ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
@@ -46,11 +43,11 @@ const FieldHandle = (props: FieldHandleProps) => {
s.insetInlineEnd = '-1rem';
}
if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) {
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
s.filter = 'opacity(0.4) grayscale(0.7)';
}
if (isConnectionInProgress && !validationResult.isValid) {
if (isConnectionInProgress && connectionError) {
if (isConnectionStartField) {
s.cursor = 'grab';
} else {
@@ -61,14 +58,14 @@ const FieldHandle = (props: FieldHandleProps) => {
}
return s;
}, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]);
}, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && validationResult.messageTKey) {
return t(validationResult.messageTKey);
if (isConnectionInProgress && connectionError) {
return connectionError;
}
return fieldTypeName;
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
}, [connectionError, fieldTypeName, isConnectionInProgress]);
return (
<Tooltip

View File

@@ -24,7 +24,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false);
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
const isMissingInput = useMemo(() => {
@@ -79,7 +79,6 @@ const InputField = ({ nodeId, fieldName }: Props) => {
kind="inputs"
isMissingInput={isMissingInput}
withTooltip
shouldDim
/>
</FormControl>
@@ -88,7 +87,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
validationResult={validationResult}
connectionError={connectionError}
/>
</InputFieldWrapper>
);
@@ -126,7 +125,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
validationResult={validationResult}
connectionError={connectionError}
/>
)}
</InputFieldWrapper>

View File

@@ -1,4 +1,3 @@
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -24,8 +23,6 @@ import {
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
isMainModelFieldInputTemplate,
isModelIdentifierFieldInputInstance,
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
@@ -98,10 +95,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -3,7 +3,6 @@ import { CSS } from '@dnd-kit/utilities';
import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
@@ -21,7 +20,7 @@ type Props = {
fieldName: string;
};
const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
const LinearViewField = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
@@ -100,12 +99,4 @@ const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
);
};
const LinearViewField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<LinearViewFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(LinearViewField);

View File

@@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
if (!fieldTemplate) {
@@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
handleType="source"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
validationResult={validationResult}
connectionError={connectionError}
/>
</OutputFieldWrapper>
);

View File

@@ -1,66 +0,0 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
const ModelIdentifierFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetModelConfigsQuery();
const _onChange = useCallback(
(value: AnyModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldModelIdentifierValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const modelConfigs = useMemo(() => {
if (!data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(data);
}, [data]);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
groupByType: true,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(ModelIdentifierFieldInputComponent);

View File

@@ -1,15 +1,14 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { nodesChanged } from 'features/nodes/store/nodesSlice';
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { MouseEvent, PropsWithChildren } from 'react';
import { memo, useCallback } from 'react';
import type { NodeChange } from 'reactflow';
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
@@ -19,7 +18,6 @@ type NodeWrapperProps = PropsWithChildren & {
const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const store = useAppStore();
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
const executionState = useExecutionState(nodeId);
@@ -39,20 +37,11 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
const { nodes } = store.getState().nodes.present;
const nodeChanges: NodeChange[] = [];
nodes.forEach(({ id, selected }) => {
if (selected !== (id === nodeId)) {
nodeChanges.push({ type: 'select', id, selected: id === nodeId });
}
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
dispatch(nodeExclusivelySelected(nodeId));
}
onCloseGlobal();
},
[onCloseGlobal, store, dispatch, nodeId]
[dispatch, onCloseGlobal, nodeId]
);
return (

View File

@@ -7,10 +7,8 @@ import WorkflowInfoTooltipContent from './viewMode/WorkflowInfoTooltipContent';
import { WorkflowWarning } from './viewMode/WorkflowWarning';
export const WorkflowName = () => {
const { name, isTouched, mode } = useAppSelector((s) => s.workflow);
const { t } = useTranslation();
const name = useAppSelector((s) => s.workflow.name);
const isTouched = useAppSelector((s) => s.workflow.isTouched);
const mode = useAppSelector((s) => s.workflow.mode);
return (
<Flex gap="1" alignItems="center">

View File

@@ -1,7 +1,6 @@
import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent';
import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
@@ -15,7 +14,7 @@ type Props = {
fieldName: string;
};
const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
const WorkflowField = ({ nodeId, fieldName }: Props) => {
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs');
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
@@ -51,12 +50,4 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
);
};
const WorkflowField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<WorkflowFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(WorkflowField);

View File

@@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DndSortable from 'features/dnd/components/DndSortable';
import type { DragEndEvent } from 'features/dnd/types';
import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
@@ -40,18 +40,16 @@ const WorkflowLinearTab = () => {
[dispatch, fields]
);
const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]);
return (
<Box position="relative" w="full" h="full">
<ScrollableContent>
<DndSortable onDragEnd={handleDragEnd} items={items}>
<DndSortable onDragEnd={handleDragEnd} items={fields.map((field) => `${field.nodeId}.${field.fieldName}`)}>
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
{isLoading ? (
<IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} />
) : fields.length ? (
fields.map(({ nodeId, fieldName }) => (
<LinearViewFieldInternal key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
<LinearViewField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
))
) : (
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />

View File

@@ -1,6 +1,7 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
@@ -8,20 +9,31 @@ import { useMemo } from 'react';
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useMemo(() => {
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return false;
}
return (
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});
const _fieldNames = getSortedFilteredFieldNames(fields);
if (_fieldNames.length === 0) {
return EMPTY_ARRAY;
}
return _fieldNames;
}, [template.inputs]);
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
return fieldNames;
};

View File

@@ -2,69 +2,58 @@ import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import {
$didUpdateEdge,
$edgePendingUpdate,
$isAddNodePopoverOpen,
$isUpdatingEdge,
$pendingConnection,
$templates,
edgesChanged,
connectionMade,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useCallback, useMemo } from 'react';
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { assert } from 'tsafe';
export const useConnection = () => {
const store = useAppStore();
const templates = useStore($templates);
const updateNodeInternals = useUpdateNodeInternals();
const onConnectStart = useCallback<OnConnectStart>(
(event, { nodeId, handleId, handleType }) => {
assert(nodeId && handleId && handleType, 'Invalid connection start event');
(event, params) => {
const nodes = store.getState().nodes.present.nodes;
const { nodeId, handleId, handleType } = params;
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
const node = nodes.find((n) => n.id === nodeId);
if (!node) {
return;
}
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
const template = templates[node.data.type];
if (!template) {
return;
}
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
const fieldTemplate = fieldTemplates[handleId];
if (!fieldTemplate) {
return;
}
$pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate });
assert(template, `Template not found for node type: ${node.data.type}`);
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
$pendingConnection.set({
node,
template,
fieldTemplate,
});
},
[store, templates]
);
const onConnect = useCallback<OnConnect>(
(connection) => {
const { dispatch } = store;
const newEdge = connectionToEdge(connection);
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
updateNodeInternals([newEdge.source, newEdge.target]);
dispatch(connectionMade(connection));
$pendingConnection.set(null);
},
[store, updateNodeInternals]
[store]
);
const onConnectEnd = useCallback<OnConnectEnd>(() => {
const { dispatch } = store;
const pendingConnection = $pendingConnection.get();
const edgePendingUpdate = $edgePendingUpdate.get();
const isUpdatingEdge = $isUpdatingEdge.get();
const mouseOverNodeId = $mouseOverNode.get();
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
// update logic can finish up
if (edgePendingUpdate && !mouseOverNodeId) {
if (isUpdatingEdge && !mouseOverNodeId) {
$pendingConnection.set(null);
return;
}
@@ -74,41 +63,30 @@ export const useConnection = () => {
}
const { nodes, edges } = store.getState().nodes.present;
if (mouseOverNodeId) {
const { handleType } = pendingConnection;
const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId;
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId;
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
if (!candidateNode) {
// The mouse is over a non-invocation node - bail
return;
}
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const connection = getFirstValidConnection(
source,
sourceHandle,
target,
targetHandle,
templates,
nodes,
edges,
templates,
edgePendingUpdate
pendingConnection,
candidateNode,
candidateTemplate
);
if (connection) {
const newEdge = connectionToEdge(connection);
const edgeChanges: EdgeChange[] = [{ type: 'add', item: newEdge }];
const nodesToUpdate = [newEdge.source, newEdge.target];
if (edgePendingUpdate) {
$didUpdateEdge.set(true);
edgeChanges.push({ type: 'remove', id: edgePendingUpdate.id });
nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target);
}
dispatch(edgesChanged(edgeChanges));
updateNodeInternals(nodesToUpdate);
dispatch(connectionMade(connection));
}
$pendingConnection.set(null);
} else {
// The mouse is not over a node - we should open the add node popover
$isAddNodePopoverOpen.set(true);
}
}, [store, templates, updateNodeInternals]);
}, [store, templates]);
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
return api;

View File

@@ -1,6 +1,7 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
@@ -8,22 +9,31 @@ import { useMemo } from 'react';
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useMemo(() => {
// get the visible fields
const fields = map(template.inputs).filter(
(field) =>
(field.input === 'connection' && !isSingleOrCollection(field.type)) ||
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return true;
}
return (
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
const _fieldNames = getSortedFilteredFieldNames(fields);
if (_fieldNames.length === 0) {
return EMPTY_ARRAY;
}
return _fieldNames;
}, [template.inputs]);
);
});
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
return fieldNames;
};

View File

@@ -1,10 +1,12 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { useMemo } from 'react';
import { useFieldType } from './useFieldType.ts';
type UseConnectionStateProps = {
nodeId: string;
fieldName: string;
@@ -14,7 +16,7 @@ type UseConnectionStateProps = {
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates);
const edgePendingUpdate = useStore($edgePendingUpdate);
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo(
() =>
@@ -31,9 +33,17 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
[fieldName, kind, nodeId]
);
const selectValidationResult = useMemo(
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
[templates, nodeId, fieldName, kind]
const selectConnectionError = useMemo(
() =>
makeConnectionErrorSelector(
templates,
pendingConnection,
nodeId,
fieldName,
kind === 'inputs' ? 'target' : 'source',
fieldType
),
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
);
const isConnected = useAppSelector(selectIsConnected);
@@ -43,23 +53,23 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
return false;
}
return (
pendingConnection.nodeId === nodeId &&
pendingConnection.handleId === fieldName &&
pendingConnection.node.id === nodeId &&
pendingConnection.fieldTemplate.name === fieldName &&
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
);
}, [fieldName, kind, nodeId, pendingConnection]);
const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate));
const connectionError = useAppSelector(selectConnectionError);
const shouldDim = useMemo(
() => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField),
[validationResult, isConnectionInProgress, isConnectionStartField]
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
[connectionError, isConnectionInProgress, isConnectionStartField]
);
return {
isConnected,
isConnectionInProgress,
isConnectionStartField,
validationResult,
connectionError,
shouldDim,
};
};

View File

@@ -5,13 +5,11 @@ import {
$copiedNodes,
$cursorPos,
$edgesToCopiedNodes,
edgesChanged,
nodesChanged,
selectionPasted,
selectNodesSlice,
} from 'features/nodes/store/nodesSlice';
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
import { isEqual, uniqWith } from 'lodash-es';
import type { EdgeChange, NodeChange } from 'reactflow';
import { v4 as uuidv4 } from 'uuid';
const copySelection = () => {
@@ -28,7 +26,7 @@ const copySelection = () => {
const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
const { getState, dispatch } = getStore();
const { nodes, edges } = selectNodesSlice(getState());
const currentNodes = selectNodesSlice(getState()).nodes;
const cursorPos = $cursorPos.get();
const copiedNodes = deepClone($copiedNodes.get());
@@ -48,7 +46,7 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
const offsetY = cursorPos ? cursorPos.y - minY : 50;
copiedNodes.forEach((node) => {
const { x, y } = findUnoccupiedPosition(nodes, node.position.x + offsetX, node.position.y + offsetY);
const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY);
node.position.x = x;
node.position.y = y;
// Pasted nodes are selected
@@ -70,48 +68,7 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
node.data.id = id;
});
const nodeChanges: NodeChange[] = [];
const edgeChanges: EdgeChange[] = [];
// Deselect existing nodes
nodes.forEach(({ id, selected }) => {
if (selected) {
nodeChanges.push({
type: 'select',
id,
selected: false,
});
}
});
// Add new nodes
copiedNodes.forEach((n) => {
nodeChanges.push({
type: 'add',
item: n,
});
});
// Deselect existing edges
edges.forEach(({ id, selected }) => {
if (selected) {
edgeChanges.push({
type: 'select',
id,
selected: false,
});
}
});
// Add new edges
copiedEdges.forEach((e) => {
edgeChanges.push({
type: 'add',
item: e,
});
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges }));
};
const api = { copySelection, pasteSelection };

View File

@@ -1,20 +0,0 @@
import { useAppSelector } from 'app/store/storeHooks';
import { isInvocationNode } from 'features/nodes/types/invocation';
export const useDoesFieldExist = (nodeId: string, fieldName?: string) => {
const doesFieldExist = useAppSelector((s) => {
const node = s.nodes.present.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
if (fieldName === undefined) {
return true;
}
if (!node.data.inputs[fieldName]) {
return false;
}
return true;
});
return doesFieldExist;
};

View File

@@ -0,0 +1,9 @@
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import type { FieldType } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]);
return fieldType;
};

View File

@@ -1,10 +1,14 @@
// TODO: enable this at some point
import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import { $templates } from 'features/nodes/store/nodesSlice';
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { isEqual } from 'lodash-es';
import { useCallback } from 'react';
import type { Connection } from 'reactflow';
import type { Connection, Node } from 'reactflow';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
@@ -21,21 +25,75 @@ export const useIsValidConnection = () => {
if (!(source && sourceHandle && target && targetHandle)) {
return false;
}
const edgePendingUpdate = $edgePendingUpdate.get();
const { nodes, edges } = store.getState().nodes.present;
const validationResult = validateConnection(
{ source, sourceHandle, target, targetHandle },
nodes,
edges,
templates,
edgePendingUpdate,
shouldValidateGraph
);
if (source === target) {
// Don't allow nodes to connect to themselves, even if validation is disabled
return false;
}
return validationResult.isValid;
const state = store.getState();
const { nodes, edges } = state.nodes.present;
// Find the source and target nodes
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
const targetNode = nodes.find((node) => node.id === target) as Node<InvocationNodeData>;
const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle];
const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle];
// Conditional guards against undefined nodes/handles
if (!(sourceFieldTemplate && targetFieldTemplate)) {
return false;
}
if (targetFieldTemplate.input === 'direct') {
return false;
}
if (!shouldValidateGraph) {
// manual override!
return true;
}
if (
edges.find((edge) => {
edge.target === target &&
edge.targetHandle === targetHandle &&
edge.source === source &&
edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return false;
}
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
return isEqual(sourceFieldTemplate.type, collectItemType);
}
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetFieldTemplate.type.name !== 'CollectionItemField'
) {
return false;
}
// Must use the originalType here if it exists
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return false;
}
// Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges);
},
[templates, shouldValidateGraph, store]
[shouldValidateGraph, templates, store]
);
return isValidConnection;

View File

@@ -1,4 +1,4 @@
import { type FieldType, isCollection, isSingleOrCollection } from 'features/nodes/types/field';
import type { FieldType } from 'features/nodes/types/field';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -10,13 +10,13 @@ export const useFieldTypeName = (fieldType?: FieldType): string => {
return '';
}
const { name } = fieldType;
if (isCollection(fieldType)) {
if (fieldType.isCollection) {
return t('nodes.collectionFieldType', { name });
}
if (isSingleOrCollection(fieldType)) {
if (fieldType.isCollectionOrScalar) {
return t('nodes.collectionOrScalarFieldType', { name });
}
return t('nodes.singleFieldType', { name });
return name;
}, [fieldType, t]);
return name;

View File

@@ -1,6 +1,6 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { Graph, GraphAndWorkflowResponse } from 'services/api/types';
import type { Graph } from 'services/api/types';
const textToImageGraphBuilt = createAction<Graph>('nodes/textToImageGraphBuilt');
const imageToImageGraphBuilt = createAction<Graph>('nodes/imageToImageGraphBuilt');
@@ -15,7 +15,7 @@ export const isAnyGraphBuilt = isAnyOf(
);
export const workflowLoadRequested = createAction<{
data: GraphAndWorkflowResponse;
workflow: unknown;
asCopy: boolean;
}>('nodes/workflowLoadRequested');

View File

@@ -16,7 +16,6 @@ import type {
IPAdapterModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
ModelIdentifierFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
StatefulFieldValue,
@@ -36,7 +35,6 @@ import {
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
zMainModelFieldValue,
zModelIdentifierFieldValue,
zSchedulerFieldValue,
zSDXLRefinerModelFieldValue,
zStatefulFieldValue,
@@ -47,13 +45,13 @@ import {
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom } from 'nanostores';
import type { MouseEvent } from 'react';
import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow';
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import type { z } from 'zod';
import type { NodesState, PendingConnection, Templates } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodesState: NodesState = {
_version: 1,
@@ -92,47 +90,44 @@ export const nodesSlice = createSlice({
reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange[]>) => {
state.nodes = applyNodeChanges(action.payload, state.nodes);
// Remove edges that are no longer valid, due to a removed or otherwise changed node
const edgeChanges: EdgeChange[] = [];
state.edges.forEach((e) => {
const sourceExists = state.nodes.some((n) => n.id === e.source);
const targetExists = state.nodes.some((n) => n.id === e.target);
if (!(sourceExists && targetExists)) {
edgeChanges.push({ type: 'remove', id: e.id });
}
});
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
nodeReplaced: (state, action: PayloadAction<{ nodeId: string; node: Node }>) => {
const nodeIndex = state.nodes.findIndex((n) => n.id === action.payload.nodeId);
if (nodeIndex < 0) {
return;
}
state.nodes[nodeIndex] = action.payload.node;
},
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
const { node, cursorPos } = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
cursorPos?.x ?? node.position.x,
cursorPos?.y ?? node.position.y
);
node.position = position;
node.selected = true;
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: false })),
state.nodes
);
state.edges = applyEdgeChanges(
state.edges.map((e) => ({ id: e.id, type: 'select', selected: false })),
state.edges
);
state.nodes.push(node);
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
const changes: EdgeChange[] = [];
// We may need to massage the edge changes or otherwise handle them
action.payload.forEach((change) => {
if (change.type === 'remove' || change.type === 'select') {
const edge = state.edges.find((e) => e.id === change.id);
// If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them
if (edge && edge.type === 'collapsed') {
const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target);
if (change.type === 'remove') {
hiddenEdges.forEach(({ id }) => {
changes.push({ type: 'remove', id });
});
}
if (change.type === 'select') {
hiddenEdges.forEach(({ id }) => {
changes.push({ type: 'select', id, selected: change.selected });
});
}
}
}
if (change.type === 'add') {
if (!change.item.type) {
// We must add the edge type!
change.item.type = 'default';
}
}
changes.push(change);
});
state.edges = applyEdgeChanges(changes, state.edges);
state.edges = applyEdgeChanges(action.payload, state.edges);
},
edgeAdded: (state, action: PayloadAction<Edge>) => {
state.edges = addEdge(action.payload, state.edges);
},
connectionMade: (state, action: PayloadAction<Connection>) => {
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
},
fieldLabelChanged: (
state,
@@ -237,7 +232,6 @@ export const nodesSlice = createSlice({
type: 'collapsed',
data: { count: 1 },
updatable: false,
selected: edge.selected,
});
}
}
@@ -258,7 +252,6 @@ export const nodesSlice = createSlice({
type: 'collapsed',
data: { count: 1 },
updatable: false,
selected: edge.selected,
});
}
}
@@ -271,6 +264,33 @@ export const nodesSlice = createSlice({
}
}
},
edgeDeleted: (state, action: PayloadAction<string>) => {
state.edges = state.edges.filter((e) => e.id !== action.payload);
},
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
const edges = action.payload;
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
// if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes
if (collapsedEdges.length) {
const edgeChanges: EdgeRemoveChange[] = [];
collapsedEdges.forEach((collapsedEdge) => {
state.edges.forEach((edge) => {
if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) {
edgeChanges.push({ id: edge.id, type: 'remove' });
}
});
});
state.edges = applyEdgeChanges(edgeChanges, state.edges);
}
},
nodesDeleted: (state, action: PayloadAction<AnyNode[]>) => {
action.payload.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
});
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
const { nodeId, label } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
@@ -289,6 +309,17 @@ export const nodesSlice = createSlice({
}
node.data.notes = notes;
},
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
const nodeId = action.payload;
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({
id: n.id,
type: 'select',
selected: n.id === nodeId ? true : false,
})),
state.nodes
);
},
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
fieldValueReducer(state, action, zStatefulFieldValue);
},
@@ -313,9 +344,6 @@ export const nodesSlice = createSlice({
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
fieldValueReducer(state, action, zMainModelFieldValue);
},
fieldModelIdentifierValueChanged: (state, action: FieldValueAction<ModelIdentifierFieldValue>) => {
fieldValueReducer(state, action, zModelIdentifierFieldValue);
},
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
},
@@ -353,6 +381,57 @@ export const nodesSlice = createSlice({
state.nodes = [];
state.edges = [];
},
selectedAll: (state) => {
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
state.nodes
);
state.edges = applyEdgeChanges(
state.edges.map((e) => ({ id: e.id, type: 'select', selected: true })),
state.edges
);
},
selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => {
const { nodes, edges } = action.payload;
const nodeChanges: NodeChange[] = [];
// Deselect existing nodes
state.nodes.forEach((n) => {
nodeChanges.push({
id: n.data.id,
type: 'select',
selected: false,
});
});
// Add new nodes
nodes.forEach((n) => {
nodeChanges.push({
item: n,
type: 'add',
});
});
const edgeChanges: EdgeChange[] = [];
// Deselect existing edges
state.edges.forEach((e) => {
edgeChanges.push({
id: e.id,
type: 'select',
selected: false,
});
});
// Add new edges
edges.forEach((e) => {
edgeChanges.push({
item: e,
type: 'add',
});
});
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
undo: (state) => state,
redo: (state) => state,
},
@@ -361,13 +440,13 @@ export const nodesSlice = createSlice({
const { nodes, edges } = action.payload;
state.nodes = applyNodeChanges(
nodes.map((node) => ({
type: 'add',
item: { ...node, ...SHARED_NODE_PROPERTIES },
type: 'add',
})),
[]
);
state.edges = applyEdgeChanges(
edges.map((edge) => ({ type: 'add', item: edge })),
edges.map((edge) => ({ item: edge, type: 'add' })),
[]
);
});
@@ -375,7 +454,10 @@ export const nodesSlice = createSlice({
});
export const {
connectionMade,
edgeDeleted,
edgesChanged,
edgesDeleted,
fieldValueReset,
fieldBoardValueChanged,
fieldBooleanValueChanged,
@@ -387,21 +469,27 @@ export const {
fieldT2IAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
nodeAdded,
nodeReplaced,
nodeEditorReset,
nodeExclusivelySelected,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
nodesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged,
selectedAll,
selectionPasted,
edgeAdded,
undo,
redo,
} = nodesSlice.actions;
@@ -412,10 +500,7 @@ export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null);
export const $edgePendingUpdate = atom<Edge | null>(null);
export const $didUpdateEdge = atom(false);
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
export const $isUpdatingEdge = atom(false);
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
export const $isAddNodePopoverOpen = atom(false);
export const closeAddNodePopover = () => {
@@ -443,13 +528,13 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
persistDenylist: [],
};
const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected);
const isSelectionAction = (action: UnknownAction) => {
if (nodesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'select')) {
return true;
}
if (selectionMatcher(action)) {
return true;
}
if (edgesChanged.match(action)) {
if (nodesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'select')) {
return true;
}
@@ -489,7 +574,10 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
// This is used for tracking `state.workflow.isTouched`
export const isAnyNodeOrEdgeMutation = isAnyOf(
connectionMade,
edgeDeleted,
edgesChanged,
edgesDeleted,
fieldBoardValueChanged,
fieldBooleanValueChanged,
fieldColorValueChanged,
@@ -506,11 +594,15 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
nodesChanged,
nodeAdded,
nodeReplaced,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged
notesNodeValueChanged,
selectionPasted,
edgeAdded
);

View File

@@ -6,20 +6,19 @@ import type {
} from 'features/nodes/types/field';
import type {
AnyNode,
InvocationNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { HandleType } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
export type PendingConnection = {
nodeId: string;
handleId: string;
handleType: HandleType;
node: InvocationNode;
template: InvocationTemplate;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};

View File

@@ -1,86 +0,0 @@
import type { FieldType } from 'features/nodes/types/field';
import { describe, expect, it } from 'vitest';
import { areTypesEqual } from './areTypesEqual';
describe(areTypesEqual.name, () => {
it('should handle equal source and target type', () => {
const sourceType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Foo',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Bar',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal source type and original target type', () => {
const sourceType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Foo',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and target type', () => {
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Bar',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and original target type', () => {
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'LoRAModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
});

View File

@@ -1,29 +0,0 @@
import type { FieldType } from 'features/nodes/types/field';
import { isEqual, omit } from 'lodash-es';
/**
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
* considered equal. For example, if the first type and original second type match, the types are considered equal.
* @param firstType The first type to compare.
* @param secondType The second type to compare.
* @returns True if the types are equal, false otherwise.
*/
export const areTypesEqual = (firstType: FieldType, secondType: FieldType) => {
const _firstType = 'originalType' in firstType ? omit(firstType, 'originalType') : firstType;
const _secondType = 'originalType' in secondType ? omit(secondType, 'originalType') : secondType;
const _originalFirstType = 'originalType' in firstType ? firstType.originalType : null;
const _originalSecondType = 'originalType' in secondType ? secondType.originalType : null;
if (isEqual(_firstType, _secondType)) {
return true;
}
if (_originalSecondType && isEqual(_firstType, _originalSecondType)) {
return true;
}
if (_originalFirstType && isEqual(_originalFirstType, _secondType)) {
return true;
}
if (_originalFirstType && _originalSecondType && isEqual(_originalFirstType, _originalSecondType)) {
return true;
}
return false;
};

View File

@@ -0,0 +1,105 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, isEqual, map } from 'lodash-es';
import type { Connection } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getFirstValidConnection = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate
): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
return null;
}
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct');
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
);
const candidateField = candidateUnconnectedFields.find((field) =>
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null;
}
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return null;
}
// Sources/outputs can have any number of edges, we can take the first matching output field
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
}
return null;
};

View File

@@ -1,44 +0,0 @@
import { deepClone } from 'common/util/deepClone';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils';
import type { FieldType } from 'features/nodes/types/field';
import { unset } from 'lodash-es';
import { describe, expect, it } from 'vitest';
describe(getCollectItemType.name, () => {
it('should return the type of the items the collect node collects', () => {
const n1 = buildNode(add);
const n2 = buildNode(collect);
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const result = getCollectItemType(templates, [n1, n2], [e1], n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE' });
});
it('should return null if the collect node does not have any connections', () => {
const n1 = buildNode(collect);
const result = getCollectItemType(templates, [n1], [], n1.id);
expect(result).toBeNull();
});
it("should return null if the first edge to collect's node doesn't exist", () => {
const n1 = buildNode(collect);
const n2 = buildNode(add);
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
const result = getCollectItemType(templates, [n1], [e1], n1.id);
expect(result).toBeNull();
});
it("should return null if the first edge to collect's node template doesn't exist", () => {
const n1 = buildNode(collect);
const n2 = buildNode(add);
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
const result = getCollectItemType({ collect }, [n1, n2], [e1], n1.id);
expect(result).toBeNull();
});
it("should return null if the first edge to the collect's field template doesn't exist", () => {
const n1 = buildNode(collect);
const n2 = buildNode(add);
const addWithoutOutputValue = deepClone(add);
unset(addWithoutOutputValue, 'outputs.value');
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id);
expect(result).toBeNull();
});
});

View File

@@ -1,38 +0,0 @@
import type { Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
/**
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
* input field.
* @param templates The current invocation templates
* @param nodes The current nodes
* @param edges The current edges
* @param nodeId The collect node's id
* @returns The type of the items the collect node collects, or null if there is no input field
*/
export const getCollectItemType = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle];
if (!fieldTemplate) {
return null;
}
return fieldTemplate.type;
};

View File

@@ -1,203 +0,0 @@
import { deepClone } from 'common/util/deepClone';
import {
getFirstValidConnection,
getSourceCandidateFields,
getTargetCandidateFields,
} from 'features/nodes/store/util/getFirstValidConnection';
import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils';
import { unset } from 'lodash-es';
import { describe, expect, it } from 'vitest';
describe('getFirstValidConnection', () => {
it('should return null if the pending and candidate nodes are the same node', () => {
const n = buildNode(add);
expect(getFirstValidConnection(n.id, 'value', n.id, null, [n], [], templates, null)).toBe(null);
});
it('should return null if the sourceHandle and targetHandle are null', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
expect(getFirstValidConnection(n1.id, null, n2.id, null, [n1, n2], [], templates, null)).toBe(null);
});
it('should return itself if both sourceHandle and targetHandle are provided', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
expect(getFirstValidConnection(n1.id, 'value', n2.id, 'a', [n1, n2], [], templates, null)).toEqual({
source: n1.id,
sourceHandle: 'value',
target: n2.id,
targetHandle: 'a',
});
});
describe('connecting from a source to a target', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
it('should return the first valid connection if there are no connected fields', () => {
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is a connected field', () => {
const e = buildEdge(n1.id, 'height', n2.id, 'width');
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'height',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is an edgePendingUpdate', () => {
const e = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, e);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return null if the target has no valid fields', () => {
const e1 = buildEdge(n1.id, 'width', n2.id, 'width');
const e2 = buildEdge(n1.id, 'height', n2.id, 'height');
const n3 = buildNode(add);
const r = getFirstValidConnection(n3.id, 'value', n2.id, null, [n1, n2, n3], [e1, e2], templates, null);
expect(r).toEqual(null);
});
});
describe('connecting from a target to a source', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
it('should return the first valid connection if there are no connected fields', () => {
const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is a connected field', () => {
const e = buildEdge(n1.id, 'height', n2.id, 'width');
const r = getFirstValidConnection(n1.id, null, n2.id, 'height', [n1, n2], [e], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'height',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is an edgePendingUpdate', () => {
const e = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [e], templates, e);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return null if the target has no valid fields', () => {
const e1 = buildEdge(n1.id, 'width', n2.id, 'width');
const e2 = buildEdge(n1.id, 'height', n2.id, 'height');
const n3 = buildNode(add);
const r = getFirstValidConnection(n3.id, null, n2.id, 'a', [n1, n2, n3], [e1, e2], templates, null);
expect(r).toEqual(null);
});
});
});
describe('getTargetCandidateFields', () => {
it('should return an empty array if the nodes canot be found', () => {
const r = getTargetCandidateFields('missing', 'value', 'missing', [], [], templates, null);
expect(r).toEqual([]);
});
it('should return an empty array if the templates cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], {}, null);
expect(r).toEqual([]);
});
it('should return an empty array if the source field template cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const addWithoutOutputValue = deepClone(add);
unset(addWithoutOutputValue, 'outputs.value');
const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], { add: addWithoutOutputValue }, null);
expect(r).toEqual([]);
});
it('should return all valid target fields if there are no connected fields', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, null);
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
});
it('should ignore the edgePendingUpdate if provided', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate);
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
});
});
describe('getSourceCandidateFields', () => {
it('should return an empty array if the nodes canot be found', () => {
const r = getSourceCandidateFields('missing', 'value', 'missing', [], [], templates, null);
expect(r).toEqual([]);
});
it('should return an empty array if the templates cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const r = getSourceCandidateFields(n2.id, 'a', n1.id, nodes, [], {}, null);
expect(r).toEqual([]);
});
it('should return an empty array if the source field template cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const addWithoutInputA = deepClone(add);
unset(addWithoutInputA, 'inputs.a');
const r = getSourceCandidateFields(n1.id, 'a', n2.id, nodes, [], { add: addWithoutInputA }, null);
expect(r).toEqual([]);
});
it('should return all valid source fields if there are no connected fields', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, null);
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
});
it('should ignore the edgePendingUpdate if provided', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate);
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
});
});

View File

@@ -1,149 +0,0 @@
import type { Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { map } from 'lodash-es';
import type { Connection, Edge } from 'reactflow';
/**
*
* @param source The source (node id)
* @param sourceHandle The source handle (field name), if any
* @param target The target (node id)
* @param targetHandle The target handle (field name), if any
* @param nodes The current nodes
* @param edges The current edges
* @param templates The current templates
* @param edgePendingUpdate The edge pending update, if any
* @returns
*/
export const getFirstValidConnection = (
source: string,
sourceHandle: string | null,
target: string,
targetHandle: string | null,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
templates: Templates,
edgePendingUpdate: Edge | null
): Connection | null => {
if (source === target) {
return null;
}
if (sourceHandle && targetHandle) {
return { source, sourceHandle, target, targetHandle };
}
if (sourceHandle && !targetHandle) {
const candidates = getTargetCandidateFields(
source,
sourceHandle,
target,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
return { source, sourceHandle, target, targetHandle: firstCandidate.name };
}
if (!sourceHandle && targetHandle) {
const candidates = getSourceCandidateFields(
target,
targetHandle,
source,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
return { source, sourceHandle: firstCandidate.name, target, targetHandle };
}
return null;
};
export const getTargetCandidateFields = (
source: string,
sourceHandle: string,
target: string,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
edgePendingUpdate: Edge | null
): FieldInputTemplate[] => {
const sourceNode = nodes.find((n) => n.id === source);
const targetNode = nodes.find((n) => n.id === target);
if (!sourceNode || !targetNode) {
return [];
}
const sourceTemplate = templates[sourceNode.data.type];
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
return [];
}
const sourceField = sourceTemplate.outputs[sourceHandle];
if (!sourceField) {
return [];
}
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
const c = { source, sourceHandle, target, targetHandle: field.name };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
});
return targetCandidateFields;
};
export const getSourceCandidateFields = (
target: string,
targetHandle: string,
source: string,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
edgePendingUpdate: Edge | null
): FieldOutputTemplate[] => {
const targetNode = nodes.find((n) => n.id === target);
const sourceNode = nodes.find((n) => n.id === source);
if (!sourceNode || !targetNode) {
return [];
}
const sourceTemplate = templates[sourceNode.data.type];
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
return [];
}
const targetField = targetTemplate.inputs[targetHandle];
if (!targetField) {
return [];
}
const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => {
const c = { source, sourceHandle: field.name, target, targetHandle };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
});
return sourceCandidateFields;
};

View File

@@ -1,22 +0,0 @@
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { add, buildEdge, buildNode } from 'features/nodes/store/util/testUtils';
import { describe, expect, it } from 'vitest';
describe(getHasCycles.name, () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const n3 = buildNode(add);
const nodes = [n1, n2, n3];
it('should return true if the graph WOULD have cycles after adding the edge', () => {
const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')];
const result = getHasCycles(n3.id, n1.id, nodes, edges);
expect(result).toBe(true);
});
it('should return false if the graph WOULD NOT have cycles after adding the edge', () => {
const edges = [buildEdge(n1.id, 'value', n2.id, 'a')];
const result = getHasCycles(n2.id, n3.id, nodes, edges);
expect(result).toBe(false);
});
});

View File

@@ -1,30 +0,0 @@
import graphlib from '@dagrejs/graphlib';
import type { Edge, Node } from 'reactflow';
/**
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
* @param source The source node id
* @param target The target node id
* @param nodes The graph's current nodes
* @param edges The graph's current edges
* @returns True if the graph would be acyclic after adding the edge, false otherwise
*/
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return !graphlib.alg.isAcyclic(g);
};

View File

@@ -0,0 +1,21 @@
import graphlib from '@dagrejs/graphlib';
import type { Edge, Node } from 'reactflow';
export const getIsGraphAcyclic = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
// construct graphlib graph from editor state
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// add the candidate edge
g.setEdge(source, target);
// check if the graph is acyclic
return graphlib.alg.isAcyclic(g);
};

View File

@@ -1,67 +0,0 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { RootState } from 'app/store/store';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
import type { Edge, HandleType } from 'reactflow';
/**
* Creates a selector that validates a pending connection.
*
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
* TODO: Figure out how to do this without duplicating all the logic
*
* @param templates The invocation templates
* @param nodeId The id of the node for which the selector is being created
* @param fieldName The name of the field for which the selector is being created
* @param handleType The type of the handle for which the selector is being created
* @returns
*/
export const makeConnectionErrorSelector = (
templates: Templates,
nodeId: string,
fieldName: string,
handleType: HandleType
) => {
return createMemoizedSelector(
selectNodesSlice,
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
(state: RootState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) =>
edgePendingUpdate,
(nodesSlice: NodesState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => {
const { nodes, edges } = nodesSlice;
if (!pendingConnection) {
return buildRejectResult('nodes.noConnectionInProgress');
}
if (handleType === pendingConnection.handleType) {
if (handleType === 'source') {
return buildRejectResult('nodes.cannotConnectOutputToOutput');
}
return buildRejectResult('nodes.cannotConnectInputToInput');
}
// we have to figure out which is the target and which is the source
const source = handleType === 'source' ? nodeId : pendingConnection.nodeId;
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId;
const target = handleType === 'target' ? nodeId : pendingConnection.nodeId;
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId;
const validationResult = validateConnection(
{
source,
sourceHandle,
target,
targetHandle,
},
nodes,
edges,
templates,
edgePendingUpdate
);
return validationResult;
}
);
};

View File

@@ -0,0 +1,147 @@
import { createSelector } from '@reduxjs/toolkit';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import i18n from 'i18next';
import { isEqual } from 'lodash-es';
import type { HandleType } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getCollectItemType = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
return fieldType;
};
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const makeConnectionErrorSelector = (
templates: Templates,
pendingConnection: PendingConnection | null,
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType?: FieldType | null
) => {
return createSelector(selectNodesSlice, (nodesSlice) => {
const { nodes, edges } = nodesSlice;
if (!fieldType) {
return i18n.t('nodes.noFieldType');
}
if (!pendingConnection) {
return i18n.t('nodes.noConnectionInProgress');
}
const connectionNodeId = pendingConnection.node.id;
const connectionFieldName = pendingConnection.fieldTemplate.name;
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
return i18n.t('nodes.noConnectionData');
}
const targetType = handleType === 'target' ? fieldType : connectionStartFieldType;
const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType;
if (nodeId === connectionNodeId) {
return i18n.t('nodes.cannotConnectToSelf');
}
if (handleType === connectionHandleType) {
if (handleType === 'source') {
return i18n.t('nodes.cannotConnectOutputToOutput');
}
return i18n.t('nodes.cannotConnectInputToInput');
}
// we have to figure out which is the target and which is the source
const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
if (
edges.find((edge) => {
edge.target === targetNodeId &&
edge.targetHandle === targetFieldName &&
edge.source === sourceNodeId &&
edge.sourceHandle === sourceFieldName;
})
) {
// We already have a connection from this source to this target
return i18n.t('nodes.cannotDuplicateConnection');
}
const targetNode = nodes.find((node) => node.id === targetNodeId);
assert(targetNode, `Target node not found: ${targetNodeId}`);
const targetTemplate = templates[targetNode.data.type];
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
return i18n.t('nodes.cannotConnectToDirectInput');
}
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
if (!isEqual(sourceType, collectItemType)) {
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
}
}
}
if (
edges.find((edge) => {
return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
}) &&
// except CollectionItem inputs can have multiples
targetType.name !== 'CollectionItemField'
) {
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
}
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
return i18n.t('nodes.fieldTypesMustMatch');
}
const isGraphAcyclic = getIsGraphAcyclic(
connectionHandleType === 'source' ? connectionNodeId : nodeId,
connectionHandleType === 'source' ? nodeId : connectionNodeId,
nodes,
edges
);
if (!isGraphAcyclic) {
return i18n.t('nodes.connectionWouldCreateCycle');
}
return;
});
};

View File

@@ -1,32 +0,0 @@
import type { Connection, Edge } from 'reactflow';
import { assert } from 'tsafe';
/**
* Gets the edge id for a connection
* Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45
* Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290
* @param connection The connection to get the id for
* @returns The edge id
*/
const getEdgeId = (connection: Connection): string => {
const { source, sourceHandle, target, targetHandle } = connection;
return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`;
};
/**
* Converts a connection to an edge
* @param connection The connection to convert to an edge
* @returns The edge
* @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle)
*/
export const connectionToEdge = (connection: Connection): Edge => {
const { source, sourceHandle, target, targetHandle } = connection;
assert(source && sourceHandle && target && targetHandle, 'Invalid connection');
return {
source,
sourceHandle,
target,
targetHandle,
id: getEdgeId({ source, sourceHandle, target, targetHandle }),
};
};

View File

@@ -1,194 +0,0 @@
import { deepClone } from 'common/util/deepClone';
import { set } from 'lodash-es';
import { describe, expect, it } from 'vitest';
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
describe(validateConnection.name, () => {
it('should reject invalid connection to self', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
});
describe('missing nodes', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
it('should reject missing source node', () => {
const r = validateConnection(c, [n2], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
});
it('should reject missing target node', () => {
const r = validateConnection(c, [n1], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
});
});
describe('missing invocation templates', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const nodes = [n1, n2];
it('should reject missing source template', () => {
const r = validateConnection(c, nodes, [], { sub }, null);
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
});
it('should reject missing target template', () => {
const r = validateConnection(c, nodes, [], { add }, null);
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
});
});
describe('missing field templates', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const nodes = [n1, n2];
it('should reject missing source field template', () => {
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
});
it('should reject missing target field template', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
});
});
describe('duplicate connections', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
it('should accept non-duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], templates, null);
expect(r).toEqual(buildAcceptResult());
});
it('should reject duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
});
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, e);
expect(r).toEqual(buildAcceptResult());
});
});
it('should reject connection to direct input', () => {
// Create cloned add template w/ a direct input
const addWithDirectAField = deepClone(add);
set(addWithDirectAField, 'inputs.a.input', 'direct');
set(addWithDirectAField, 'type', 'addWithDirectAField');
const n1 = buildNode(add);
const n2 = buildNode(addWithDirectAField);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
});
it('should reject connection to a collect node with mismatched item types', () => {
const n1 = buildNode(add);
const n2 = buildNode(collect);
const n3 = buildNode(main_model_loader);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
});
it('should accept connection to a collect node with matching item types', () => {
const n1 = buildNode(add);
const n2 = buildNode(collect);
const n3 = buildNode(sub);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildAcceptResult());
});
it('should reject connections to target field that is already connected', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const n3 = buildNode(add);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
});
it('should accept connections to target field that is already connected (ignored edge)', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const n3 = buildNode(add);
const nodes = [n1, n2, n3];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, e1);
expect(r).toEqual(buildAcceptResult());
});
it('should reject connections between invalid types', () => {
const n1 = buildNode(add);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
});
it('should reject connections that would create cycles', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const nodes = [n1, n2];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
});
describe('non-strict mode', () => {
it('should reject connections from self to self in non-strict mode', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null, false);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
});
it('should reject connections that create cycles in non-strict mode', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const nodes = [n1, n2];
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null, false);
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
});
it('should otherwise allow invalid connections in non-strict mode', () => {
const n1 = buildNode(add);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null, false);
expect(r).toEqual(buildAcceptResult());
});
});
});

View File

@@ -1,130 +0,0 @@
import type { Templates } from 'features/nodes/store/types';
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import type { Connection as NullableConnection, Edge } from 'reactflow';
import type { O } from 'ts-toolbelt';
type Connection = O.NonNullable<NullableConnection>;
export type ValidationResult =
| {
isValid: true;
messageTKey?: string;
}
| {
isValid: false;
messageTKey: string;
};
type ValidateConnectionFunc = (
connection: Connection,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
ignoreEdge: Edge | null,
strict?: boolean
) => ValidationResult;
const getEqualityPredicate =
(c: Connection) =>
(e: Edge): boolean => {
return (
e.target === c.target &&
e.targetHandle === c.targetHandle &&
e.source === c.source &&
e.sourceHandle === c.sourceHandle
);
};
const getTargetEqualityPredicate =
(c: Connection) =>
(e: Edge): boolean => {
return e.target === c.target && e.targetHandle === c.targetHandle;
};
export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
if (c.source === c.target) {
return buildRejectResult('nodes.cannotConnectToSelf');
}
if (strict) {
/**
* We may need to ignore an edge when validating a connection.
*
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else
* the validation will fail unexpectedly.
*/
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
if (filteredEdges.some(getEqualityPredicate(c))) {
// We already have a connection from this source to this target
return buildRejectResult('nodes.cannotDuplicateConnection');
}
const sourceNode = nodes.find((n) => n.id === c.source);
if (!sourceNode) {
return buildRejectResult('nodes.missingNode');
}
const targetNode = nodes.find((n) => n.id === c.target);
if (!targetNode) {
return buildRejectResult('nodes.missingNode');
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
if (targetFieldTemplate.input === 'direct') {
return buildRejectResult('nodes.cannotConnectToDirectInput');
}
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types.
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
}
}
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
// CollectionItemField inputs can have multiple input connections
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
}
}
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return buildRejectResult('nodes.fieldTypesMustMatch');
}
}
if (getHasCycles(c.source, c.target, nodes, edges)) {
return buildRejectResult('nodes.connectionWouldCreateCycle');
}
return buildAcceptResult();
};

View File

@@ -1,222 +0,0 @@
import { describe, expect, it } from 'vitest';
import { validateConnectionTypes } from './validateConnectionTypes';
describe(validateConnectionTypes.name, () => {
describe('generic cases', () => {
it('should accept SINGLE to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'FooField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should accept COLLECTION to COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept SINGLE to SINGLE_OR_COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept COLLECTION to SINGLE_OR_COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should reject COLLECTION to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
it('should reject SINGLE_OR_COLLECTION to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
it('should reject mismatched types', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'BarField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
});
describe('special cases', () => {
it('should reject a COLLECTION input to a COLLECTION input', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', cardinality: 'COLLECTION' },
{ name: 'CollectionField', cardinality: 'COLLECTION' }
);
expect(r).toBe(false);
});
it('should accept equal types', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
describe('CollectionItemField', () => {
it('should accept CollectionItemField to any SINGLE target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should accept CollectionItemField to any SINGLE_OR_COLLECTION target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any SINGLE to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should reject any COLLECTION to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'COLLECTION' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
it('should reject any SINGLE_OR_COLLECTION to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
);
expect(r).toBe(false);
});
});
describe('SINGLE_OR_COLLECTION', () => {
it('should accept any SINGLE of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'COLLECTION' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any SINGLE_OR_COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
});
describe('CollectionField', () => {
it('should accept any CollectionField to any COLLECTION type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any CollectionField to any SINGLE_OR_COLLECTION type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
});
describe('subtype handling', () => {
type TypePair = { t1: string; t2: string };
const typePairs = [
{ t1: 'IntegerField', t2: 'FloatField' },
{ t1: 'IntegerField', t2: 'StringField' },
{ t1: 'FloatField', t2: 'StringField' },
];
it.each(typePairs)('should accept SINGLE $t1 to SINGLE $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes({ name: t1, cardinality: 'SINGLE' }, { name: t2, cardinality: 'SINGLE' });
expect(r).toBe(true);
});
it.each(typePairs)('should accept SINGLE $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'SINGLE' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept COLLECTION $t1 to COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'COLLECTION' },
{ name: t2, cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'COLLECTION' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
it.each(typePairs)(
'should accept SINGLE_OR_COLLECTION $t1 to SINGLE_OR_COLLECTION $t2',
({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'SINGLE_OR_COLLECTION' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
}
);
});
describe('AnyField', () => {
it('should accept any SINGLE type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'SINGLE' }
);
expect(r).toBe(true);
});
it('should accept any COLLECTION type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'COLLECTION' }
);
expect(r).toBe(true);
});
it('should accept any SINGLE_OR_COLLECTION type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION' }
);
expect(r).toBe(true);
});
});
});
});

View File

@@ -1,74 +0,0 @@
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'features/nodes/types/field';
/**
* Validates that the source and target types are compatible for a connection.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the connection is valid, false otherwise.
*/
export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
return false;
}
if (areTypesEqual(sourceType, targetType)) {
return true;
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-COLLECTION (e.g. SINGLE or SINGLE_OR_COLLECTION)
* - SINGLE can connect to CollectionItem
* - Anything (SINGLE, COLLECTION, SINGLE_OR_COLLECTION) can connect to SINGLE_OR_COLLECTION of the same base type
* - Generic CollectionField can connect to any other COLLECTION or SINGLE_OR_COLLECTION
* - Any COLLECTION can connect to a Generic Collection
*/
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !isCollection(targetType);
const isNonCollectionToCollectionItem = isSingle(sourceType) && targetType.name === 'CollectionItemField';
const isAnythingToSingleOrCollectionOfSameBaseType =
isSingleOrCollection(targetType) && sourceType.name === targetType.name;
const isGenericCollectionToAnyCollectionOrSingleOrCollection =
sourceType.name === 'CollectionField' && !isSingle(targetType);
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && isCollection(sourceType);
const isSourceSingle = isSingle(sourceType);
const isTargetSingle = isSingle(targetType);
const isSingleToSingle = isSourceSingle && isTargetSingle;
const isSingleToSingleOrCollection = isSourceSingle && isSingleOrCollection(targetType);
const isCollectionToCollection = isCollection(sourceType) && isCollection(targetType);
const isCollectionToSingleOrCollection = isCollection(sourceType) && isSingleOrCollection(targetType);
const isSingleOrCollectionToSingleOrCollection = isSingleOrCollection(sourceType) && isSingleOrCollection(targetType);
const doesCardinalityMatch =
isSingleToSingle ||
isCollectionToCollection ||
isCollectionToSingleOrCollection ||
isSingleOrCollectionToSingleOrCollection ||
isSingleToSingleOrCollection;
const isIntToFloat = sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
const isIntToString = sourceType.name === 'IntegerField' && targetType.name === 'StringField';
const isFloatToString = sourceType.name === 'FloatField' && targetType.name === 'StringField';
const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString);
const isTargetAnyType = targetType.name === 'AnyField';
// One of these must be true for the connection to be valid
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToSingleOrCollectionOfSameBaseType ||
isGenericCollectionToAnyCollectionOrSingleOrCollection ||
isCollectionToGenericCollection ||
isSubTypeMatch ||
isTargetAnyType
);
};

View File

@@ -0,0 +1,70 @@
import type { FieldType } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
/**
* Validates that the source and target types are compatible for a connection.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the connection is valid, false otherwise.
*/
export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
return false;
}
if (isEqual(sourceType, targetType)) {
return true;
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
* - Generic Collection can connect to any other Collection or CollectionOrScalar
* - Any Collection can connect to a Generic Collection
*/
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
const isNonCollectionToCollectionItem =
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
const isAnythingToCollectionOrScalarOfSameBaseType =
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
const areBothTypesSingle =
!sourceType.isCollection &&
!sourceType.isCollectionOrScalar &&
!targetType.isCollection &&
!targetType.isCollectionOrScalar;
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
const isIntOrFloatToString =
areBothTypesSingle &&
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
targetType.name === 'StringField';
const isTargetAnyType = targetType.name === 'AnyField';
// One of these must be true for the connection to be valid
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToCollectionOrScalarOfSameBaseType ||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
isCollectionToGenericCollection ||
isIntToFloat ||
isIntOrFloatToString ||
isTargetAnyType
);
};

View File

@@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
import type {
FieldIdentifierWithValue,
WorkflowMode,
@@ -139,31 +139,15 @@ export const workflowSlice = createSlice({
};
});
builder.addCase(nodesDeleted, (state, action) => {
action.payload.forEach((node) => {
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id);
});
});
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
builder.addCase(nodesChanged, (state, action) => {
// If a node was removed, we should remove any exposed fields that were associated with it. However, node changes
// may remove and then add the same node back. For example, when updating a workflow, we replace old nodes with
// updated nodes. In this case, we should not remove the exposed fields. To handle this, we find the last remove
// and add changes for each exposed field. If the remove change comes after the add change, we remove the exposed
// field.
const exposedFieldsToRemove: FieldIdentifier[] = [];
state.exposedFields.forEach((field) => {
const removeIndex = action.payload.findLastIndex(
(change) => change.type === 'remove' && change.id === field.nodeId
);
const addIndex = action.payload.findLastIndex(
(change) => change.type === 'add' && change.item.id === field.nodeId
);
if (removeIndex > addIndex) {
exposedFieldsToRemove.push({ nodeId: field.nodeId, fieldName: field.fieldName });
}
});
state.exposedFields = state.exposedFields.filter(
(field) => !exposedFieldsToRemove.some((f) => isEqual(f, field))
);
// Not all changes to nodes should result in the workflow being marked touched
const filteredChanges = action.payload.filter((change) => {
// We always want to mark the workflow as touched if a node is added, removed, or reset
@@ -181,7 +165,7 @@ export const workflowSlice = createSlice({
return false;
});
if (filteredChanges.length > 0 || exposedFieldsToRemove.length > 0) {
if (filteredChanges.length > 0) {
state.isTouched = true;
}
});

View File

@@ -54,10 +54,9 @@ const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
fieldKind: z.literal('output'),
});
const zCardinality = z.enum(['SINGLE', 'COLLECTION', 'SINGLE_OR_COLLECTION']);
const zFieldTypeBase = z.object({
cardinality: zCardinality,
isCollection: z.boolean(),
isCollectionOrScalar: z.boolean(),
});
export const zFieldIdentifier = z.object({
@@ -67,124 +66,16 @@ export const zFieldIdentifier = z.object({
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
// #endregion
// #region Field Types
const zStatelessFieldType = zFieldTypeBase.extend({
name: z.string().min(1), // stateless --> we accept the field's name as the type
});
// #region IntegerField
const zIntegerFieldType = zFieldTypeBase.extend({
name: z.literal('IntegerField'),
originalType: zStatelessFieldType.optional(),
});
const zFloatFieldType = zFieldTypeBase.extend({
name: z.literal('FloatField'),
originalType: zStatelessFieldType.optional(),
});
const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
originalType: zStatelessFieldType.optional(),
});
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
originalType: zStatelessFieldType.optional(),
});
const zEnumFieldType = zFieldTypeBase.extend({
name: z.literal('EnumField'),
originalType: zStatelessFieldType.optional(),
});
const zImageFieldType = zFieldTypeBase.extend({
name: z.literal('ImageField'),
originalType: zStatelessFieldType.optional(),
});
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
originalType: zStatelessFieldType.optional(),
});
const zColorFieldType = zFieldTypeBase.extend({
name: z.literal('ColorField'),
originalType: zStatelessFieldType.optional(),
});
const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zModelIdentifierFieldType = zFieldTypeBase.extend({
name: z.literal('ModelIdentifierField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
});
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
originalType: zStatelessFieldType.optional(),
});
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('IPAdapterModelField'),
originalType: zStatelessFieldType.optional(),
});
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
});
const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
zStringFieldType,
zBooleanFieldType,
zEnumFieldType,
zImageFieldType,
zBoardFieldType,
zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
export const isStatefulFieldType = (fieldType: FieldType): fieldType is StatefulFieldType =>
(statefulFieldTypeNames as string[]).includes(fieldType.name);
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
export type FieldType = z.infer<typeof zFieldType>;
export const isSingle = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.SINGLE;
export const isCollection = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.COLLECTION;
export const isSingleOrCollection = (fieldType: FieldType): boolean =>
fieldType.cardinality === zCardinality.enum.SINGLE_OR_COLLECTION;
// #endregion
// #region IntegerField
export const zIntegerFieldValue = z.number().int();
const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIntegerFieldValue,
});
const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zIntegerFieldType,
originalType: zFieldType.optional(),
default: zIntegerFieldValue,
multipleOf: z.number().int().optional(),
maximum: z.number().int().optional(),
@@ -205,14 +96,15 @@ export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldIn
// #endregion
// #region FloatField
const zFloatFieldType = zFieldTypeBase.extend({
name: z.literal('FloatField'),
});
export const zFloatFieldValue = z.number();
const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFloatFieldValue,
});
const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFloatFieldType,
originalType: zFieldType.optional(),
default: zFloatFieldValue,
multipleOf: z.number().optional(),
maximum: z.number().optional(),
@@ -233,14 +125,15 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT
// #endregion
// #region StringField
const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
});
export const zStringFieldValue = z.string();
const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStringFieldValue,
});
const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStringFieldType,
originalType: zFieldType.optional(),
default: zStringFieldValue,
maxLength: z.number().int().optional(),
minLength: z.number().int().optional(),
@@ -259,14 +152,15 @@ export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInpu
// #endregion
// #region BooleanField
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
});
export const zBooleanFieldValue = z.boolean();
const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
value: zBooleanFieldValue,
});
const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zBooleanFieldType,
originalType: zFieldType.optional(),
default: zBooleanFieldValue,
});
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -282,14 +176,15 @@ export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldIn
// #endregion
// #region EnumField
const zEnumFieldType = zFieldTypeBase.extend({
name: z.literal('EnumField'),
});
export const zEnumFieldValue = z.string();
const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
value: zEnumFieldValue,
});
const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zEnumFieldType,
originalType: zFieldType.optional(),
default: zEnumFieldValue,
options: z.array(z.string()),
labels: z.record(z.string()).optional(),
@@ -307,14 +202,15 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem
// #endregion
// #region ImageField
const zImageFieldType = zFieldTypeBase.extend({
name: z.literal('ImageField'),
});
export const zImageFieldValue = zImageField.optional();
const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImageFieldValue,
});
const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImageFieldType,
originalType: zFieldType.optional(),
default: zImageFieldValue,
});
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -330,14 +226,15 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT
// #endregion
// #region BoardField
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
});
export const zBoardFieldValue = zBoardField.optional();
const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
value: zBoardFieldValue,
});
const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zBoardFieldType,
originalType: zFieldType.optional(),
default: zBoardFieldValue,
});
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -353,14 +250,15 @@ export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputT
// #endregion
// #region ColorField
const zColorFieldType = zFieldTypeBase.extend({
name: z.literal('ColorField'),
});
export const zColorFieldValue = zColorField.optional();
const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
value: zColorFieldValue,
});
const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zColorFieldType,
originalType: zFieldType.optional(),
default: zColorFieldValue,
});
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -376,14 +274,15 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
// #endregion
// #region MainModelField
const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'),
});
export const zMainModelFieldValue = zModelIdentifierField.optional();
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zMainModelFieldValue,
});
const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zMainModelFieldType,
originalType: zFieldType.optional(),
default: zMainModelFieldValue,
});
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -398,37 +297,16 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie
zMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region ModelIdentifierField
export const zModelIdentifierFieldValue = zModelIdentifierField.optional();
const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({
value: zModelIdentifierFieldValue,
});
const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zModelIdentifierFieldType,
originalType: zFieldType.optional(),
default: zModelIdentifierFieldValue,
});
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zModelIdentifierFieldType,
});
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
zModelIdentifierFieldInputInstance.safeParse(val).success;
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
zModelIdentifierFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SDXLMainModelField
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
});
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSDXLMainModelFieldValue,
});
const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSDXLMainModelFieldType,
originalType: zFieldType.optional(),
default: zSDXLMainModelFieldValue,
});
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -443,7 +321,9 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
// #endregion
// #region SDXLRefinerModelField
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
});
/** @alias */ // tells knip to ignore this duplicate export
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -451,7 +331,6 @@ const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
});
const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSDXLRefinerModelFieldType,
originalType: zFieldType.optional(),
default: zSDXLRefinerModelFieldValue,
});
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -467,14 +346,15 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
// #endregion
// #region VAEModelField
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
});
export const zVAEModelFieldValue = zModelIdentifierField.optional();
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zVAEModelFieldValue,
});
const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zVAEModelFieldType,
originalType: zFieldType.optional(),
default: zVAEModelFieldValue,
});
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -490,14 +370,15 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
// #endregion
// #region LoRAModelField
const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
});
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLoRAModelFieldValue,
});
const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zLoRAModelFieldType,
originalType: zFieldType.optional(),
default: zLoRAModelFieldValue,
});
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -513,14 +394,15 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
// #endregion
// #region ControlNetModelField
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
});
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zControlNetModelFieldValue,
});
const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zControlNetModelFieldType,
originalType: zFieldType.optional(),
default: zControlNetModelFieldValue,
});
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -536,14 +418,15 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
// #endregion
// #region IPAdapterModelField
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('IPAdapterModelField'),
});
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIPAdapterModelFieldValue,
});
const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zIPAdapterModelFieldType,
originalType: zFieldType.optional(),
default: zIPAdapterModelFieldValue,
});
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -559,14 +442,15 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
// #endregion
// #region T2IAdapterField
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'),
});
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT2IAdapterModelFieldValue,
});
const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zT2IAdapterModelFieldType,
originalType: zFieldType.optional(),
default: zT2IAdapterModelFieldValue,
});
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -582,14 +466,15 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
// #endregion
// #region SchedulerField
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
});
export const zSchedulerFieldValue = zSchedulerField.optional();
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSchedulerFieldValue,
});
const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSchedulerFieldType,
originalType: zFieldType.optional(),
default: zSchedulerFieldValue,
});
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -616,14 +501,15 @@ export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFie
* - Reserved fields like IsIntermediate
* - Any other field we don't have full-on schemas for
*/
const zStatelessFieldType = zFieldTypeBase.extend({
name: z.string().min(1), // stateless --> we accept the field's name as the type
});
const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStatelessFieldValue,
});
const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStatelessFieldType,
originalType: zFieldType.optional(),
default: zStatelessFieldValue,
input: z.literal('connection'), // stateless --> only accepts connection inputs
});
@@ -649,6 +535,34 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
* for all other StatelessFields.
*/
// #region StatefulFieldType & FieldType
const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
zStringFieldType,
zBooleanFieldType,
zEnumFieldType,
zImageFieldType,
zBoardFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
export const isStatefulFieldType = (val: unknown): val is StatefulFieldType =>
zStatefulFieldType.safeParse(val).success;
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
export type FieldType = z.infer<typeof zFieldType>;
// #endregion
// #region StatefulFieldValue & FieldValue
export const zStatefulFieldValue = z.union([
zIntegerFieldValue,
@@ -658,7 +572,6 @@ export const zStatefulFieldValue = z.union([
zEnumFieldValue,
zImageFieldValue,
zBoardFieldValue,
zModelIdentifierFieldValue,
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zSDXLRefinerModelFieldValue,
@@ -685,7 +598,6 @@ const zStatefulFieldInputInstance = z.union([
zEnumFieldInputInstance,
zImageFieldInputInstance,
zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
@@ -713,7 +625,6 @@ const zStatefulFieldInputTemplate = z.union([
zEnumFieldInputTemplate,
zImageFieldInputTemplate,
zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
@@ -742,7 +653,6 @@ const zStatefulFieldOutputTemplate = z.union([
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,

View File

@@ -1,4 +1,4 @@
import type { StatefulFieldType, StatelessFieldType } from 'features/nodes/types/v2/field';
import type { FieldType, StatefulFieldType } from 'features/nodes/types/field';
import type { FieldTypeV1 } from './workflowV1';
@@ -165,7 +165,7 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
* Thus, this object was manually edited to ensure it is correct.
*/
const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
[key in FieldTypeV1]?: StatelessFieldType;
[key in FieldTypeV1]?: FieldType;
} = {
Any: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false },
ClipField: {

View File

@@ -316,7 +316,6 @@ const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({
const zStatelessFieldType = zFieldTypeBase.extend({
name: z.string().min(1), // stateless --> we accept the field's name as the type
});
export type StatelessFieldType = z.infer<typeof zStatelessFieldType>;
const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
type: zStatelessFieldType,
@@ -328,27 +327,6 @@ const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({
// #endregion
const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
zStringFieldType,
zBooleanFieldType,
zEnumFieldType,
zImageFieldType,
zBoardFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
/**
* Here we define the main field unions:
* - FieldType

View File

@@ -47,7 +47,6 @@ const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({
type: z.literal('default'),
sourceHandle: z.string().trim().min(1),
targetHandle: z.string().trim().min(1),
hidden: z.boolean().optional(),
});
const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({
type: z.literal('collapsed'),

View File

@@ -29,7 +29,7 @@ export const addControlNetToLinearGraph = async (
assert(activeTabName !== 'generation', 'Tried to use addControlNetToLinearGraph on generation tab');
if (controlNets.length) {
// Even though denoise_latents' control input is SINGLE_OR_COLLECTION, keep it simple and always use a collect
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
const controlNetIterateNode: Invocation<'collect'> = {
id: CONTROL_NET_COLLECT,
type: 'collect',

View File

@@ -25,7 +25,7 @@ export const addIPAdapterToLinearGraph = async (
});
if (ipAdapters.length) {
// Even though denoise_latents' ip adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: Invocation<'collect'> = {
id: IP_ADAPTER_COLLECT,
type: 'collect',

View File

@@ -28,7 +28,7 @@ export const addT2IAdaptersToLinearGraph = async (
);
if (t2iAdapters.length) {
// Even though denoise_latents' t2i adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: Invocation<'collect'> = {
id: T2I_ADAPTER_COLLECT,
type: 'collect',

View File

@@ -330,7 +330,6 @@ export const buildCanvasImageToImageGraph = async (
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
import {
CANVAS_INPAINT_GRAPH,
CANVAS_OUTPUT,
@@ -422,15 +421,6 @@ export const buildCanvasInpaintGraph = async (
});
}
addCoreMetadataNode(
graph,
{
generation_mode: 'inpaint',
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
import {
CANVAS_OUTPAINT_GRAPH,
CANVAS_OUTPUT,
@@ -580,15 +579,6 @@ export const buildCanvasOutpaintGraph = async (
);
}
addCoreMetadataNode(
graph,
{
generation_mode: 'outpaint',
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);

View File

@@ -332,7 +332,6 @@ export const buildCanvasSDXLImageToImageGraph = async (
init_image: initialImage.image_name,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
import {
CANVAS_OUTPUT,
INPAINT_CREATE_MASK,
@@ -433,15 +432,6 @@ export const buildCanvasSDXLInpaintGraph = async (
});
}
addCoreMetadataNode(
graph,
{
generation_mode: 'sdxl_inpaint',
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
import {
CANVAS_OUTPUT,
INPAINT_CREATE_MASK,
@@ -589,15 +588,6 @@ export const buildCanvasSDXLOutpaintGraph = async (
);
}
addCoreMetadataNode(
graph,
{
generation_mode: 'sdxl_outpaint',
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);

View File

@@ -291,7 +291,6 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);

View File

@@ -280,7 +280,6 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
clip_skip: clipSkip,
_canvas_objects: state.canvas.layerState.objects,
},
CANVAS_OUTPUT
);

View File

@@ -11,6 +11,8 @@ export const addSDXLLoRas = (
denoise: Invocation<'denoise_latents'>,
modelLoader: Invocation<'sdxl_model_loader'>,
seamless: Invocation<'seamless'> | null,
clipSkip: Invocation<'clip_skip'>,
clipSkip2: Invocation<'clip_skip'>,
posCond: Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'sdxl_compel_prompt'>
): void => {
@@ -37,8 +39,8 @@ export const addSDXLLoRas = (
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use seamless as UNet input if it exists, otherwise use the model loader
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2');
g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip');
g.addEdge(clipSkip2, 'clip', loraCollectionLoader, 'clip2');
// Reroute UNet & CLIP connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['unet']);
g.deleteEdgesTo(posCond, ['clip', 'clip2']);

View File

@@ -1,6 +1,7 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
CLIP_SKIP,
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
@@ -29,6 +30,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
clipSkip: skipped_layers,
scheduler,
seed,
steps,
@@ -51,6 +53,16 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
id: SDXL_MODEL_LOADER,
model,
});
const clipSkip = g.addNode({
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers,
});
const clipSkip2 = g.addNode({
type: 'clip_skip',
id: `${CLIP_SKIP}_2`,
skipped_layers,
});
const posCond = g.addNode({
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
@@ -103,10 +115,12 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
g.addEdge(modelLoader, 'clip2', clipSkip2, 'clip');
g.addEdge(clipSkip, 'clip', posCond, 'clip');
g.addEdge(clipSkip, 'clip', negCond, 'clip');
g.addEdge(clipSkip2, 'clip', posCond, 'clip2');
g.addEdge(clipSkip2, 'clip', negCond, 'clip2');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
@@ -132,12 +146,13 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
scheduler,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
clip_skip: skipped_layers,
vae: vae ?? undefined,
});
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond);
addSDXLLoRas(state, g, denoise, modelLoader, seamless, clipSkip, clipSkip2, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;

Some files were not shown because too many files have changed in this diff Show More