mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 15:27:55 -05:00
Compare commits
1 Commits
v4.2.2
...
image-capt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59327e827b |
@@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
|
||||
|
||||
"Custom" fields will always be treated as stateless fields.
|
||||
|
||||
##### Single and Collection Fields
|
||||
##### Collection and Scalar Fields
|
||||
|
||||
Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field.
|
||||
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
|
||||
|
||||
- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`).
|
||||
- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`).
|
||||
- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
||||
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
|
||||
|
||||
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
||||
|
||||
## Implementation
|
||||
|
||||
@@ -173,7 +173,8 @@ Field types are represented as structured objects:
|
||||
```ts
|
||||
type FieldType = {
|
||||
name: string;
|
||||
cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION';
|
||||
isCollection: boolean;
|
||||
isCollectionOrScalar: boolean;
|
||||
};
|
||||
```
|
||||
|
||||
@@ -185,7 +186,7 @@ There are 4 general cases for field type parsing.
|
||||
|
||||
When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property.
|
||||
|
||||
We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`.
|
||||
We create a field type name from this `type` string (e.g. `string` -> `StringField`).
|
||||
|
||||
##### Complex Types
|
||||
|
||||
@@ -199,13 +200,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi
|
||||
|
||||
When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type.
|
||||
|
||||
We use the item type for field type name. The cardinality is `"COLLECTION"`.
|
||||
We use the item type for field type name, adding `isCollection: true` to the field type.
|
||||
|
||||
##### Single or Collection Types
|
||||
##### Collection or Scalar Types
|
||||
|
||||
When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union.
|
||||
|
||||
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`.
|
||||
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type.
|
||||
|
||||
##### Optional Fields
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The
|
||||
|
||||
### Requirements
|
||||
|
||||
Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md).
|
||||
Before you start, go through the [installation requirements].
|
||||
|
||||
### Installation Walkthrough
|
||||
|
||||
@@ -79,7 +79,7 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
|
||||
|
||||
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
|
||||
|
||||
- You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
|
||||
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
|
||||
|
||||
!!! example "Install with an extra index URL"
|
||||
|
||||
@@ -116,4 +116,4 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
|
||||
|
||||
!!! warning
|
||||
|
||||
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
|
||||
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
|
||||
|
||||
@@ -6,12 +6,13 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@@ -41,17 +42,13 @@ async def upload_image(
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
metadata: Optional[JsonValue] = Body(
|
||||
default=None, description="The metadata to associate with the image", embed=True
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
_metadata = None
|
||||
_workflow = None
|
||||
_graph = None
|
||||
metadata = None
|
||||
workflow = None
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
@@ -65,28 +62,22 @@ async def upload_image(
|
||||
|
||||
# TODO: retain non-invokeai metadata on upload?
|
||||
# attempt to parse metadata from image
|
||||
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
|
||||
if isinstance(metadata_raw, str):
|
||||
_metadata = metadata_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
metadata_raw = pil_image.info.get("invokeai_metadata", None)
|
||||
if metadata_raw:
|
||||
try:
|
||||
metadata = MetadataFieldValidator.validate_json(metadata_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
||||
# attempt to parse workflow from image
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if isinstance(workflow_raw, str):
|
||||
_workflow = workflow_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
|
||||
pass
|
||||
|
||||
# attempt to extract graph from image
|
||||
graph_raw = pil_image.info.get("invokeai_graph", None)
|
||||
if isinstance(graph_raw, str):
|
||||
_graph = graph_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image")
|
||||
pass
|
||||
if workflow_raw is not None:
|
||||
try:
|
||||
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
@@ -95,9 +86,8 @@ async def upload_image(
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
board_id=board_id,
|
||||
metadata=_metadata,
|
||||
workflow=_workflow,
|
||||
graph=_graph,
|
||||
metadata=metadata,
|
||||
workflow=workflow,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
@@ -195,21 +185,14 @@ async def get_image_metadata(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
class WorkflowAndGraphResponse(BaseModel):
|
||||
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
|
||||
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
|
||||
)
|
||||
async def get_image_workflow(
|
||||
image_name: str = Path(description="The name of image whose workflow to get"),
|
||||
) -> WorkflowAndGraphResponse:
|
||||
) -> Optional[WorkflowWithoutID]:
|
||||
try:
|
||||
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
|
||||
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
|
||||
return ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
@@ -79,13 +80,13 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, List, Union
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@@ -15,7 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.primitives import ImageOutput, CaptionImageOutputs, CaptionImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
@@ -66,6 +67,56 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"auto_caption_image",
|
||||
title="Automatically Caption Image",
|
||||
tags=["image", "caption"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
)
|
||||
class CaptionImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Adds a caption to an image"""
|
||||
|
||||
images: Union[ImageField,List[ImageField]] = InputField(description="The image to caption")
|
||||
prompt: str = InputField(default="Describe this list of images in 20 words or less", description="Describe how you would like the image to be captioned.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CaptionImageOutputs:
|
||||
|
||||
model_id = "vikhyatk/moondream2"
|
||||
model_revision = "2024-04-02"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
|
||||
moondream_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, trust_remote_code=True, revision=model_revision
|
||||
)
|
||||
output: CaptionImageOutputs = CaptionImageOutputs()
|
||||
try:
|
||||
from PIL.Image import Image
|
||||
images: List[Image] = []
|
||||
image_fields = self.images if isinstance(self.images, list) else [self.images]
|
||||
for image in image_fields:
|
||||
images.append(context.images.get_pil(image.image_name))
|
||||
answers: List[str] = moondream_model.batch_answer(
|
||||
images=images,
|
||||
prompts=[self.prompt] * len(images),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert isinstance(answers, list)
|
||||
for i, answer in enumerate(answers):
|
||||
output.images.append(CaptionImageOutput(
|
||||
image=image_fields[i],
|
||||
width=images[i].width,
|
||||
height=images[i].height,
|
||||
caption=answer
|
||||
))
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
del moondream_model
|
||||
del tokenizer
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
@@ -194,7 +245,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
image: List[ImageField] = InputField(description="The image to create the mask from")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -67,6 +67,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -94,46 +93,19 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
pass
|
||||
|
||||
|
||||
@invocation_output("model_identifier_output")
|
||||
class ModelIdentifierOutput(BaseInvocationOutput):
|
||||
"""Model identifier output"""
|
||||
|
||||
model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")
|
||||
|
||||
|
||||
@invocation(
|
||||
"model_identifier",
|
||||
title="Model identifier",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ModelIdentifierInvocation(BaseInvocation):
|
||||
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
|
||||
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
|
||||
error."""
|
||||
|
||||
model: ModelIdentifierField = InputField(description="The model to select", title="Model")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
|
||||
if not context.models.exists(self.model.key):
|
||||
raise Exception(f"Unknown model {self.model.key}")
|
||||
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.2",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
@@ -162,12 +134,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@@ -225,12 +197,12 @@ class LoRASelectorOutput(BaseInvocationOutput):
|
||||
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
|
||||
|
||||
|
||||
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
|
||||
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0")
|
||||
class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
@@ -301,13 +273,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.2",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@@ -442,12 +414,12 @@ class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -247,6 +247,17 @@ class ImageOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("captioned_image_output")
|
||||
class CaptionImageOutput(ImageOutput):
|
||||
caption: str = OutputField(description="Caption for given image")
|
||||
|
||||
|
||||
|
||||
@invocation_output("captioned_image_outputs")
|
||||
class CaptionImageOutputs(BaseInvocationOutput):
|
||||
images: List[CaptionImageOutput] = OutputField(description="List of captioned images", default=[])
|
||||
|
||||
|
||||
@invocation_output("image_collection_output")
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
@@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
title="SDXL Refiner Model",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.2",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2"
|
||||
)
|
||||
class T2IAdapterInvocation(BaseInvocation):
|
||||
"""Collects T2I-Adapter info to pass to other nodes."""
|
||||
@@ -55,6 +55,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
t2i_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.T2IAdapterModel,
|
||||
)
|
||||
|
||||
@@ -122,8 +122,6 @@ class EventServiceBase:
|
||||
source_node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
user_id: str | None,
|
||||
project_id: str | None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
@@ -137,8 +135,6 @@ class EventServiceBase:
|
||||
"source_node_id": source_node_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
@@ -30,9 +33,8 @@ class ImageFileStorageBase(ABC):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
@@ -44,11 +46,6 @@ class ImageFileStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[str]:
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
"""Gets the workflow of an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_graph(self, image_name: str) -> Optional[str]:
|
||||
"""Gets the graph of an image."""
|
||||
pass
|
||||
|
||||
@@ -7,7 +7,9 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
from .image_files_base import ImageFileStorageBase
|
||||
@@ -54,9 +56,8 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -67,14 +68,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
info_dict = {}
|
||||
|
||||
if metadata is not None:
|
||||
info_dict["invokeai_metadata"] = metadata
|
||||
pnginfo.add_text("invokeai_metadata", metadata)
|
||||
metadata_json = metadata.model_dump_json()
|
||||
info_dict["invokeai_metadata"] = metadata_json
|
||||
pnginfo.add_text("invokeai_metadata", metadata_json)
|
||||
if workflow is not None:
|
||||
info_dict["invokeai_workflow"] = workflow
|
||||
pnginfo.add_text("invokeai_workflow", workflow)
|
||||
if graph is not None:
|
||||
info_dict["invokeai_graph"] = graph
|
||||
pnginfo.add_text("invokeai_graph", graph)
|
||||
workflow_json = workflow.model_dump_json()
|
||||
info_dict["invokeai_workflow"] = workflow_json
|
||||
pnginfo.add_text("invokeai_workflow", workflow_json)
|
||||
|
||||
# When saving the image, the image object's info field is not populated. We need to set it
|
||||
image.info = info_dict
|
||||
@@ -129,18 +129,11 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
||||
def get_workflow(self, image_name: str) -> str | None:
|
||||
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
|
||||
image = self.get(image_name)
|
||||
workflow = image.info.get("invokeai_workflow", None)
|
||||
if isinstance(workflow, str):
|
||||
return workflow
|
||||
return None
|
||||
|
||||
def get_graph(self, image_name: str) -> str | None:
|
||||
image = self.get(image_name)
|
||||
graph = image.info.get("invokeai_graph", None)
|
||||
if isinstance(graph, str):
|
||||
return graph
|
||||
if workflow is not None:
|
||||
return WorkflowWithoutID.model_validate_json(workflow)
|
||||
return None
|
||||
|
||||
def __validate_storage_folders(self) -> None:
|
||||
|
||||
@@ -80,7 +80,7 @@ class ImageRecordStorageBase(ABC):
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@@ -328,9 +328,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = metadata.model_dump_json() if metadata is not None else None
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -357,7 +358,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
|
||||
@@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
@@ -50,9 +51,8 @@ class ImageServiceABC(ABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@@ -87,12 +87,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[str]:
|
||||
"""Gets an image's workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_graph(self, image_name: str) -> Optional[str]:
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
"""Gets an image's workflow."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from PIL.Image import Image as PILImageType
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from ..image_files.image_files_common import (
|
||||
ImageFileDeleteException,
|
||||
@@ -41,9 +42,8 @@ class ImageService(ImageServiceABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
@@ -64,7 +64,7 @@ class ImageService(ImageServiceABC):
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
has_workflow=workflow is not None or graph is not None,
|
||||
has_workflow=workflow is not None,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
@@ -75,7 +75,7 @@ class ImageService(ImageServiceABC):
|
||||
if board_id is not None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||
)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
@@ -157,7 +157,7 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image metadata")
|
||||
raise e
|
||||
|
||||
def get_workflow(self, image_name: str) -> Optional[str]:
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
try:
|
||||
return self.__invoker.services.image_files.get_workflow(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
@@ -167,16 +167,6 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image workflow")
|
||||
raise
|
||||
|
||||
def get_graph(self, image_name: str) -> Optional[str]:
|
||||
try:
|
||||
return self.__invoker.services.image_files.get_graph(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self.__invoker.services.logger.error("Image file not found")
|
||||
raise
|
||||
except Exception:
|
||||
self.__invoker.services.logger.error("Problem getting image graph")
|
||||
raise
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
|
||||
|
||||
@@ -237,8 +237,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
user_id=None,
|
||||
project_id=None,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@@ -180,9 +180,9 @@ class ImagesInterface(InvocationContextInterface):
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
metadata_ = metadata.model_dump_json()
|
||||
elif isinstance(self._data.invocation, WithMetadata) and self._data.invocation.metadata:
|
||||
metadata_ = self._data.invocation.metadata.model_dump_json()
|
||||
metadata_ = metadata
|
||||
elif isinstance(self._data.invocation, WithMetadata):
|
||||
metadata_ = self._data.invocation.metadata
|
||||
|
||||
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
|
||||
board_id_ = None
|
||||
@@ -191,14 +191,6 @@ class ImagesInterface(InvocationContextInterface):
|
||||
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
|
||||
board_id_ = self._data.invocation.board.board_id
|
||||
|
||||
workflow_ = None
|
||||
if self._data.queue_item.workflow:
|
||||
workflow_ = self._data.queue_item.workflow.model_dump_json()
|
||||
|
||||
graph_ = None
|
||||
if self._data.queue_item.session.graph:
|
||||
graph_ = self._data.queue_item.session.graph.model_dump_json()
|
||||
|
||||
return self._services.images.create(
|
||||
image=image,
|
||||
is_intermediate=self._data.invocation.is_intermediate,
|
||||
@@ -206,8 +198,7 @@ class ImagesInterface(InvocationContextInterface):
|
||||
board_id=board_id_,
|
||||
metadata=metadata_,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
workflow=workflow_,
|
||||
graph=graph_,
|
||||
workflow=self._data.queue_item.workflow,
|
||||
session_id=self._data.queue_item.session_id,
|
||||
node_id=self._data.invocation.id,
|
||||
)
|
||||
|
||||
@@ -775,14 +775,10 @@
|
||||
"cannotConnectToSelf": "Cannot connect to self",
|
||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
||||
"missingNode": "Missing invocation node",
|
||||
"missingInvocationTemplate": "Missing invocation template",
|
||||
"missingFieldTemplate": "Missing field template",
|
||||
"nodePack": "Node pack",
|
||||
"collection": "Collection",
|
||||
"singleFieldType": "{{name}} (Single)",
|
||||
"collectionFieldType": "{{name}} (Collection)",
|
||||
"collectionOrScalarFieldType": "{{name}} (Single or Collection)",
|
||||
"collectionFieldType": "{{name}} Collection",
|
||||
"collectionOrScalarFieldType": "{{name}} Collection|Scalar",
|
||||
"colorCodeEdges": "Color-Code Edges",
|
||||
"colorCodeEdgesHelp": "Color-code edges according to their connected fields",
|
||||
"connectionWouldCreateCycle": "Connection would create a cycle",
|
||||
@@ -884,7 +880,6 @@
|
||||
"versionUnknown": " Version Unknown",
|
||||
"workflow": "Workflow",
|
||||
"graph": "Graph",
|
||||
"noGraph": "No Graph",
|
||||
"workflowAuthor": "Author",
|
||||
"workflowContact": "Contact",
|
||||
"workflowDescription": "Short Description",
|
||||
@@ -952,7 +947,7 @@
|
||||
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
|
||||
"controlAdapterNoImageSelected": "no Control Adapter image selected",
|
||||
"controlAdapterImageNotProcessed": "Control Adapter image not processed",
|
||||
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}",
|
||||
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64",
|
||||
"ipAdapterNoModelSelected": "no IP adapter selected",
|
||||
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
|
||||
"ipAdapterNoImageSelected": "no IP Adapter image selected",
|
||||
|
||||
@@ -21,7 +21,6 @@ import i18n from 'i18n';
|
||||
import { size } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import PreselectedImage from './PreselectedImage';
|
||||
@@ -47,7 +46,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
||||
useSocketIO();
|
||||
useGlobalModifiersInit();
|
||||
useGlobalHotkeys();
|
||||
useGetOpenAPISchemaQuery();
|
||||
|
||||
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
@@ -44,9 +43,6 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
|
||||
type: 'TOAST',
|
||||
toastOptions: { title: t('toast.canvasSavedGallery') },
|
||||
},
|
||||
metadata: {
|
||||
_canvas_objects: parseify(state.canvas.layerState.objects),
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
|
||||
@@ -16,7 +16,6 @@ import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig } from 'services/api/types';
|
||||
@@ -48,10 +47,8 @@ const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batc
|
||||
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher,
|
||||
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
|
||||
effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => {
|
||||
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
|
||||
const state = getState();
|
||||
const originalState = getOriginalState();
|
||||
|
||||
// Cancel any in-progress instances of this listener
|
||||
cancelActiveListeners();
|
||||
@@ -60,33 +57,21 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
// Delay before starting actual work
|
||||
await delay(DEBOUNCE_MS);
|
||||
|
||||
// Double-check that we are still eligible for processing
|
||||
const state = getState();
|
||||
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
|
||||
|
||||
// If we have no image or there is no processor config, bail
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We should only process if the processor settings or image have changed
|
||||
const originalLayer = originalState.controlLayers.present.layers
|
||||
.filter(isControlAdapterLayer)
|
||||
.find((l) => l.id === layerId);
|
||||
const originalImage = originalLayer?.controlAdapter.image;
|
||||
const originalConfig = originalLayer?.controlAdapter.processorConfig;
|
||||
|
||||
const image = layer.controlAdapter.image;
|
||||
const config = layer.controlAdapter.processorConfig;
|
||||
|
||||
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
|
||||
// Neither config nor image have changed, we can bail
|
||||
return;
|
||||
}
|
||||
|
||||
if (!image || !config) {
|
||||
// - If we have no image, we have nothing to process
|
||||
// - If we have no processor config, we have nothing to process
|
||||
// Clear the processed image and bail
|
||||
// The user has reset the image or config, so we should clear the processed image
|
||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
|
||||
@@ -96,8 +81,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
|
||||
}
|
||||
|
||||
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
|
||||
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
|
||||
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error...
|
||||
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config);
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
|
||||
@@ -18,7 +18,6 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
nes.progressImage = progress_image ?? null;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
|
||||
@@ -31,12 +31,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
||||
}
|
||||
try {
|
||||
const updatedNode = updateNode(node, template);
|
||||
dispatch(
|
||||
nodesChanged([
|
||||
{ type: 'remove', id: updatedNode.id },
|
||||
{ type: 'add', item: updatedNode },
|
||||
])
|
||||
);
|
||||
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
|
||||
} catch (e) {
|
||||
if (e instanceof NodeUpdateError) {
|
||||
unableToUpdateCount++;
|
||||
|
||||
@@ -4,49 +4,31 @@ import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
|
||||
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
|
||||
if (data.workflow) {
|
||||
// Prefer to load the workflow if it's available - it has more information
|
||||
const parsed = JSON.parse(data.workflow);
|
||||
return validateWorkflow(parsed, templates);
|
||||
} else if (data.graph) {
|
||||
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
|
||||
const parsed = JSON.parse(data.graph);
|
||||
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
|
||||
return validateWorkflow(workflow, templates);
|
||||
} else {
|
||||
throw new Error('No workflow or graph provided');
|
||||
}
|
||||
};
|
||||
|
||||
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch }) => {
|
||||
const log = logger('nodes');
|
||||
const { data, asCopy } = action.payload;
|
||||
const { workflow, asCopy } = action.payload;
|
||||
const nodeTemplates = $templates.get();
|
||||
|
||||
try {
|
||||
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||
|
||||
if (asCopy) {
|
||||
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
|
||||
delete workflow.id;
|
||||
delete validatedWorkflow.id;
|
||||
}
|
||||
|
||||
dispatch(workflowLoaded(workflow));
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
if (!warnings.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
|
||||
@@ -137,7 +137,7 @@ const createSelector = (templates: Templates) =>
|
||||
if (l.controlAdapter.type === 't2i_adapter') {
|
||||
const multiple = model?.base === 'sdxl' ? 32 : 64;
|
||||
if (size.width % multiple !== 0 || size.height % multiple !== 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
|
||||
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -616,24 +616,12 @@ export const controlLayersSlice = createSlice({
|
||||
iiLayerAdded: {
|
||||
reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
|
||||
const { layerId, imageDTO } = action.payload;
|
||||
|
||||
// Retain opacity and denoising strength of existing initial image layer if exists
|
||||
let opacity = 1;
|
||||
let denoisingStrength = 0.75;
|
||||
const iiLayer = state.layers.find((l) => l.id === layerId);
|
||||
if (iiLayer) {
|
||||
assert(isInitialImageLayer(iiLayer));
|
||||
opacity = iiLayer.opacity;
|
||||
denoisingStrength = iiLayer.denoisingStrength;
|
||||
}
|
||||
|
||||
// Highlander! There can be only one!
|
||||
state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true));
|
||||
|
||||
const layer: InitialImageLayer = {
|
||||
id: layerId,
|
||||
type: 'initial_image_layer',
|
||||
opacity,
|
||||
opacity: 1,
|
||||
x: 0,
|
||||
y: 0,
|
||||
bbox: null,
|
||||
@@ -641,7 +629,7 @@ export const controlLayersSlice = createSlice({
|
||||
isEnabled: true,
|
||||
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
|
||||
isSelected: true,
|
||||
denoisingStrength,
|
||||
denoisingStrength: 0.75,
|
||||
};
|
||||
state.layers.push(layer);
|
||||
exclusivelySelectLayer(state, layer.id);
|
||||
|
||||
@@ -11,12 +11,10 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
|
||||
import { size } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -50,7 +48,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const isCanvasEnabled = useFeatureStatus('canvas');
|
||||
const customStarUi = useStore($customStarUI);
|
||||
const { downloadImage } = useDownloadImage();
|
||||
const templates = useStore($templates);
|
||||
|
||||
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } =
|
||||
useImageActions(imageDTO?.image_name);
|
||||
@@ -136,7 +133,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
<MenuItem
|
||||
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
|
||||
onClickCapture={handleLoadWorkflow}
|
||||
isDisabled={!imageDTO.has_workflow || !size(templates)}
|
||||
isDisabled={!imageDTO.has_workflow}
|
||||
>
|
||||
{t('nodes.loadWorkflow')}
|
||||
</MenuItem>
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import DataViewer from './DataViewer';
|
||||
|
||||
type Props = {
|
||||
image: ImageDTO;
|
||||
};
|
||||
|
||||
const ImageMetadataGraphTabContent = ({ image }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { currentData } = useDebouncedImageWorkflow(image);
|
||||
const graph = useMemo(() => {
|
||||
if (currentData?.graph) {
|
||||
try {
|
||||
return JSON.parse(currentData.graph);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}, [currentData]);
|
||||
|
||||
if (!graph) {
|
||||
return <IAINoContentFallback label={t('nodes.noGraph')} />;
|
||||
}
|
||||
|
||||
return <DataViewer data={graph} label={t('nodes.graph')} />;
|
||||
};
|
||||
|
||||
export default memo(ImageMetadataGraphTabContent);
|
||||
@@ -1,7 +1,6 @@
|
||||
import { ExternalLink, Flex, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import ImageMetadataGraphTabContent from 'features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent';
|
||||
import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { memo } from 'react';
|
||||
@@ -53,7 +52,6 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
<Tab>{t('metadata.metadata')}</Tab>
|
||||
<Tab>{t('metadata.imageDetails')}</Tab>
|
||||
<Tab>{t('metadata.workflow')}</Tab>
|
||||
<Tab>{t('nodes.graph')}</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels>
|
||||
@@ -83,9 +81,6 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
<TabPanel>
|
||||
<ImageMetadataWorkflowTabContent image={image} />
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<ImageMetadataGraphTabContent image={image} />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -12,17 +12,7 @@ type Props = {
|
||||
|
||||
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { currentData } = useDebouncedImageWorkflow(image);
|
||||
const workflow = useMemo(() => {
|
||||
if (currentData?.workflow) {
|
||||
try {
|
||||
return JSON.parse(currentData.workflow);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}, [currentData]);
|
||||
const { workflow } = useDebouncedImageWorkflow(image);
|
||||
|
||||
if (!workflow) {
|
||||
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
@@ -13,14 +12,12 @@ import { sentImageToImg2Img } from 'features/gallery/store/actions';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings';
|
||||
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
|
||||
import { size } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -51,7 +48,7 @@ const CurrentImageButtons = () => {
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
const selection = useAppSelector((s) => s.gallery.selection);
|
||||
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
||||
const templates = useStore($templates);
|
||||
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
const { t } = useTranslation();
|
||||
@@ -146,7 +143,7 @@ const CurrentImageButtons = () => {
|
||||
icon={<PiFlowArrowBold />}
|
||||
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||
isDisabled={!imageDTO?.has_workflow || !size(templates)}
|
||||
isDisabled={!imageDTO?.has_workflow}
|
||||
onClick={handleLoadWorkflow}
|
||||
isLoading={getAndLoadEmbeddedWorkflowResult.isLoading}
|
||||
/>
|
||||
|
||||
@@ -9,29 +9,27 @@ import type { SelectInstance } from 'chakra-react-select';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import {
|
||||
$cursorPos,
|
||||
$edgePendingUpdate,
|
||||
$isAddNodePopoverOpen,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
closeAddNodePopover,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
connectionMade,
|
||||
nodeAdded,
|
||||
openAddNodePopover,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { filter, map, memoize, some } from 'lodash-es';
|
||||
import type { KeyboardEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
||||
import type { EdgeChange, NodeChange } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const createRegex = memoize(
|
||||
(inputValue: string) =>
|
||||
@@ -71,19 +69,17 @@ const AddNodePopover = () => {
|
||||
|
||||
const filteredTemplates = useMemo(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const templatesArray = map(templates);
|
||||
if (!pendingConnection) {
|
||||
return templatesArray;
|
||||
return map(templates);
|
||||
}
|
||||
|
||||
return filter(templates, (template) => {
|
||||
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
|
||||
return some(candidateFields, (field) => {
|
||||
const sourceType =
|
||||
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType =
|
||||
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
return validateConnectionTypes(sourceType, targetType);
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
|
||||
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
|
||||
return some(fields, (field) => {
|
||||
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
});
|
||||
}, [templates, pendingConnection]);
|
||||
@@ -133,37 +129,11 @@ const AddNodePopover = () => {
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
// Find a cozy spot for the node
|
||||
const cursorPos = $cursorPos.get();
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y);
|
||||
node.selected = true;
|
||||
|
||||
// Deselect all other nodes and edges
|
||||
const nodeChanges: NodeChange[] = [{ type: 'add', item: node }];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
nodes.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
nodeChanges.push({ type: 'select', id, selected: false });
|
||||
}
|
||||
});
|
||||
edges.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
edgeChanges.push({ type: 'select', id, selected: false });
|
||||
}
|
||||
});
|
||||
|
||||
// Onwards!
|
||||
if (nodeChanges.length > 0) {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
}
|
||||
if (edgeChanges.length > 0) {
|
||||
dispatch(edgesChanged(edgeChanges));
|
||||
}
|
||||
dispatch(nodeAdded({ node, cursorPos }));
|
||||
return node;
|
||||
},
|
||||
[buildInvocation, store, dispatch, t, toaster]
|
||||
[dispatch, buildInvocation, toaster, t]
|
||||
);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
@@ -175,28 +145,12 @@ const AddNodePopover = () => {
|
||||
|
||||
// Auto-connect an edge if we just added a node and have a pending connection
|
||||
if (pendingConnection && isInvocationNode(node)) {
|
||||
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||
const { handleType } = pendingConnection;
|
||||
|
||||
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
|
||||
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
|
||||
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||
|
||||
const template = templates[node.data.type];
|
||||
assert(template, 'Template not found');
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
const connection = getFirstValidConnection(
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
|
||||
if (connection) {
|
||||
const newEdge = connectionToEdge(connection);
|
||||
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,23 +160,24 @@ const AddNodePopover = () => {
|
||||
);
|
||||
|
||||
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
|
||||
if (!$isAddNodePopoverOpen.get()) {
|
||||
e.preventDefault();
|
||||
openAddNodePopover();
|
||||
flushSync(() => {
|
||||
selectRef.current?.inputRef?.focus();
|
||||
});
|
||||
}
|
||||
e.preventDefault();
|
||||
openAddNodePopover();
|
||||
flushSync(() => {
|
||||
selectRef.current?.inputRef?.focus();
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
|
||||
if ($isAddNodePopoverOpen.get()) {
|
||||
closeAddNodePopover();
|
||||
}
|
||||
closeAddNodePopover();
|
||||
}, []);
|
||||
|
||||
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
|
||||
useHotkeys(['escape'], handleHotkeyClose, { enableOnFormTags: ['TEXTAREA'] });
|
||||
useHotkeys(['escape'], handleHotkeyClose);
|
||||
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
|
||||
if (e.key === 'Escape') {
|
||||
closeAddNodePopover();
|
||||
}
|
||||
}, []);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
|
||||
|
||||
@@ -260,6 +215,7 @@ const AddNodePopover = () => {
|
||||
filterOption={filterOption}
|
||||
onChange={onChange}
|
||||
onMenuClose={closeAddNodePopover}
|
||||
onKeyDown={onKeyDown}
|
||||
inputRef={inputRef}
|
||||
closeMenuOnSelect={false}
|
||||
/>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
@@ -8,35 +8,38 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'
|
||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import {
|
||||
$cursorPos,
|
||||
$didUpdateEdge,
|
||||
$edgePendingUpdate,
|
||||
$isAddNodePopoverOpen,
|
||||
$lastEdgeUpdateMouseEvent,
|
||||
$isUpdatingEdge,
|
||||
$pendingConnection,
|
||||
$viewport,
|
||||
connectionMade,
|
||||
edgeAdded,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
nodesChanged,
|
||||
nodesDeleted,
|
||||
redo,
|
||||
selectedAll,
|
||||
undo,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type {
|
||||
EdgeChange,
|
||||
NodeChange,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnEdgeUpdateFunc,
|
||||
OnInit,
|
||||
OnMoveEnd,
|
||||
OnNodesChange,
|
||||
OnNodesDelete,
|
||||
ProOptions,
|
||||
ReactFlowProps,
|
||||
ReactFlowState,
|
||||
} from 'reactflow';
|
||||
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from 'reactflow';
|
||||
import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow';
|
||||
|
||||
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
|
||||
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
|
||||
@@ -45,6 +48,8 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
|
||||
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
|
||||
import NotesNode from './nodes/Notes/NotesNode';
|
||||
|
||||
const DELETE_KEYS = ['Delete', 'Backspace'];
|
||||
|
||||
const edgeTypes = {
|
||||
collapsed: InvocationCollapsedEdge,
|
||||
default: InvocationDefaultEdge,
|
||||
@@ -76,8 +81,6 @@ export const Flow = memo(() => {
|
||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||
const isValidConnection = useIsValidConnection();
|
||||
const cancelConnection = useReactFlowStore(selectCancelConnection);
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const store = useAppStore();
|
||||
useWorkflowWatcher();
|
||||
useSyncExecutionState();
|
||||
const [borderRadius] = useToken('radii', ['base']);
|
||||
@@ -90,17 +93,29 @@ export const Flow = memo(() => {
|
||||
);
|
||||
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(nodeChanges) => {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
(changes) => {
|
||||
dispatch(nodesChanged(changes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgesChange: OnEdgesChange = useCallback(
|
||||
(changes) => {
|
||||
if (changes.length > 0) {
|
||||
dispatch(edgesChanged(changes));
|
||||
}
|
||||
dispatch(edgesChanged(changes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||
(edges) => {
|
||||
dispatch(edgesDeleted(edges));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onNodesDelete: OnNodesDelete = useCallback(
|
||||
(nodes) => {
|
||||
dispatch(nodesDeleted(nodes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@@ -142,50 +157,45 @@ export const Flow = memo(() => {
|
||||
* where the edge is deleted if you click it accidentally).
|
||||
*/
|
||||
|
||||
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback((e, edge, _handleType) => {
|
||||
$edgePendingUpdate.set(edge);
|
||||
$didUpdateEdge.set(false);
|
||||
$lastEdgeUpdateMouseEvent.set(e);
|
||||
}, []);
|
||||
// We have a ref for cursor position, but it is the *projected* cursor position.
|
||||
// Easiest to just keep track of the last mouse event for this particular feature
|
||||
const edgeUpdateMouseEvent = useRef<MouseEvent>();
|
||||
|
||||
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
$isUpdatingEdge.set(true);
|
||||
// update mouse event
|
||||
edgeUpdateMouseEvent.current = e;
|
||||
// always delete the edge when starting an updated
|
||||
dispatch(edgeDeleted(edge.id));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
|
||||
(oldEdge, newConnection) => {
|
||||
// This event is fired when an edge update is successful
|
||||
$didUpdateEdge.set(true);
|
||||
// When an edge update is successful, we need to delete the old edge and create a new one
|
||||
const newEdge = connectionToEdge(newConnection);
|
||||
dispatch(
|
||||
edgesChanged([
|
||||
{ type: 'remove', id: oldEdge.id },
|
||||
{ type: 'add', item: newEdge },
|
||||
])
|
||||
);
|
||||
// Because we shift the position of handles depending on whether a field is connected or not, we must use
|
||||
// updateNodeInternals to tell reactflow to recalculate the positions of the handles
|
||||
updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]);
|
||||
(_oldEdge, newConnection) => {
|
||||
// Because we deleted the edge when the update started, we must create a new edge from the connection
|
||||
dispatch(connectionMade(newConnection));
|
||||
},
|
||||
[dispatch, updateNodeInternals]
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> = useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
const didUpdateEdge = $didUpdateEdge.get();
|
||||
// Fall back to a reasonable default event
|
||||
const lastEvent = $lastEdgeUpdateMouseEvent.get() ?? { clientX: 0, clientY: 0 };
|
||||
// We have to narrow this event down to MouseEvents - could be TouchEvent
|
||||
const didMouseMove =
|
||||
!('touches' in e) && Math.hypot(e.clientX - lastEvent.clientX, e.clientY - lastEvent.clientY) > 5;
|
||||
|
||||
// If we got this far and did not successfully update an edge, and the mouse moved away from the handle,
|
||||
// the user probably intended to delete the edge
|
||||
if (!didUpdateEdge && didMouseMove) {
|
||||
dispatch(edgesChanged([{ type: 'remove', id: edge.id }]));
|
||||
}
|
||||
|
||||
$edgePendingUpdate.set(null);
|
||||
$didUpdateEdge.set(false);
|
||||
$isUpdatingEdge.set(false);
|
||||
$pendingConnection.set(null);
|
||||
$lastEdgeUpdateMouseEvent.set(null);
|
||||
// Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting
|
||||
// the edge update - we need to add it back
|
||||
if (
|
||||
// ignore touch events
|
||||
!('touches' in e) &&
|
||||
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
|
||||
edgeUpdateMouseEvent.current?.clientY === e.clientY
|
||||
) {
|
||||
dispatch(edgeAdded(edge));
|
||||
}
|
||||
// reset mouse event
|
||||
edgeUpdateMouseEvent.current = undefined;
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@@ -206,27 +216,9 @@ export const Flow = memo(() => {
|
||||
const onSelectAllHotkey = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
e.preventDefault();
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
nodes.forEach(({ id, selected }) => {
|
||||
if (!selected) {
|
||||
nodeChanges.push({ type: 'select', id, selected: true });
|
||||
}
|
||||
});
|
||||
edges.forEach(({ id, selected }) => {
|
||||
if (!selected) {
|
||||
edgeChanges.push({ type: 'select', id, selected: true });
|
||||
}
|
||||
});
|
||||
if (nodeChanges.length > 0) {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
}
|
||||
if (edgeChanges.length > 0) {
|
||||
dispatch(edgesChanged(edgeChanges));
|
||||
}
|
||||
dispatch(selectedAll());
|
||||
},
|
||||
[dispatch, store]
|
||||
[dispatch]
|
||||
);
|
||||
useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey);
|
||||
|
||||
@@ -263,37 +255,12 @@ export const Flow = memo(() => {
|
||||
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
|
||||
|
||||
const onEscapeHotkey = useCallback(() => {
|
||||
if (!$edgePendingUpdate.get()) {
|
||||
$pendingConnection.set(null);
|
||||
$isAddNodePopoverOpen.set(false);
|
||||
cancelConnection();
|
||||
}
|
||||
$pendingConnection.set(null);
|
||||
$isAddNodePopoverOpen.set(false);
|
||||
cancelConnection();
|
||||
}, [cancelConnection]);
|
||||
useHotkeys('esc', onEscapeHotkey);
|
||||
|
||||
const onDeleteHotkey = useCallback(() => {
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
nodes
|
||||
.filter((n) => n.selected)
|
||||
.forEach(({ id }) => {
|
||||
nodeChanges.push({ type: 'remove', id });
|
||||
});
|
||||
edges
|
||||
.filter((e) => e.selected)
|
||||
.forEach(({ id }) => {
|
||||
edgeChanges.push({ type: 'remove', id });
|
||||
});
|
||||
if (nodeChanges.length > 0) {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
}
|
||||
if (edgeChanges.length > 0) {
|
||||
dispatch(edgesChanged(edgeChanges));
|
||||
}
|
||||
}, [dispatch, store]);
|
||||
useHotkeys(['delete', 'backspace'], onDeleteHotkey);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
id="workflow-editor"
|
||||
@@ -307,9 +274,11 @@ export const Flow = memo(() => {
|
||||
onMouseMove={onMouseMove}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
onEdgeUpdate={onEdgeUpdate}
|
||||
onEdgeUpdateStart={onEdgeUpdateStart}
|
||||
onEdgeUpdateEnd={onEdgeUpdateEnd}
|
||||
onNodesDelete={onNodesDelete}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
@@ -323,10 +292,9 @@ export const Flow = memo(() => {
|
||||
proOptions={proOptions}
|
||||
style={flowStyles}
|
||||
onPaneClick={handlePaneClick}
|
||||
deleteKeyCode={null}
|
||||
deleteKeyCode={DELETE_KEYS}
|
||||
selectionMode={selectionMode}
|
||||
elevateEdgesOnSelect
|
||||
nodeDragThreshold={1}
|
||||
>
|
||||
<Background />
|
||||
</ReactFlow>
|
||||
|
||||
@@ -2,13 +2,13 @@ import { Badge, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import { makeEdgeSelector } from 'features/nodes/components/flow/edges/util/makeEdgeSelector';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { EdgeProps } from 'reactflow';
|
||||
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
|
||||
|
||||
import { makeEdgeSelector } from './util/makeEdgeSelector';
|
||||
|
||||
const InvocationCollapsedEdge = ({
|
||||
sourceX,
|
||||
sourceY,
|
||||
@@ -18,19 +18,19 @@ const InvocationCollapsedEdge = ({
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
data,
|
||||
selected = false,
|
||||
selected,
|
||||
source,
|
||||
sourceHandleId,
|
||||
target,
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps<{ count: number }>) => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId),
|
||||
[templates, source, sourceHandleId, target, targetHandleId]
|
||||
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
|
||||
[templates, selected, source, sourceHandleId, target, targetHandleId]
|
||||
);
|
||||
|
||||
const { shouldAnimateEdges, areConnectedNodesSelected } = useAppSelector(selector);
|
||||
const { isSelected, shouldAnimate } = useAppSelector(selector);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
@@ -44,8 +44,14 @@ const InvocationCollapsedEdge = ({
|
||||
const { base500 } = useChakraThemeTokens();
|
||||
|
||||
const edgeStyles = useMemo(
|
||||
() => getEdgeStyles(base500, selected, shouldAnimateEdges, areConnectedNodesSelected),
|
||||
[areConnectedNodesSelected, base500, selected, shouldAnimateEdges]
|
||||
() => ({
|
||||
strokeWidth: isSelected ? 3 : 2,
|
||||
stroke: base500,
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
|
||||
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||
}),
|
||||
[base500, isSelected, shouldAnimate]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -54,15 +60,11 @@ const InvocationCollapsedEdge = ({
|
||||
{data?.count && data.count > 1 && (
|
||||
<EdgeLabelRenderer>
|
||||
<Flex
|
||||
data-testid="asdfasdfasdf"
|
||||
position="absolute"
|
||||
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
|
||||
className="nodrag nopan"
|
||||
// Unfortunately edge labels do not get the same zIndex treatment as edges do, so we need to manage this ourselves
|
||||
// See: https://github.com/xyflow/xyflow/issues/3658
|
||||
zIndex={1001}
|
||||
>
|
||||
<Badge variant="solid" bg="base.500" opacity={selected ? 0.8 : 0.5} boxShadow="base">
|
||||
<Badge variant="solid" bg="base.500" opacity={isSelected ? 0.8 : 0.5} boxShadow="base">
|
||||
{data.count}
|
||||
</Badge>
|
||||
</Flex>
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { EdgeProps } from 'reactflow';
|
||||
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
|
||||
@@ -17,7 +17,7 @@ const InvocationDefaultEdge = ({
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
selected = false,
|
||||
selected,
|
||||
source,
|
||||
target,
|
||||
sourceHandleId,
|
||||
@@ -25,11 +25,11 @@ const InvocationDefaultEdge = ({
|
||||
}: EdgeProps) => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId),
|
||||
[templates, source, sourceHandleId, target, targetHandleId]
|
||||
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
|
||||
[templates, source, sourceHandleId, target, targetHandleId, selected]
|
||||
);
|
||||
|
||||
const { shouldAnimateEdges, areConnectedNodesSelected, stroke, label } = useAppSelector(selector);
|
||||
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||
const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
@@ -41,9 +41,15 @@ const InvocationDefaultEdge = ({
|
||||
targetPosition,
|
||||
});
|
||||
|
||||
const edgeStyles = useMemo(
|
||||
() => getEdgeStyles(stroke, selected, shouldAnimateEdges, areConnectedNodesSelected),
|
||||
[areConnectedNodesSelected, stroke, selected, shouldAnimateEdges]
|
||||
const edgeStyles = useMemo<CSSProperties>(
|
||||
() => ({
|
||||
strokeWidth: isSelected ? 3 : 2,
|
||||
stroke,
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
|
||||
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||
}),
|
||||
[isSelected, shouldAnimate, stroke]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -59,13 +65,13 @@ const InvocationDefaultEdge = ({
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
borderWidth={1}
|
||||
borderColor={selected ? 'undefined' : 'transparent'}
|
||||
opacity={selected ? 1 : 0.5}
|
||||
borderColor={isSelected ? 'undefined' : 'transparent'}
|
||||
opacity={isSelected ? 1 : 0.5}
|
||||
py={1}
|
||||
px={3}
|
||||
shadow="md"
|
||||
>
|
||||
<Text size="sm" fontWeight="semibold" color={selected ? 'base.100' : 'base.300'}>
|
||||
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
|
||||
{label}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { FIELD_COLORS } from 'features/nodes/types/constants';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { CSSProperties } from 'react';
|
||||
|
||||
export const getFieldColor = (fieldType: FieldType | null): string => {
|
||||
if (!fieldType) {
|
||||
@@ -11,16 +10,3 @@ export const getFieldColor = (fieldType: FieldType | null): string => {
|
||||
|
||||
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
|
||||
};
|
||||
|
||||
export const getEdgeStyles = (
|
||||
stroke: string,
|
||||
selected: boolean,
|
||||
shouldAnimateEdges: boolean,
|
||||
areConnectedNodesSelected: boolean
|
||||
): CSSProperties => ({
|
||||
strokeWidth: 3,
|
||||
stroke,
|
||||
opacity: selected ? 1 : 0.5,
|
||||
animation: shouldAnimateEdges ? 'dashdraw 0.5s linear infinite' : undefined,
|
||||
strokeDasharray: selected || areConnectedNodesSelected ? 5 : 'none',
|
||||
});
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
@@ -9,8 +8,8 @@ import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
|
||||
const defaultReturnValue = {
|
||||
areConnectedNodesSelected: false,
|
||||
shouldAnimateEdges: false,
|
||||
isSelected: false,
|
||||
shouldAnimate: false,
|
||||
stroke: colorTokenToCssVar('base.500'),
|
||||
label: '',
|
||||
};
|
||||
@@ -20,27 +19,21 @@ export const makeEdgeSelector = (
|
||||
source: string,
|
||||
sourceHandleId: string | null | undefined,
|
||||
target: string,
|
||||
targetHandleId: string | null | undefined
|
||||
targetHandleId: string | null | undefined,
|
||||
selected?: boolean
|
||||
) =>
|
||||
createMemoizedSelector(
|
||||
selectNodesSlice,
|
||||
selectWorkflowSettingsSlice,
|
||||
(
|
||||
nodes,
|
||||
workflowSettings
|
||||
): { areConnectedNodesSelected: boolean; shouldAnimateEdges: boolean; stroke: string; label: string } => {
|
||||
const { shouldAnimateEdges, shouldColorEdges } = workflowSettings;
|
||||
(nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
const returnValue = deepClone(defaultReturnValue);
|
||||
returnValue.shouldAnimateEdges = shouldAnimateEdges;
|
||||
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
|
||||
returnValue.areConnectedNodesSelected = Boolean(sourceNode?.selected || targetNode?.selected);
|
||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
|
||||
return returnValue;
|
||||
return defaultReturnValue;
|
||||
}
|
||||
|
||||
const sourceNodeTemplate = templates[sourceNode.data.type];
|
||||
@@ -49,10 +42,16 @@ export const makeEdgeSelector = (
|
||||
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
returnValue.stroke = sourceType && shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
const stroke =
|
||||
sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
returnValue.label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||
|
||||
return returnValue;
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
label,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
nodeId: string;
|
||||
fieldName?: string;
|
||||
}>;
|
||||
|
||||
export const MissingFallback = memo((props: Props) => {
|
||||
// We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field
|
||||
const exists = useDoesFieldExist(props.nodeId, props.fieldName);
|
||||
if (!exists) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return props.children;
|
||||
});
|
||||
|
||||
MissingFallback.displayName = 'MissingFallback';
|
||||
@@ -25,11 +25,10 @@ interface Props {
|
||||
kind: 'inputs' | 'outputs';
|
||||
isMissingInput?: boolean;
|
||||
withTooltip?: boolean;
|
||||
shouldDim?: boolean;
|
||||
}
|
||||
|
||||
const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
const { nodeId, fieldName, kind, isMissingInput = false, withTooltip = false, shouldDim = false } = props;
|
||||
const { nodeId, fieldName, kind, isMissingInput = false, withTooltip = false } = props;
|
||||
const label = useFieldLabel(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
|
||||
const { t } = useTranslation();
|
||||
@@ -40,11 +39,13 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
const handleSubmit = useCallback(
|
||||
async (newTitleRaw: string) => {
|
||||
const newTitle = newTitleRaw.trim();
|
||||
const finalTitle = newTitle || fieldTemplateTitle || t('nodes.unknownField');
|
||||
setLocalTitle(finalTitle);
|
||||
dispatch(fieldLabelChanged({ nodeId, fieldName, label: finalTitle }));
|
||||
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
|
||||
return;
|
||||
}
|
||||
setLocalTitle(newTitle || fieldTemplateTitle || t('nodes.unknownField'));
|
||||
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
|
||||
},
|
||||
[fieldTemplateTitle, dispatch, nodeId, fieldName, t]
|
||||
[label, fieldTemplateTitle, dispatch, nodeId, fieldName, t]
|
||||
);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
@@ -79,7 +80,6 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
sx={editablePreviewStyles}
|
||||
noOfLines={1}
|
||||
color={isMissingInput ? 'error.300' : 'base.300'}
|
||||
opacity={shouldDim ? 0.5 : 1}
|
||||
/>
|
||||
</Tooltip>
|
||||
<EditableInput className="nodrag" sx={editableInputStyles} />
|
||||
|
||||
@@ -2,12 +2,10 @@ import { Tooltip } from '@invoke-ai/ui-library';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
||||
import { type FieldInputTemplate, type FieldOutputTemplate, isSingle } from 'features/nodes/types/field';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { HandleType } from 'reactflow';
|
||||
import { Handle, Position } from 'reactflow';
|
||||
|
||||
@@ -16,12 +14,11 @@ type FieldHandleProps = {
|
||||
handleType: HandleType;
|
||||
isConnectionInProgress: boolean;
|
||||
isConnectionStartField: boolean;
|
||||
validationResult: ValidationResult;
|
||||
connectionError?: string;
|
||||
};
|
||||
|
||||
const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props;
|
||||
const { t } = useTranslation();
|
||||
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props;
|
||||
const { name } = fieldTemplate;
|
||||
const type = fieldTemplate.type;
|
||||
const fieldTypeName = useFieldTypeName(type);
|
||||
@@ -29,11 +26,11 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
const isModelType = MODEL_TYPES.some((t) => t === type.name);
|
||||
const color = getFieldColor(type);
|
||||
const s: CSSProperties = {
|
||||
backgroundColor: !isSingle(type) ? colorTokenToCssVar('base.900') : color,
|
||||
backgroundColor: type.isCollection || type.isCollectionOrScalar ? colorTokenToCssVar('base.900') : color,
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: !isSingle(type) ? 4 : 0,
|
||||
borderWidth: type.isCollection || type.isCollectionOrScalar ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
@@ -46,11 +43,11 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
s.insetInlineEnd = '-1rem';
|
||||
}
|
||||
|
||||
if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) {
|
||||
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
|
||||
s.filter = 'opacity(0.4) grayscale(0.7)';
|
||||
}
|
||||
|
||||
if (isConnectionInProgress && !validationResult.isValid) {
|
||||
if (isConnectionInProgress && connectionError) {
|
||||
if (isConnectionStartField) {
|
||||
s.cursor = 'grab';
|
||||
} else {
|
||||
@@ -61,14 +58,14 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
}
|
||||
|
||||
return s;
|
||||
}, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]);
|
||||
}, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (isConnectionInProgress && validationResult.messageTKey) {
|
||||
return t(validationResult.messageTKey);
|
||||
if (isConnectionInProgress && connectionError) {
|
||||
return connectionError;
|
||||
}
|
||||
return fieldTypeName;
|
||||
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
|
||||
}, [connectionError, fieldTypeName, isConnectionInProgress]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
|
||||
@@ -24,7 +24,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||
|
||||
const isMissingInput = useMemo(() => {
|
||||
@@ -79,7 +79,6 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
kind="inputs"
|
||||
isMissingInput={isMissingInput}
|
||||
withTooltip
|
||||
shouldDim
|
||||
/>
|
||||
</FormControl>
|
||||
|
||||
@@ -88,7 +87,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
handleType="target"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
validationResult={validationResult}
|
||||
connectionError={connectionError}
|
||||
/>
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
@@ -126,7 +125,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
handleType="target"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
validationResult={validationResult}
|
||||
connectionError={connectionError}
|
||||
/>
|
||||
)}
|
||||
</InputFieldWrapper>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import {
|
||||
@@ -24,8 +23,6 @@ import {
|
||||
isLoRAModelFieldInputTemplate,
|
||||
isMainModelFieldInputInstance,
|
||||
isMainModelFieldInputTemplate,
|
||||
isModelIdentifierFieldInputInstance,
|
||||
isModelIdentifierFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
@@ -98,10 +95,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
|
||||
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ import { CSS } from '@dnd-kit/utilities';
|
||||
import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
|
||||
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
|
||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
|
||||
@@ -21,7 +20,7 @@ type Props = {
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
|
||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
||||
@@ -100,12 +99,4 @@ const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
return (
|
||||
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
|
||||
<LinearViewFieldInternal nodeId={nodeId} fieldName={fieldName} />
|
||||
</MissingFallback>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LinearViewField);
|
||||
|
||||
@@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
|
||||
|
||||
if (!fieldTemplate) {
|
||||
@@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
handleType="source"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
validationResult={validationResult}
|
||||
connectionError={connectionError}
|
||||
/>
|
||||
</OutputFieldWrapper>
|
||||
);
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
|
||||
|
||||
const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data, isLoading } = useGetModelConfigsQuery();
|
||||
const _onChange = useCallback(
|
||||
(value: AnyModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldModelIdentifierValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const modelConfigs = useMemo(() => {
|
||||
if (!data) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(data);
|
||||
}, [data]);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
groupByType: true,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelIdentifierFieldInputComponent);
|
||||
@@ -1,15 +1,14 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import type { MouseEvent, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { NodeChange } from 'reactflow';
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
nodeId: string;
|
||||
@@ -19,7 +18,6 @@ type NodeWrapperProps = PropsWithChildren & {
|
||||
|
||||
const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const store = useAppStore();
|
||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
||||
|
||||
const executionState = useExecutionState(nodeId);
|
||||
@@ -39,20 +37,11 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
|
||||
const { nodes } = store.getState().nodes.present;
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
nodes.forEach(({ id, selected }) => {
|
||||
if (selected !== (id === nodeId)) {
|
||||
nodeChanges.push({ type: 'select', id, selected: id === nodeId });
|
||||
}
|
||||
});
|
||||
if (nodeChanges.length > 0) {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
}
|
||||
dispatch(nodeExclusivelySelected(nodeId));
|
||||
}
|
||||
onCloseGlobal();
|
||||
},
|
||||
[onCloseGlobal, store, dispatch, nodeId]
|
||||
[dispatch, onCloseGlobal, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -7,10 +7,8 @@ import WorkflowInfoTooltipContent from './viewMode/WorkflowInfoTooltipContent';
|
||||
import { WorkflowWarning } from './viewMode/WorkflowWarning';
|
||||
|
||||
export const WorkflowName = () => {
|
||||
const { name, isTouched, mode } = useAppSelector((s) => s.workflow);
|
||||
const { t } = useTranslation();
|
||||
const name = useAppSelector((s) => s.workflow.name);
|
||||
const isTouched = useAppSelector((s) => s.workflow.isTouched);
|
||||
const mode = useAppSelector((s) => s.workflow.mode);
|
||||
|
||||
return (
|
||||
<Flex gap="1" alignItems="center">
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
|
||||
import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent';
|
||||
import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
|
||||
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
||||
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
|
||||
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
||||
@@ -15,7 +14,7 @@ type Props = {
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||
const WorkflowField = ({ nodeId, fieldName }: Props) => {
|
||||
const label = useFieldLabel(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs');
|
||||
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
|
||||
@@ -51,12 +50,4 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
const WorkflowField = ({ nodeId, fieldName }: Props) => {
|
||||
return (
|
||||
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
|
||||
<WorkflowFieldInternal nodeId={nodeId} fieldName={fieldName} />
|
||||
</MissingFallback>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(WorkflowField);
|
||||
|
||||
@@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import DndSortable from 'features/dnd/components/DndSortable';
|
||||
import type { DragEndEvent } from 'features/dnd/types';
|
||||
import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
|
||||
import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
|
||||
import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
@@ -40,18 +40,16 @@ const WorkflowLinearTab = () => {
|
||||
[dispatch, fields]
|
||||
);
|
||||
|
||||
const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]);
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<DndSortable onDragEnd={handleDragEnd} items={items}>
|
||||
<DndSortable onDragEnd={handleDragEnd} items={fields.map((field) => `${field.nodeId}.${field.fieldName}`)}>
|
||||
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
|
||||
{isLoading ? (
|
||||
<IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} />
|
||||
) : fields.length ? (
|
||||
fields.map(({ nodeId, fieldName }) => (
|
||||
<LinearViewFieldInternal key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
|
||||
<LinearViewField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
|
||||
))
|
||||
) : (
|
||||
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
@@ -8,20 +9,31 @@ import { useMemo } from 'react';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const selectConnectedFieldNames = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
|
||||
nodesSlice.edges
|
||||
.filter((e) => e.target === nodeId)
|
||||
.map((e) => e.targetHandle)
|
||||
.filter(Boolean)
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
|
||||
|
||||
const fieldNames = useMemo(() => {
|
||||
const fields = map(template.inputs).filter((field) => {
|
||||
if (connectedFieldNames.includes(field.name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (
|
||||
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
|
||||
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
});
|
||||
const _fieldNames = getSortedFilteredFieldNames(fields);
|
||||
if (_fieldNames.length === 0) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
return _fieldNames;
|
||||
}, [template.inputs]);
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
}, [connectedFieldNames, template.inputs]);
|
||||
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
@@ -2,69 +2,58 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import {
|
||||
$didUpdateEdge,
|
||||
$edgePendingUpdate,
|
||||
$isAddNodePopoverOpen,
|
||||
$isUpdatingEdge,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
edgesChanged,
|
||||
connectionMade,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||
import { useUpdateNodeInternals } from 'reactflow';
|
||||
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useConnection = () => {
|
||||
const store = useAppStore();
|
||||
const templates = useStore($templates);
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
|
||||
const onConnectStart = useCallback<OnConnectStart>(
|
||||
(event, { nodeId, handleId, handleType }) => {
|
||||
assert(nodeId && handleId && handleType, 'Invalid connection start event');
|
||||
(event, params) => {
|
||||
const nodes = store.getState().nodes.present.nodes;
|
||||
|
||||
const { nodeId, handleId, handleType } = params;
|
||||
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
|
||||
const node = nodes.find((n) => n.id === nodeId);
|
||||
if (!node) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
|
||||
const fieldTemplate = fieldTemplates[handleId];
|
||||
if (!fieldTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
$pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate });
|
||||
assert(template, `Template not found for node type: ${node.data.type}`);
|
||||
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
|
||||
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
|
||||
$pendingConnection.set({
|
||||
node,
|
||||
template,
|
||||
fieldTemplate,
|
||||
});
|
||||
},
|
||||
[store, templates]
|
||||
);
|
||||
const onConnect = useCallback<OnConnect>(
|
||||
(connection) => {
|
||||
const { dispatch } = store;
|
||||
const newEdge = connectionToEdge(connection);
|
||||
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
|
||||
updateNodeInternals([newEdge.source, newEdge.target]);
|
||||
dispatch(connectionMade(connection));
|
||||
$pendingConnection.set(null);
|
||||
},
|
||||
[store, updateNodeInternals]
|
||||
[store]
|
||||
);
|
||||
const onConnectEnd = useCallback<OnConnectEnd>(() => {
|
||||
const { dispatch } = store;
|
||||
const pendingConnection = $pendingConnection.get();
|
||||
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||
const isUpdatingEdge = $isUpdatingEdge.get();
|
||||
const mouseOverNodeId = $mouseOverNode.get();
|
||||
|
||||
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
|
||||
// update logic can finish up
|
||||
if (edgePendingUpdate && !mouseOverNodeId) {
|
||||
if (isUpdatingEdge && !mouseOverNodeId) {
|
||||
$pendingConnection.set(null);
|
||||
return;
|
||||
}
|
||||
@@ -74,41 +63,30 @@ export const useConnection = () => {
|
||||
}
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
if (mouseOverNodeId) {
|
||||
const { handleType } = pendingConnection;
|
||||
const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId;
|
||||
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
|
||||
const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId;
|
||||
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
|
||||
|
||||
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
|
||||
if (!candidateNode) {
|
||||
// The mouse is over a non-invocation node - bail
|
||||
return;
|
||||
}
|
||||
const candidateTemplate = templates[candidateNode.data.type];
|
||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||
const connection = getFirstValidConnection(
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
templates,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
pendingConnection,
|
||||
candidateNode,
|
||||
candidateTemplate
|
||||
);
|
||||
if (connection) {
|
||||
const newEdge = connectionToEdge(connection);
|
||||
const edgeChanges: EdgeChange[] = [{ type: 'add', item: newEdge }];
|
||||
|
||||
const nodesToUpdate = [newEdge.source, newEdge.target];
|
||||
if (edgePendingUpdate) {
|
||||
$didUpdateEdge.set(true);
|
||||
edgeChanges.push({ type: 'remove', id: edgePendingUpdate.id });
|
||||
nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target);
|
||||
}
|
||||
dispatch(edgesChanged(edgeChanges));
|
||||
updateNodeInternals(nodesToUpdate);
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
$pendingConnection.set(null);
|
||||
} else {
|
||||
// The mouse is not over a node - we should open the add node popover
|
||||
$isAddNodePopoverOpen.set(true);
|
||||
}
|
||||
}, [store, templates, updateNodeInternals]);
|
||||
}, [store, templates]);
|
||||
|
||||
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
||||
return api;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
@@ -8,22 +9,31 @@ import { useMemo } from 'react';
|
||||
|
||||
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const selectConnectedFieldNames = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
|
||||
nodesSlice.edges
|
||||
.filter((e) => e.target === nodeId)
|
||||
.map((e) => e.targetHandle)
|
||||
.filter(Boolean)
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
|
||||
|
||||
const fieldNames = useMemo(() => {
|
||||
// get the visible fields
|
||||
const fields = map(template.inputs).filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' && !isSingleOrCollection(field.type)) ||
|
||||
const fields = map(template.inputs).filter((field) => {
|
||||
if (connectedFieldNames.includes(field.name)) {
|
||||
return true;
|
||||
}
|
||||
return (
|
||||
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
|
||||
const _fieldNames = getSortedFilteredFieldNames(fields);
|
||||
|
||||
if (_fieldNames.length === 0) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return _fieldNames;
|
||||
}, [template.inputs]);
|
||||
);
|
||||
});
|
||||
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
}, [connectedFieldNames, template.inputs]);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
|
||||
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useFieldType } from './useFieldType.ts';
|
||||
|
||||
type UseConnectionStateProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
@@ -14,7 +16,7 @@ type UseConnectionStateProps = {
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const templates = useStore($templates);
|
||||
const edgePendingUpdate = useStore($edgePendingUpdate);
|
||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||
|
||||
const selectIsConnected = useMemo(
|
||||
() =>
|
||||
@@ -31,9 +33,17 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
const selectValidationResult = useMemo(
|
||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
|
||||
[templates, nodeId, fieldName, kind]
|
||||
const selectConnectionError = useMemo(
|
||||
() =>
|
||||
makeConnectionErrorSelector(
|
||||
templates,
|
||||
pendingConnection,
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind === 'inputs' ? 'target' : 'source',
|
||||
fieldType
|
||||
),
|
||||
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
@@ -43,23 +53,23 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
pendingConnection.nodeId === nodeId &&
|
||||
pendingConnection.handleId === fieldName &&
|
||||
pendingConnection.node.id === nodeId &&
|
||||
pendingConnection.fieldTemplate.name === fieldName &&
|
||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||
);
|
||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||
const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate));
|
||||
const connectionError = useAppSelector(selectConnectionError);
|
||||
|
||||
const shouldDim = useMemo(
|
||||
() => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField),
|
||||
[validationResult, isConnectionInProgress, isConnectionStartField]
|
||||
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
|
||||
[connectionError, isConnectionInProgress, isConnectionStartField]
|
||||
);
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
validationResult,
|
||||
connectionError,
|
||||
shouldDim,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -5,13 +5,11 @@ import {
|
||||
$copiedNodes,
|
||||
$cursorPos,
|
||||
$edgesToCopiedNodes,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
selectionPasted,
|
||||
selectNodesSlice,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
||||
import { isEqual, uniqWith } from 'lodash-es';
|
||||
import type { EdgeChange, NodeChange } from 'reactflow';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const copySelection = () => {
|
||||
@@ -28,7 +26,7 @@ const copySelection = () => {
|
||||
|
||||
const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
const { getState, dispatch } = getStore();
|
||||
const { nodes, edges } = selectNodesSlice(getState());
|
||||
const currentNodes = selectNodesSlice(getState()).nodes;
|
||||
const cursorPos = $cursorPos.get();
|
||||
|
||||
const copiedNodes = deepClone($copiedNodes.get());
|
||||
@@ -48,7 +46,7 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
const offsetY = cursorPos ? cursorPos.y - minY : 50;
|
||||
|
||||
copiedNodes.forEach((node) => {
|
||||
const { x, y } = findUnoccupiedPosition(nodes, node.position.x + offsetX, node.position.y + offsetY);
|
||||
const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY);
|
||||
node.position.x = x;
|
||||
node.position.y = y;
|
||||
// Pasted nodes are selected
|
||||
@@ -70,48 +68,7 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
node.data.id = id;
|
||||
});
|
||||
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
// Deselect existing nodes
|
||||
nodes.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
nodeChanges.push({
|
||||
type: 'select',
|
||||
id,
|
||||
selected: false,
|
||||
});
|
||||
}
|
||||
});
|
||||
// Add new nodes
|
||||
copiedNodes.forEach((n) => {
|
||||
nodeChanges.push({
|
||||
type: 'add',
|
||||
item: n,
|
||||
});
|
||||
});
|
||||
// Deselect existing edges
|
||||
edges.forEach(({ id, selected }) => {
|
||||
if (selected) {
|
||||
edgeChanges.push({
|
||||
type: 'select',
|
||||
id,
|
||||
selected: false,
|
||||
});
|
||||
}
|
||||
});
|
||||
// Add new edges
|
||||
copiedEdges.forEach((e) => {
|
||||
edgeChanges.push({
|
||||
type: 'add',
|
||||
item: e,
|
||||
});
|
||||
});
|
||||
if (nodeChanges.length > 0) {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
}
|
||||
if (edgeChanges.length > 0) {
|
||||
dispatch(edgesChanged(edgeChanges));
|
||||
}
|
||||
dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges }));
|
||||
};
|
||||
|
||||
const api = { copySelection, pasteSelection };
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
export const useDoesFieldExist = (nodeId: string, fieldName?: string) => {
|
||||
const doesFieldExist = useAppSelector((s) => {
|
||||
const node = s.nodes.present.nodes.find((n) => n.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
if (fieldName === undefined) {
|
||||
return true;
|
||||
}
|
||||
if (!node.data.inputs[fieldName]) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
return doesFieldExist;
|
||||
};
|
||||
@@ -0,0 +1,9 @@
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => {
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]);
|
||||
return fieldType;
|
||||
};
|
||||
@@ -1,10 +1,14 @@
|
||||
// TODO: enable this at some point
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection } from 'reactflow';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
@@ -21,21 +25,75 @@ export const useIsValidConnection = () => {
|
||||
if (!(source && sourceHandle && target && targetHandle)) {
|
||||
return false;
|
||||
}
|
||||
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
|
||||
const validationResult = validateConnection(
|
||||
{ source, sourceHandle, target, targetHandle },
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate,
|
||||
shouldValidateGraph
|
||||
);
|
||||
if (source === target) {
|
||||
// Don't allow nodes to connect to themselves, even if validation is disabled
|
||||
return false;
|
||||
}
|
||||
|
||||
return validationResult.isValid;
|
||||
const state = store.getState();
|
||||
const { nodes, edges } = state.nodes.present;
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
||||
const targetNode = nodes.find((node) => node.id === target) as Node<InvocationNodeData>;
|
||||
const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle];
|
||||
const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle];
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceFieldTemplate && targetFieldTemplate)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!shouldValidateGraph) {
|
||||
// manual override!
|
||||
return true;
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
edge.target === target &&
|
||||
edge.targetHandle === targetHandle &&
|
||||
edge.source === source &&
|
||||
edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
return isEqual(sourceFieldTemplate.type, collectItemType);
|
||||
}
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must use the originalType here if it exists
|
||||
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||
},
|
||||
[templates, shouldValidateGraph, store]
|
||||
[shouldValidateGraph, templates, store]
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type FieldType, isCollection, isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -10,13 +10,13 @@ export const useFieldTypeName = (fieldType?: FieldType): string => {
|
||||
return '';
|
||||
}
|
||||
const { name } = fieldType;
|
||||
if (isCollection(fieldType)) {
|
||||
if (fieldType.isCollection) {
|
||||
return t('nodes.collectionFieldType', { name });
|
||||
}
|
||||
if (isSingleOrCollection(fieldType)) {
|
||||
if (fieldType.isCollectionOrScalar) {
|
||||
return t('nodes.collectionOrScalarFieldType', { name });
|
||||
}
|
||||
return t('nodes.singleFieldType', { name });
|
||||
return name;
|
||||
}, [fieldType, t]);
|
||||
|
||||
return name;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { Graph, GraphAndWorkflowResponse } from 'services/api/types';
|
||||
import type { Graph } from 'services/api/types';
|
||||
|
||||
const textToImageGraphBuilt = createAction<Graph>('nodes/textToImageGraphBuilt');
|
||||
const imageToImageGraphBuilt = createAction<Graph>('nodes/imageToImageGraphBuilt');
|
||||
@@ -15,7 +15,7 @@ export const isAnyGraphBuilt = isAnyOf(
|
||||
);
|
||||
|
||||
export const workflowLoadRequested = createAction<{
|
||||
data: GraphAndWorkflowResponse;
|
||||
workflow: unknown;
|
||||
asCopy: boolean;
|
||||
}>('nodes/workflowLoadRequested');
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import type {
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
ModelIdentifierFieldValue,
|
||||
SchedulerFieldValue,
|
||||
SDXLRefinerModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
@@ -36,7 +35,6 @@ import {
|
||||
zIPAdapterModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
@@ -47,13 +45,13 @@ import {
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { atom } from 'nanostores';
|
||||
import type { MouseEvent } from 'react';
|
||||
import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
||||
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type { z } from 'zod';
|
||||
|
||||
import type { NodesState, PendingConnection, Templates } from './types';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
|
||||
const initialNodesState: NodesState = {
|
||||
_version: 1,
|
||||
@@ -92,47 +90,44 @@ export const nodesSlice = createSlice({
|
||||
reducers: {
|
||||
nodesChanged: (state, action: PayloadAction<NodeChange[]>) => {
|
||||
state.nodes = applyNodeChanges(action.payload, state.nodes);
|
||||
// Remove edges that are no longer valid, due to a removed or otherwise changed node
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
state.edges.forEach((e) => {
|
||||
const sourceExists = state.nodes.some((n) => n.id === e.source);
|
||||
const targetExists = state.nodes.some((n) => n.id === e.target);
|
||||
if (!(sourceExists && targetExists)) {
|
||||
edgeChanges.push({ type: 'remove', id: e.id });
|
||||
}
|
||||
});
|
||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||
},
|
||||
nodeReplaced: (state, action: PayloadAction<{ nodeId: string; node: Node }>) => {
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === action.payload.nodeId);
|
||||
if (nodeIndex < 0) {
|
||||
return;
|
||||
}
|
||||
state.nodes[nodeIndex] = action.payload.node;
|
||||
},
|
||||
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
|
||||
const { node, cursorPos } = action.payload;
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
cursorPos?.x ?? node.position.x,
|
||||
cursorPos?.y ?? node.position.y
|
||||
);
|
||||
node.position = position;
|
||||
node.selected = true;
|
||||
|
||||
state.nodes = applyNodeChanges(
|
||||
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: false })),
|
||||
state.nodes
|
||||
);
|
||||
|
||||
state.edges = applyEdgeChanges(
|
||||
state.edges.map((e) => ({ id: e.id, type: 'select', selected: false })),
|
||||
state.edges
|
||||
);
|
||||
|
||||
state.nodes.push(node);
|
||||
},
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
const changes: EdgeChange[] = [];
|
||||
// We may need to massage the edge changes or otherwise handle them
|
||||
action.payload.forEach((change) => {
|
||||
if (change.type === 'remove' || change.type === 'select') {
|
||||
const edge = state.edges.find((e) => e.id === change.id);
|
||||
// If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them
|
||||
if (edge && edge.type === 'collapsed') {
|
||||
const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target);
|
||||
if (change.type === 'remove') {
|
||||
hiddenEdges.forEach(({ id }) => {
|
||||
changes.push({ type: 'remove', id });
|
||||
});
|
||||
}
|
||||
if (change.type === 'select') {
|
||||
hiddenEdges.forEach(({ id }) => {
|
||||
changes.push({ type: 'select', id, selected: change.selected });
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (change.type === 'add') {
|
||||
if (!change.item.type) {
|
||||
// We must add the edge type!
|
||||
change.item.type = 'default';
|
||||
}
|
||||
}
|
||||
changes.push(change);
|
||||
});
|
||||
state.edges = applyEdgeChanges(changes, state.edges);
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
},
|
||||
edgeAdded: (state, action: PayloadAction<Edge>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
|
||||
},
|
||||
fieldLabelChanged: (
|
||||
state,
|
||||
@@ -237,7 +232,6 @@ export const nodesSlice = createSlice({
|
||||
type: 'collapsed',
|
||||
data: { count: 1 },
|
||||
updatable: false,
|
||||
selected: edge.selected,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -258,7 +252,6 @@ export const nodesSlice = createSlice({
|
||||
type: 'collapsed',
|
||||
data: { count: 1 },
|
||||
updatable: false,
|
||||
selected: edge.selected,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -271,6 +264,33 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
}
|
||||
},
|
||||
edgeDeleted: (state, action: PayloadAction<string>) => {
|
||||
state.edges = state.edges.filter((e) => e.id !== action.payload);
|
||||
},
|
||||
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
|
||||
const edges = action.payload;
|
||||
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
|
||||
|
||||
// if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes
|
||||
if (collapsedEdges.length) {
|
||||
const edgeChanges: EdgeRemoveChange[] = [];
|
||||
collapsedEdges.forEach((collapsedEdge) => {
|
||||
state.edges.forEach((edge) => {
|
||||
if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) {
|
||||
edgeChanges.push({ id: edge.id, type: 'remove' });
|
||||
}
|
||||
});
|
||||
});
|
||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||
}
|
||||
},
|
||||
nodesDeleted: (state, action: PayloadAction<AnyNode[]>) => {
|
||||
action.payload.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
});
|
||||
},
|
||||
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
||||
const { nodeId, label } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
@@ -289,6 +309,17 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = notes;
|
||||
},
|
||||
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
|
||||
const nodeId = action.payload;
|
||||
state.nodes = applyNodeChanges(
|
||||
state.nodes.map((n) => ({
|
||||
id: n.id,
|
||||
type: 'select',
|
||||
selected: n.id === nodeId ? true : false,
|
||||
})),
|
||||
state.nodes
|
||||
);
|
||||
},
|
||||
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStatefulFieldValue);
|
||||
},
|
||||
@@ -313,9 +344,6 @@ export const nodesSlice = createSlice({
|
||||
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zMainModelFieldValue);
|
||||
},
|
||||
fieldModelIdentifierValueChanged: (state, action: FieldValueAction<ModelIdentifierFieldValue>) => {
|
||||
fieldValueReducer(state, action, zModelIdentifierFieldValue);
|
||||
},
|
||||
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
|
||||
},
|
||||
@@ -353,6 +381,57 @@ export const nodesSlice = createSlice({
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
},
|
||||
selectedAll: (state) => {
|
||||
state.nodes = applyNodeChanges(
|
||||
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
|
||||
state.nodes
|
||||
);
|
||||
state.edges = applyEdgeChanges(
|
||||
state.edges.map((e) => ({ id: e.id, type: 'select', selected: true })),
|
||||
state.edges
|
||||
);
|
||||
},
|
||||
selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => {
|
||||
const { nodes, edges } = action.payload;
|
||||
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
|
||||
// Deselect existing nodes
|
||||
state.nodes.forEach((n) => {
|
||||
nodeChanges.push({
|
||||
id: n.data.id,
|
||||
type: 'select',
|
||||
selected: false,
|
||||
});
|
||||
});
|
||||
// Add new nodes
|
||||
nodes.forEach((n) => {
|
||||
nodeChanges.push({
|
||||
item: n,
|
||||
type: 'add',
|
||||
});
|
||||
});
|
||||
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
// Deselect existing edges
|
||||
state.edges.forEach((e) => {
|
||||
edgeChanges.push({
|
||||
id: e.id,
|
||||
type: 'select',
|
||||
selected: false,
|
||||
});
|
||||
});
|
||||
// Add new edges
|
||||
edges.forEach((e) => {
|
||||
edgeChanges.push({
|
||||
item: e,
|
||||
type: 'add',
|
||||
});
|
||||
});
|
||||
|
||||
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||
},
|
||||
undo: (state) => state,
|
||||
redo: (state) => state,
|
||||
},
|
||||
@@ -361,13 +440,13 @@ export const nodesSlice = createSlice({
|
||||
const { nodes, edges } = action.payload;
|
||||
state.nodes = applyNodeChanges(
|
||||
nodes.map((node) => ({
|
||||
type: 'add',
|
||||
item: { ...node, ...SHARED_NODE_PROPERTIES },
|
||||
type: 'add',
|
||||
})),
|
||||
[]
|
||||
);
|
||||
state.edges = applyEdgeChanges(
|
||||
edges.map((edge) => ({ type: 'add', item: edge })),
|
||||
edges.map((edge) => ({ item: edge, type: 'add' })),
|
||||
[]
|
||||
);
|
||||
});
|
||||
@@ -375,7 +454,10 @@ export const nodesSlice = createSlice({
|
||||
});
|
||||
|
||||
export const {
|
||||
connectionMade,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
fieldValueReset,
|
||||
fieldBoardValueChanged,
|
||||
fieldBooleanValueChanged,
|
||||
@@ -387,21 +469,27 @@ export const {
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldModelIdentifierValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
nodeAdded,
|
||||
nodeReplaced,
|
||||
nodeEditorReset,
|
||||
nodeExclusivelySelected,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
nodeLabelChanged,
|
||||
nodeNotesChanged,
|
||||
nodesChanged,
|
||||
nodesDeleted,
|
||||
nodeUseCacheChanged,
|
||||
notesNodeValueChanged,
|
||||
selectedAll,
|
||||
selectionPasted,
|
||||
edgeAdded,
|
||||
undo,
|
||||
redo,
|
||||
} = nodesSlice.actions;
|
||||
@@ -412,10 +500,7 @@ export const $copiedNodes = atom<AnyNode[]>([]);
|
||||
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
||||
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
|
||||
export const $pendingConnection = atom<PendingConnection | null>(null);
|
||||
export const $edgePendingUpdate = atom<Edge | null>(null);
|
||||
export const $didUpdateEdge = atom(false);
|
||||
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
|
||||
|
||||
export const $isUpdatingEdge = atom(false);
|
||||
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
|
||||
export const $isAddNodePopoverOpen = atom(false);
|
||||
export const closeAddNodePopover = () => {
|
||||
@@ -443,13 +528,13 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected);
|
||||
|
||||
const isSelectionAction = (action: UnknownAction) => {
|
||||
if (nodesChanged.match(action)) {
|
||||
if (action.payload.every((change) => change.type === 'select')) {
|
||||
return true;
|
||||
}
|
||||
if (selectionMatcher(action)) {
|
||||
return true;
|
||||
}
|
||||
if (edgesChanged.match(action)) {
|
||||
if (nodesChanged.match(action)) {
|
||||
if (action.payload.every((change) => change.type === 'select')) {
|
||||
return true;
|
||||
}
|
||||
@@ -489,7 +574,10 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||
|
||||
// This is used for tracking `state.workflow.isTouched`
|
||||
export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
connectionMade,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
fieldBoardValueChanged,
|
||||
fieldBooleanValueChanged,
|
||||
fieldColorValueChanged,
|
||||
@@ -506,11 +594,15 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
nodesChanged,
|
||||
nodeAdded,
|
||||
nodeReplaced,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
nodeLabelChanged,
|
||||
nodeNotesChanged,
|
||||
nodesDeleted,
|
||||
nodeUseCacheChanged,
|
||||
notesNodeValueChanged
|
||||
notesNodeValueChanged,
|
||||
selectionPasted,
|
||||
edgeAdded
|
||||
);
|
||||
|
||||
@@ -6,20 +6,19 @@ import type {
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
AnyNode,
|
||||
InvocationNode,
|
||||
InvocationNodeEdge,
|
||||
InvocationTemplate,
|
||||
NodeExecutionState,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { HandleType } from 'reactflow';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
||||
|
||||
export type PendingConnection = {
|
||||
nodeId: string;
|
||||
handleId: string;
|
||||
handleType: HandleType;
|
||||
node: InvocationNode;
|
||||
template: InvocationTemplate;
|
||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { areTypesEqual } from './areTypesEqual';
|
||||
|
||||
describe(areTypesEqual.name, () => {
|
||||
it('should handle equal source and target type', () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'Foo',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'Bar',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle equal source type and original target type', () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'Foo',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle equal original source type and target type', () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'Bar',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle equal original source type and original target type', () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'LoRAModelField',
|
||||
cardinality: 'SINGLE',
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,29 +0,0 @@
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { isEqual, omit } from 'lodash-es';
|
||||
|
||||
/**
|
||||
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
|
||||
* considered equal. For example, if the first type and original second type match, the types are considered equal.
|
||||
* @param firstType The first type to compare.
|
||||
* @param secondType The second type to compare.
|
||||
* @returns True if the types are equal, false otherwise.
|
||||
*/
|
||||
export const areTypesEqual = (firstType: FieldType, secondType: FieldType) => {
|
||||
const _firstType = 'originalType' in firstType ? omit(firstType, 'originalType') : firstType;
|
||||
const _secondType = 'originalType' in secondType ? omit(secondType, 'originalType') : secondType;
|
||||
const _originalFirstType = 'originalType' in firstType ? firstType.originalType : null;
|
||||
const _originalSecondType = 'originalType' in secondType ? secondType.originalType : null;
|
||||
if (isEqual(_firstType, _secondType)) {
|
||||
return true;
|
||||
}
|
||||
if (_originalSecondType && isEqual(_firstType, _originalSecondType)) {
|
||||
return true;
|
||||
}
|
||||
if (_originalFirstType && isEqual(_originalFirstType, _secondType)) {
|
||||
return true;
|
||||
}
|
||||
if (_originalFirstType && _originalSecondType && isEqual(_originalFirstType, _originalSecondType)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@@ -0,0 +1,105 @@
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { differenceWith, isEqual, map } from 'lodash-es';
|
||||
import type { Connection } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
export const getFirstValidConnection = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
pendingConnection: PendingConnection,
|
||||
candidateNode: InvocationNode,
|
||||
candidateTemplate: InvocationTemplate
|
||||
): Connection | null => {
|
||||
if (pendingConnection.node.id === candidateNode.id) {
|
||||
// Cannot connect to self
|
||||
return null;
|
||||
}
|
||||
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
|
||||
if (pendingFieldKind === 'source') {
|
||||
// Connecting from a source to a target
|
||||
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
if (candidateNode.data.type === 'collect') {
|
||||
// Special handling for collect node - the `item` field takes any number of connections
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: 'item',
|
||||
};
|
||||
}
|
||||
// Only one connection per target field is allowed - look for an unconnected target field
|
||||
const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct');
|
||||
const candidateConnectedFields = edges
|
||||
.filter((edge) => edge.target === candidateNode.id)
|
||||
.map((edge) => {
|
||||
// Edges must always have a targetHandle, safe to assert here
|
||||
assert(edge.targetHandle);
|
||||
return edge.targetHandle;
|
||||
});
|
||||
const candidateUnconnectedFields = differenceWith(
|
||||
candidateFields,
|
||||
candidateConnectedFields,
|
||||
(field, connectedFieldName) => field.name === connectedFieldName
|
||||
);
|
||||
const candidateField = candidateUnconnectedFields.find((field) =>
|
||||
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||
);
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: candidateField.name,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Connecting from a target to a source
|
||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
||||
const isCollect = pendingConnection.node.data.type === 'collect';
|
||||
const isTargetAlreadyConnected = edges.some(
|
||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||
);
|
||||
if (!isCollect && isTargetAlreadyConnected) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||
let candidateFields = map(candidateTemplate.outputs);
|
||||
if (isCollect) {
|
||||
// Narrow candidates to same field type as already is connected to the collect node
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||
if (collectItemType) {
|
||||
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
|
||||
}
|
||||
}
|
||||
const candidateField = candidateFields.find((field) => {
|
||||
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
|
||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
||||
return isValid && !isAlreadyConnected;
|
||||
});
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: candidateNode.id,
|
||||
sourceHandle: candidateField.name,
|
||||
target: pendingConnection.node.id,
|
||||
targetHandle: pendingConnection.fieldTemplate.name,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
@@ -1,44 +0,0 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { unset } from 'lodash-es';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe(getCollectItemType.name, () => {
|
||||
it('should return the type of the items the collect node collects', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(collect);
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const result = getCollectItemType(templates, [n1, n2], [e1], n2.id);
|
||||
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE' });
|
||||
});
|
||||
it('should return null if the collect node does not have any connections', () => {
|
||||
const n1 = buildNode(collect);
|
||||
const result = getCollectItemType(templates, [n1], [], n1.id);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
it("should return null if the first edge to collect's node doesn't exist", () => {
|
||||
const n1 = buildNode(collect);
|
||||
const n2 = buildNode(add);
|
||||
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
|
||||
const result = getCollectItemType(templates, [n1], [e1], n1.id);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
it("should return null if the first edge to collect's node template doesn't exist", () => {
|
||||
const n1 = buildNode(collect);
|
||||
const n2 = buildNode(add);
|
||||
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
|
||||
const result = getCollectItemType({ collect }, [n1, n2], [e1], n1.id);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
it("should return null if the first edge to the collect's field template doesn't exist", () => {
|
||||
const n1 = buildNode(collect);
|
||||
const n2 = buildNode(add);
|
||||
const addWithoutOutputValue = deepClone(add);
|
||||
unset(addWithoutOutputValue, 'outputs.value');
|
||||
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
|
||||
const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -1,38 +0,0 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
|
||||
/**
|
||||
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
|
||||
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
|
||||
* input field.
|
||||
* @param templates The current invocation templates
|
||||
* @param nodes The current nodes
|
||||
* @param edges The current edges
|
||||
* @param nodeId The collect node's id
|
||||
* @returns The type of the items the collect node collects, or null if there is no input field
|
||||
*/
|
||||
export const getCollectItemType = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
nodeId: string
|
||||
): FieldType | null => {
|
||||
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||
if (!firstEdgeToCollect?.sourceHandle) {
|
||||
return null;
|
||||
}
|
||||
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle];
|
||||
if (!fieldTemplate) {
|
||||
return null;
|
||||
}
|
||||
return fieldTemplate.type;
|
||||
};
|
||||
@@ -1,203 +0,0 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import {
|
||||
getFirstValidConnection,
|
||||
getSourceCandidateFields,
|
||||
getTargetCandidateFields,
|
||||
} from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils';
|
||||
import { unset } from 'lodash-es';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe('getFirstValidConnection', () => {
|
||||
it('should return null if the pending and candidate nodes are the same node', () => {
|
||||
const n = buildNode(add);
|
||||
expect(getFirstValidConnection(n.id, 'value', n.id, null, [n], [], templates, null)).toBe(null);
|
||||
});
|
||||
|
||||
it('should return null if the sourceHandle and targetHandle are null', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
expect(getFirstValidConnection(n1.id, null, n2.id, null, [n1, n2], [], templates, null)).toBe(null);
|
||||
});
|
||||
|
||||
it('should return itself if both sourceHandle and targetHandle are provided', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
expect(getFirstValidConnection(n1.id, 'value', n2.id, 'a', [n1, n2], [], templates, null)).toEqual({
|
||||
source: n1.id,
|
||||
sourceHandle: 'value',
|
||||
target: n2.id,
|
||||
targetHandle: 'a',
|
||||
});
|
||||
});
|
||||
|
||||
describe('connecting from a source to a target', () => {
|
||||
const n1 = buildNode(img_resize);
|
||||
const n2 = buildNode(img_resize);
|
||||
|
||||
it('should return the first valid connection if there are no connected fields', () => {
|
||||
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [], templates, null);
|
||||
const c = {
|
||||
source: n1.id,
|
||||
sourceHandle: 'width',
|
||||
target: n2.id,
|
||||
targetHandle: 'width',
|
||||
};
|
||||
expect(r).toEqual(c);
|
||||
});
|
||||
it('should return the first valid connection if there is a connected field', () => {
|
||||
const e = buildEdge(n1.id, 'height', n2.id, 'width');
|
||||
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, null);
|
||||
const c = {
|
||||
source: n1.id,
|
||||
sourceHandle: 'width',
|
||||
target: n2.id,
|
||||
targetHandle: 'height',
|
||||
};
|
||||
expect(r).toEqual(c);
|
||||
});
|
||||
it('should return the first valid connection if there is an edgePendingUpdate', () => {
|
||||
const e = buildEdge(n1.id, 'width', n2.id, 'width');
|
||||
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, e);
|
||||
const c = {
|
||||
source: n1.id,
|
||||
sourceHandle: 'width',
|
||||
target: n2.id,
|
||||
targetHandle: 'width',
|
||||
};
|
||||
expect(r).toEqual(c);
|
||||
});
|
||||
it('should return null if the target has no valid fields', () => {
|
||||
const e1 = buildEdge(n1.id, 'width', n2.id, 'width');
|
||||
const e2 = buildEdge(n1.id, 'height', n2.id, 'height');
|
||||
const n3 = buildNode(add);
|
||||
const r = getFirstValidConnection(n3.id, 'value', n2.id, null, [n1, n2, n3], [e1, e2], templates, null);
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('connecting from a target to a source', () => {
|
||||
const n1 = buildNode(img_resize);
|
||||
const n2 = buildNode(img_resize);
|
||||
|
||||
it('should return the first valid connection if there are no connected fields', () => {
|
||||
const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [], templates, null);
|
||||
const c = {
|
||||
source: n1.id,
|
||||
sourceHandle: 'width',
|
||||
target: n2.id,
|
||||
targetHandle: 'width',
|
||||
};
|
||||
expect(r).toEqual(c);
|
||||
});
|
||||
it('should return the first valid connection if there is a connected field', () => {
|
||||
const e = buildEdge(n1.id, 'height', n2.id, 'width');
|
||||
const r = getFirstValidConnection(n1.id, null, n2.id, 'height', [n1, n2], [e], templates, null);
|
||||
const c = {
|
||||
source: n1.id,
|
||||
sourceHandle: 'width',
|
||||
target: n2.id,
|
||||
targetHandle: 'height',
|
||||
};
|
||||
expect(r).toEqual(c);
|
||||
});
|
||||
it('should return the first valid connection if there is an edgePendingUpdate', () => {
|
||||
const e = buildEdge(n1.id, 'width', n2.id, 'width');
|
||||
const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [e], templates, e);
|
||||
const c = {
|
||||
source: n1.id,
|
||||
sourceHandle: 'width',
|
||||
target: n2.id,
|
||||
targetHandle: 'width',
|
||||
};
|
||||
expect(r).toEqual(c);
|
||||
});
|
||||
it('should return null if the target has no valid fields', () => {
|
||||
const e1 = buildEdge(n1.id, 'width', n2.id, 'width');
|
||||
const e2 = buildEdge(n1.id, 'height', n2.id, 'height');
|
||||
const n3 = buildNode(add);
|
||||
const r = getFirstValidConnection(n3.id, null, n2.id, 'a', [n1, n2, n3], [e1, e2], templates, null);
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTargetCandidateFields', () => {
|
||||
it('should return an empty array if the nodes canot be found', () => {
|
||||
const r = getTargetCandidateFields('missing', 'value', 'missing', [], [], templates, null);
|
||||
expect(r).toEqual([]);
|
||||
});
|
||||
it('should return an empty array if the templates cannot be found', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const nodes = [n1, n2];
|
||||
const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], {}, null);
|
||||
expect(r).toEqual([]);
|
||||
});
|
||||
it('should return an empty array if the source field template cannot be found', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const nodes = [n1, n2];
|
||||
|
||||
const addWithoutOutputValue = deepClone(add);
|
||||
unset(addWithoutOutputValue, 'outputs.value');
|
||||
|
||||
const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], { add: addWithoutOutputValue }, null);
|
||||
expect(r).toEqual([]);
|
||||
});
|
||||
it('should return all valid target fields if there are no connected fields', () => {
|
||||
const n1 = buildNode(img_resize);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, null);
|
||||
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
|
||||
});
|
||||
it('should ignore the edgePendingUpdate if provided', () => {
|
||||
const n1 = buildNode(img_resize);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width');
|
||||
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate);
|
||||
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSourceCandidateFields', () => {
|
||||
it('should return an empty array if the nodes canot be found', () => {
|
||||
const r = getSourceCandidateFields('missing', 'value', 'missing', [], [], templates, null);
|
||||
expect(r).toEqual([]);
|
||||
});
|
||||
it('should return an empty array if the templates cannot be found', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const nodes = [n1, n2];
|
||||
const r = getSourceCandidateFields(n2.id, 'a', n1.id, nodes, [], {}, null);
|
||||
expect(r).toEqual([]);
|
||||
});
|
||||
it('should return an empty array if the source field template cannot be found', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const nodes = [n1, n2];
|
||||
|
||||
const addWithoutInputA = deepClone(add);
|
||||
unset(addWithoutInputA, 'inputs.a');
|
||||
|
||||
const r = getSourceCandidateFields(n1.id, 'a', n2.id, nodes, [], { add: addWithoutInputA }, null);
|
||||
expect(r).toEqual([]);
|
||||
});
|
||||
it('should return all valid source fields if there are no connected fields', () => {
|
||||
const n1 = buildNode(img_resize);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, null);
|
||||
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
|
||||
});
|
||||
it('should ignore the edgePendingUpdate if provided', () => {
|
||||
const n1 = buildNode(img_resize);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width');
|
||||
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate);
|
||||
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
|
||||
});
|
||||
});
|
||||
@@ -1,149 +0,0 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { map } from 'lodash-es';
|
||||
import type { Connection, Edge } from 'reactflow';
|
||||
|
||||
/**
|
||||
*
|
||||
* @param source The source (node id)
|
||||
* @param sourceHandle The source handle (field name), if any
|
||||
* @param target The target (node id)
|
||||
* @param targetHandle The target handle (field name), if any
|
||||
* @param nodes The current nodes
|
||||
* @param edges The current edges
|
||||
* @param templates The current templates
|
||||
* @param edgePendingUpdate The edge pending update, if any
|
||||
* @returns
|
||||
*/
|
||||
export const getFirstValidConnection = (
|
||||
source: string,
|
||||
sourceHandle: string | null,
|
||||
target: string,
|
||||
targetHandle: string | null,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
templates: Templates,
|
||||
edgePendingUpdate: Edge | null
|
||||
): Connection | null => {
|
||||
if (source === target) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (sourceHandle && targetHandle) {
|
||||
return { source, sourceHandle, target, targetHandle };
|
||||
}
|
||||
|
||||
if (sourceHandle && !targetHandle) {
|
||||
const candidates = getTargetCandidateFields(
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
|
||||
const firstCandidate = candidates[0];
|
||||
if (!firstCandidate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return { source, sourceHandle, target, targetHandle: firstCandidate.name };
|
||||
}
|
||||
|
||||
if (!sourceHandle && targetHandle) {
|
||||
const candidates = getSourceCandidateFields(
|
||||
target,
|
||||
targetHandle,
|
||||
source,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
|
||||
const firstCandidate = candidates[0];
|
||||
if (!firstCandidate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return { source, sourceHandle: firstCandidate.name, target, targetHandle };
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export const getTargetCandidateFields = (
|
||||
source: string,
|
||||
sourceHandle: string,
|
||||
target: string,
|
||||
nodes: AnyNode[],
|
||||
edges: Edge[],
|
||||
templates: Templates,
|
||||
edgePendingUpdate: Edge | null
|
||||
): FieldInputTemplate[] => {
|
||||
const sourceNode = nodes.find((n) => n.id === source);
|
||||
const targetNode = nodes.find((n) => n.id === target);
|
||||
if (!sourceNode || !targetNode) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!sourceTemplate || !targetTemplate) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const sourceField = sourceTemplate.outputs[sourceHandle];
|
||||
|
||||
if (!sourceField) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
|
||||
const c = { source, sourceHandle, target, targetHandle: field.name };
|
||||
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
|
||||
return r.isValid;
|
||||
});
|
||||
|
||||
return targetCandidateFields;
|
||||
};
|
||||
|
||||
export const getSourceCandidateFields = (
|
||||
target: string,
|
||||
targetHandle: string,
|
||||
source: string,
|
||||
nodes: AnyNode[],
|
||||
edges: Edge[],
|
||||
templates: Templates,
|
||||
edgePendingUpdate: Edge | null
|
||||
): FieldOutputTemplate[] => {
|
||||
const targetNode = nodes.find((n) => n.id === target);
|
||||
const sourceNode = nodes.find((n) => n.id === source);
|
||||
if (!sourceNode || !targetNode) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!sourceTemplate || !targetTemplate) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const targetField = targetTemplate.inputs[targetHandle];
|
||||
|
||||
if (!targetField) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => {
|
||||
const c = { source, sourceHandle: field.name, target, targetHandle };
|
||||
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
|
||||
return r.isValid;
|
||||
});
|
||||
|
||||
return sourceCandidateFields;
|
||||
};
|
||||
@@ -1,22 +0,0 @@
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { add, buildEdge, buildNode } from 'features/nodes/store/util/testUtils';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe(getHasCycles.name, () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const n3 = buildNode(add);
|
||||
const nodes = [n1, n2, n3];
|
||||
|
||||
it('should return true if the graph WOULD have cycles after adding the edge', () => {
|
||||
const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')];
|
||||
const result = getHasCycles(n3.id, n1.id, nodes, edges);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false if the graph WOULD NOT have cycles after adding the edge', () => {
|
||||
const edges = [buildEdge(n1.id, 'value', n2.id, 'a')];
|
||||
const result = getHasCycles(n2.id, n3.id, nodes, edges);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,30 +0,0 @@
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import type { Edge, Node } from 'reactflow';
|
||||
|
||||
/**
|
||||
* Check if adding an edge between the source and target nodes would create a cycle in the graph.
|
||||
* @param source The source node id
|
||||
* @param target The target node id
|
||||
* @param nodes The graph's current nodes
|
||||
* @param edges The graph's current edges
|
||||
* @returns True if the graph would be acyclic after adding the edge, false otherwise
|
||||
*/
|
||||
|
||||
export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
|
||||
// construct graphlib graph from editor state
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// add the candidate edge
|
||||
g.setEdge(source, target);
|
||||
|
||||
// check if the graph is acyclic
|
||||
return !graphlib.alg.isAcyclic(g);
|
||||
};
|
||||
@@ -0,0 +1,21 @@
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import type { Edge, Node } from 'reactflow';
|
||||
|
||||
export const getIsGraphAcyclic = (source: string, target: string, nodes: Node[], edges: Edge[]) => {
|
||||
// construct graphlib graph from editor state
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// add the candidate edge
|
||||
g.setEdge(source, target);
|
||||
|
||||
// check if the graph is acyclic
|
||||
return graphlib.alg.isAcyclic(g);
|
||||
};
|
||||
@@ -1,67 +0,0 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import type { Edge, HandleType } from 'reactflow';
|
||||
|
||||
/**
|
||||
* Creates a selector that validates a pending connection.
|
||||
*
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*
|
||||
* @param templates The invocation templates
|
||||
* @param nodeId The id of the node for which the selector is being created
|
||||
* @param fieldName The name of the field for which the selector is being created
|
||||
* @param handleType The type of the handle for which the selector is being created
|
||||
* @returns
|
||||
*/
|
||||
export const makeConnectionErrorSelector = (
|
||||
templates: Templates,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType
|
||||
) => {
|
||||
return createMemoizedSelector(
|
||||
selectNodesSlice,
|
||||
(state: RootState, pendingConnection: PendingConnection | null) => pendingConnection,
|
||||
(state: RootState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) =>
|
||||
edgePendingUpdate,
|
||||
(nodesSlice: NodesState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => {
|
||||
const { nodes, edges } = nodesSlice;
|
||||
|
||||
if (!pendingConnection) {
|
||||
return buildRejectResult('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
if (handleType === pendingConnection.handleType) {
|
||||
if (handleType === 'source') {
|
||||
return buildRejectResult('nodes.cannotConnectOutputToOutput');
|
||||
}
|
||||
return buildRejectResult('nodes.cannotConnectInputToInput');
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const source = handleType === 'source' ? nodeId : pendingConnection.nodeId;
|
||||
const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId;
|
||||
const target = handleType === 'target' ? nodeId : pendingConnection.nodeId;
|
||||
const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId;
|
||||
|
||||
const validationResult = validateConnection(
|
||||
{
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
},
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
edgePendingUpdate
|
||||
);
|
||||
|
||||
return validationResult;
|
||||
}
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,147 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import i18n from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { HandleType } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
export const getCollectItemType = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
nodeId: string
|
||||
): FieldType | null => {
|
||||
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||
if (!firstEdgeToCollect?.sourceHandle) {
|
||||
return null;
|
||||
}
|
||||
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
||||
return fieldType;
|
||||
};
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const makeConnectionErrorSelector = (
|
||||
templates: Templates,
|
||||
pendingConnection: PendingConnection | null,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType | null
|
||||
) => {
|
||||
return createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const { nodes, edges } = nodesSlice;
|
||||
|
||||
if (!fieldType) {
|
||||
return i18n.t('nodes.noFieldType');
|
||||
}
|
||||
|
||||
if (!pendingConnection) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
const connectionNodeId = pendingConnection.node.id;
|
||||
const connectionFieldName = pendingConnection.fieldTemplate.name;
|
||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
|
||||
|
||||
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
|
||||
return i18n.t('nodes.noConnectionData');
|
||||
}
|
||||
|
||||
const targetType = handleType === 'target' ? fieldType : connectionStartFieldType;
|
||||
const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType;
|
||||
|
||||
if (nodeId === connectionNodeId) {
|
||||
return i18n.t('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
if (handleType === connectionHandleType) {
|
||||
if (handleType === 'source') {
|
||||
return i18n.t('nodes.cannotConnectOutputToOutput');
|
||||
}
|
||||
return i18n.t('nodes.cannotConnectInputToInput');
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
|
||||
const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
|
||||
const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
|
||||
const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
edge.target === targetNodeId &&
|
||||
edge.targetHandle === targetFieldName &&
|
||||
edge.source === sourceNodeId &&
|
||||
edge.sourceHandle === sourceFieldName;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return i18n.t('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((node) => node.id === targetNodeId);
|
||||
assert(targetNode, `Target node not found: ${targetNodeId}`);
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
|
||||
|
||||
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
|
||||
return i18n.t('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
if (!isEqual(sourceType, collectItemType)) {
|
||||
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType.name !== 'CollectionItemField'
|
||||
) {
|
||||
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(sourceType, targetType)) {
|
||||
return i18n.t('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(
|
||||
connectionHandleType === 'source' ? connectionNodeId : nodeId,
|
||||
connectionHandleType === 'source' ? nodeId : connectionNodeId,
|
||||
nodes,
|
||||
edges
|
||||
);
|
||||
|
||||
if (!isGraphAcyclic) {
|
||||
return i18n.t('nodes.connectionWouldCreateCycle');
|
||||
}
|
||||
|
||||
return;
|
||||
});
|
||||
};
|
||||
@@ -1,32 +0,0 @@
|
||||
import type { Connection, Edge } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
* Gets the edge id for a connection
|
||||
* Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45
|
||||
* Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290
|
||||
* @param connection The connection to get the id for
|
||||
* @returns The edge id
|
||||
*/
|
||||
const getEdgeId = (connection: Connection): string => {
|
||||
const { source, sourceHandle, target, targetHandle } = connection;
|
||||
return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Converts a connection to an edge
|
||||
* @param connection The connection to convert to an edge
|
||||
* @returns The edge
|
||||
* @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle)
|
||||
*/
|
||||
export const connectionToEdge = (connection: Connection): Edge => {
|
||||
const { source, sourceHandle, target, targetHandle } = connection;
|
||||
assert(source && sourceHandle && target && targetHandle, 'Invalid connection');
|
||||
return {
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
id: getEdgeId({ source, sourceHandle, target, targetHandle }),
|
||||
};
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,194 +0,0 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { set } from 'lodash-es';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
|
||||
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
||||
|
||||
describe(validateConnection.name, () => {
|
||||
it('should reject invalid connection to self', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
const r = validateConnection(c, [], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
});
|
||||
|
||||
describe('missing nodes', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
|
||||
it('should reject missing source node', () => {
|
||||
const r = validateConnection(c, [n2], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||
});
|
||||
|
||||
it('should reject missing target node', () => {
|
||||
const r = validateConnection(c, [n1], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('missing invocation templates', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const nodes = [n1, n2];
|
||||
|
||||
it('should reject missing source template', () => {
|
||||
const r = validateConnection(c, nodes, [], { sub }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||
});
|
||||
|
||||
it('should reject missing target template', () => {
|
||||
const r = validateConnection(c, nodes, [], { add }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('missing field templates', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const nodes = [n1, n2];
|
||||
|
||||
it('should reject missing source field template', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||
});
|
||||
|
||||
it('should reject missing target field template', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||
});
|
||||
});
|
||||
|
||||
describe('duplicate connections', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
it('should accept non-duplicate connections', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], templates, null);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
it('should reject duplicate connections', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const r = validateConnection(c, [n1, n2], [e], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
|
||||
});
|
||||
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const r = validateConnection(c, [n1, n2], [e], templates, e);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
});
|
||||
|
||||
it('should reject connection to direct input', () => {
|
||||
// Create cloned add template w/ a direct input
|
||||
const addWithDirectAField = deepClone(add);
|
||||
set(addWithDirectAField, 'inputs.a.input', 'direct');
|
||||
set(addWithDirectAField, 'type', 'addWithDirectAField');
|
||||
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(addWithDirectAField);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
|
||||
});
|
||||
|
||||
it('should reject connection to a collect node with mismatched item types', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(collect);
|
||||
const n3 = buildNode(main_model_loader);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
|
||||
});
|
||||
|
||||
it('should accept connection to a collect node with matching item types', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(collect);
|
||||
const n3 = buildNode(sub);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
|
||||
it('should reject connections to target field that is already connected', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const n3 = buildNode(add);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
|
||||
});
|
||||
|
||||
it('should accept connections to target field that is already connected (ignored edge)', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(add);
|
||||
const n3 = buildNode(add);
|
||||
const nodes = [n1, n2, n3];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, e1);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
|
||||
it('should reject connections between invalid types', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
|
||||
});
|
||||
|
||||
it('should reject connections that would create cycles', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const nodes = [n1, n2];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
|
||||
describe('non-strict mode', () => {
|
||||
it('should reject connections from self to self in non-strict mode', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
const r = validateConnection(c, [], [], templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
});
|
||||
it('should reject connections that create cycles in non-strict mode', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const nodes = [n1, n2];
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const edges = [e1];
|
||||
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
});
|
||||
it('should otherwise allow invalid connections in non-strict mode', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(img_resize);
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null, false);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,130 +0,0 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
type Connection = O.NonNullable<NullableConnection>;
|
||||
|
||||
export type ValidationResult =
|
||||
| {
|
||||
isValid: true;
|
||||
messageTKey?: string;
|
||||
}
|
||||
| {
|
||||
isValid: false;
|
||||
messageTKey: string;
|
||||
};
|
||||
|
||||
type ValidateConnectionFunc = (
|
||||
connection: Connection,
|
||||
nodes: AnyNode[],
|
||||
edges: Edge[],
|
||||
templates: Templates,
|
||||
ignoreEdge: Edge | null,
|
||||
strict?: boolean
|
||||
) => ValidationResult;
|
||||
|
||||
const getEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
(e: Edge): boolean => {
|
||||
return (
|
||||
e.target === c.target &&
|
||||
e.targetHandle === c.targetHandle &&
|
||||
e.source === c.source &&
|
||||
e.sourceHandle === c.sourceHandle
|
||||
);
|
||||
};
|
||||
|
||||
const getTargetEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
(e: Edge): boolean => {
|
||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||
};
|
||||
|
||||
export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
|
||||
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
|
||||
if (c.source === c.target) {
|
||||
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||
}
|
||||
|
||||
if (strict) {
|
||||
/**
|
||||
* We may need to ignore an edge when validating a connection.
|
||||
*
|
||||
* For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection,
|
||||
* the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else
|
||||
* the validation will fail unexpectedly.
|
||||
*/
|
||||
const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id);
|
||||
|
||||
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||
// We already have a connection from this source to this target
|
||||
return buildRejectResult('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||
if (!sourceNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((n) => n.id === c.target);
|
||||
if (!targetNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return buildRejectResult('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types.
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
|
||||
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
|
||||
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
|
||||
// CollectionItemField inputs can have multiple input connections
|
||||
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
|
||||
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
}
|
||||
}
|
||||
|
||||
if (getHasCycles(c.source, c.target, nodes, edges)) {
|
||||
return buildRejectResult('nodes.connectionWouldCreateCycle');
|
||||
}
|
||||
|
||||
return buildAcceptResult();
|
||||
};
|
||||
@@ -1,222 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { validateConnectionTypes } from './validateConnectionTypes';
|
||||
|
||||
describe(validateConnectionTypes.name, () => {
|
||||
describe('generic cases', () => {
|
||||
it('should accept SINGLE to SINGLE of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept COLLECTION to COLLECTION of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept SINGLE to SINGLE_OR_COLLECTION of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept COLLECTION to SINGLE_OR_COLLECTION of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should reject COLLECTION to SINGLE of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject SINGLE_OR_COLLECTION to SINGLE of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject mismatched types', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'BarField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('special cases', () => {
|
||||
it('should reject a COLLECTION input to a COLLECTION input', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', cardinality: 'COLLECTION' },
|
||||
{ name: 'CollectionField', cardinality: 'COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
|
||||
it('should accept equal types', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
|
||||
describe('CollectionItemField', () => {
|
||||
it('should accept CollectionItemField to any SINGLE target', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept CollectionItemField to any SINGLE_OR_COLLECTION target', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any SINGLE to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should reject any COLLECTION to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION' },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject any SINGLE_OR_COLLECTION to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('SINGLE_OR_COLLECTION', () => {
|
||||
it('should accept any SINGLE of same type to SINGLE_OR_COLLECTION', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any SINGLE_OR_COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('CollectionField', () => {
|
||||
it('should accept any CollectionField to any COLLECTION type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any CollectionField to any SINGLE_OR_COLLECTION type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('subtype handling', () => {
|
||||
type TypePair = { t1: string; t2: string };
|
||||
const typePairs = [
|
||||
{ t1: 'IntegerField', t2: 'FloatField' },
|
||||
{ t1: 'IntegerField', t2: 'StringField' },
|
||||
{ t1: 'FloatField', t2: 'StringField' },
|
||||
];
|
||||
it.each(typePairs)('should accept SINGLE $t1 to SINGLE $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes({ name: t1, cardinality: 'SINGLE' }, { name: t2, cardinality: 'SINGLE' });
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept SINGLE $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'SINGLE' },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept COLLECTION $t1 to COLLECTION $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'COLLECTION' },
|
||||
{ name: t2, cardinality: 'COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'COLLECTION' },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)(
|
||||
'should accept SINGLE_OR_COLLECTION $t1 to SINGLE_OR_COLLECTION $t2',
|
||||
({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
describe('AnyField', () => {
|
||||
it('should accept any SINGLE type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'AnyField', cardinality: 'SINGLE' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any COLLECTION type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'AnyField', cardinality: 'COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any SINGLE_OR_COLLECTION type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,74 +0,0 @@
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'features/nodes/types/field';
|
||||
|
||||
/**
|
||||
* Validates that the source and target types are compatible for a connection.
|
||||
* @param sourceType The type of the source field.
|
||||
* @param targetType The type of the target field.
|
||||
* @returns True if the connection is valid, false otherwise.
|
||||
*/
|
||||
export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => {
|
||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||
// Once this is resolved, we can remove this check.
|
||||
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (areTypesEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-COLLECTION (e.g. SINGLE or SINGLE_OR_COLLECTION)
|
||||
* - SINGLE can connect to CollectionItem
|
||||
* - Anything (SINGLE, COLLECTION, SINGLE_OR_COLLECTION) can connect to SINGLE_OR_COLLECTION of the same base type
|
||||
* - Generic CollectionField can connect to any other COLLECTION or SINGLE_OR_COLLECTION
|
||||
* - Any COLLECTION can connect to a Generic Collection
|
||||
*/
|
||||
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !isCollection(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem = isSingle(sourceType) && targetType.name === 'CollectionItemField';
|
||||
|
||||
const isAnythingToSingleOrCollectionOfSameBaseType =
|
||||
isSingleOrCollection(targetType) && sourceType.name === targetType.name;
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrSingleOrCollection =
|
||||
sourceType.name === 'CollectionField' && !isSingle(targetType);
|
||||
|
||||
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && isCollection(sourceType);
|
||||
|
||||
const isSourceSingle = isSingle(sourceType);
|
||||
const isTargetSingle = isSingle(targetType);
|
||||
const isSingleToSingle = isSourceSingle && isTargetSingle;
|
||||
const isSingleToSingleOrCollection = isSourceSingle && isSingleOrCollection(targetType);
|
||||
const isCollectionToCollection = isCollection(sourceType) && isCollection(targetType);
|
||||
const isCollectionToSingleOrCollection = isCollection(sourceType) && isSingleOrCollection(targetType);
|
||||
const isSingleOrCollectionToSingleOrCollection = isSingleOrCollection(sourceType) && isSingleOrCollection(targetType);
|
||||
const doesCardinalityMatch =
|
||||
isSingleToSingle ||
|
||||
isCollectionToCollection ||
|
||||
isCollectionToSingleOrCollection ||
|
||||
isSingleOrCollectionToSingleOrCollection ||
|
||||
isSingleToSingleOrCollection;
|
||||
|
||||
const isIntToFloat = sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
|
||||
const isIntToString = sourceType.name === 'IntegerField' && targetType.name === 'StringField';
|
||||
const isFloatToString = sourceType.name === 'FloatField' && targetType.name === 'StringField';
|
||||
|
||||
const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString);
|
||||
|
||||
const isTargetAnyType = targetType.name === 'AnyField';
|
||||
|
||||
// One of these must be true for the connection to be valid
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToSingleOrCollectionOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrSingleOrCollection ||
|
||||
isCollectionToGenericCollection ||
|
||||
isSubTypeMatch ||
|
||||
isTargetAnyType
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,70 @@
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
/**
|
||||
* Validates that the source and target types are compatible for a connection.
|
||||
* @param sourceType The type of the source field.
|
||||
* @param targetType The type of the target field.
|
||||
* @returns True if the connection is valid, false otherwise.
|
||||
*/
|
||||
export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => {
|
||||
// TODO: There's a bug with Collect -> Iterate nodes:
|
||||
// https://github.com/invoke-ai/InvokeAI/issues/3956
|
||||
// Once this is resolved, we can remove this check.
|
||||
if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
|
||||
* - Generic Collection can connect to any other Collection or CollectionOrScalar
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection;
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar;
|
||||
|
||||
const isAnythingToCollectionOrScalarOfSameBaseType =
|
||||
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
|
||||
sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar);
|
||||
|
||||
const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection;
|
||||
|
||||
const areBothTypesSingle =
|
||||
!sourceType.isCollection &&
|
||||
!sourceType.isCollectionOrScalar &&
|
||||
!targetType.isCollection &&
|
||||
!targetType.isCollectionOrScalar;
|
||||
|
||||
const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField';
|
||||
|
||||
const isIntOrFloatToString =
|
||||
areBothTypesSingle &&
|
||||
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
|
||||
targetType.name === 'StringField';
|
||||
|
||||
const isTargetAnyType = targetType.name === 'AnyField';
|
||||
|
||||
// One of these must be true for the connection to be valid
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToCollectionOrScalarOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat ||
|
||||
isIntOrFloatToString ||
|
||||
isTargetAnyType
|
||||
);
|
||||
};
|
||||
@@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FieldIdentifierWithValue,
|
||||
WorkflowMode,
|
||||
@@ -139,31 +139,15 @@ export const workflowSlice = createSlice({
|
||||
};
|
||||
});
|
||||
|
||||
builder.addCase(nodesDeleted, (state, action) => {
|
||||
action.payload.forEach((node) => {
|
||||
state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id);
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
|
||||
|
||||
builder.addCase(nodesChanged, (state, action) => {
|
||||
// If a node was removed, we should remove any exposed fields that were associated with it. However, node changes
|
||||
// may remove and then add the same node back. For example, when updating a workflow, we replace old nodes with
|
||||
// updated nodes. In this case, we should not remove the exposed fields. To handle this, we find the last remove
|
||||
// and add changes for each exposed field. If the remove change comes after the add change, we remove the exposed
|
||||
// field.
|
||||
const exposedFieldsToRemove: FieldIdentifier[] = [];
|
||||
state.exposedFields.forEach((field) => {
|
||||
const removeIndex = action.payload.findLastIndex(
|
||||
(change) => change.type === 'remove' && change.id === field.nodeId
|
||||
);
|
||||
const addIndex = action.payload.findLastIndex(
|
||||
(change) => change.type === 'add' && change.item.id === field.nodeId
|
||||
);
|
||||
if (removeIndex > addIndex) {
|
||||
exposedFieldsToRemove.push({ nodeId: field.nodeId, fieldName: field.fieldName });
|
||||
}
|
||||
});
|
||||
|
||||
state.exposedFields = state.exposedFields.filter(
|
||||
(field) => !exposedFieldsToRemove.some((f) => isEqual(f, field))
|
||||
);
|
||||
|
||||
// Not all changes to nodes should result in the workflow being marked touched
|
||||
const filteredChanges = action.payload.filter((change) => {
|
||||
// We always want to mark the workflow as touched if a node is added, removed, or reset
|
||||
@@ -181,7 +165,7 @@ export const workflowSlice = createSlice({
|
||||
return false;
|
||||
});
|
||||
|
||||
if (filteredChanges.length > 0 || exposedFieldsToRemove.length > 0) {
|
||||
if (filteredChanges.length > 0) {
|
||||
state.isTouched = true;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -54,10 +54,9 @@ const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
});
|
||||
|
||||
const zCardinality = z.enum(['SINGLE', 'COLLECTION', 'SINGLE_OR_COLLECTION']);
|
||||
|
||||
const zFieldTypeBase = z.object({
|
||||
cardinality: zCardinality,
|
||||
isCollection: z.boolean(),
|
||||
isCollectionOrScalar: z.boolean(),
|
||||
});
|
||||
|
||||
export const zFieldIdentifier = z.object({
|
||||
@@ -67,124 +66,16 @@ export const zFieldIdentifier = z.object({
|
||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||
// #endregion
|
||||
|
||||
// #region Field Types
|
||||
const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
name: z.string().min(1), // stateless --> we accept the field's name as the type
|
||||
});
|
||||
// #region IntegerField
|
||||
const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IntegerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zEnumFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('EnumField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zModelIdentifierFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ModelIdentifierField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zModelIdentifierFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
|
||||
export const isStatefulFieldType = (fieldType: FieldType): fieldType is StatefulFieldType =>
|
||||
(statefulFieldTypeNames as string[]).includes(fieldType.name);
|
||||
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
|
||||
export const isSingle = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.SINGLE;
|
||||
export const isCollection = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.COLLECTION;
|
||||
export const isSingleOrCollection = (fieldType: FieldType): boolean =>
|
||||
fieldType.cardinality === zCardinality.enum.SINGLE_OR_COLLECTION;
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField
|
||||
|
||||
export const zIntegerFieldValue = z.number().int();
|
||||
const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIntegerFieldValue,
|
||||
});
|
||||
const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIntegerFieldValue,
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().int().optional(),
|
||||
@@ -205,14 +96,15 @@ export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldIn
|
||||
// #endregion
|
||||
|
||||
// #region FloatField
|
||||
|
||||
const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
});
|
||||
export const zFloatFieldValue = z.number();
|
||||
const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFloatFieldValue,
|
||||
});
|
||||
const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFloatFieldValue,
|
||||
multipleOf: z.number().optional(),
|
||||
maximum: z.number().optional(),
|
||||
@@ -233,14 +125,15 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region StringField
|
||||
|
||||
const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
});
|
||||
export const zStringFieldValue = z.string();
|
||||
const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStringFieldValue,
|
||||
});
|
||||
const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStringFieldValue,
|
||||
maxLength: z.number().int().optional(),
|
||||
minLength: z.number().int().optional(),
|
||||
@@ -259,14 +152,15 @@ export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInpu
|
||||
// #endregion
|
||||
|
||||
// #region BooleanField
|
||||
|
||||
const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
});
|
||||
export const zBooleanFieldValue = z.boolean();
|
||||
const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zBooleanFieldValue,
|
||||
});
|
||||
const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zBooleanFieldValue,
|
||||
});
|
||||
const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -282,14 +176,15 @@ export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldIn
|
||||
// #endregion
|
||||
|
||||
// #region EnumField
|
||||
|
||||
const zEnumFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('EnumField'),
|
||||
});
|
||||
export const zEnumFieldValue = z.string();
|
||||
const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zEnumFieldValue,
|
||||
});
|
||||
const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zEnumFieldValue,
|
||||
options: z.array(z.string()),
|
||||
labels: z.record(z.string()).optional(),
|
||||
@@ -307,14 +202,15 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem
|
||||
// #endregion
|
||||
|
||||
// #region ImageField
|
||||
|
||||
const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
});
|
||||
export const zImageFieldValue = zImageField.optional();
|
||||
const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImageFieldValue,
|
||||
});
|
||||
const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zImageFieldValue,
|
||||
});
|
||||
const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -330,14 +226,15 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
});
|
||||
export const zBoardFieldValue = zBoardField.optional();
|
||||
const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zBoardFieldValue,
|
||||
});
|
||||
const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zBoardFieldValue,
|
||||
});
|
||||
const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -353,14 +250,15 @@ export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region ColorField
|
||||
|
||||
const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
});
|
||||
export const zColorFieldValue = zColorField.optional();
|
||||
const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zColorFieldValue,
|
||||
});
|
||||
const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zColorFieldValue,
|
||||
});
|
||||
const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -376,14 +274,15 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
|
||||
// #endregion
|
||||
|
||||
// #region MainModelField
|
||||
|
||||
const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
});
|
||||
export const zMainModelFieldValue = zModelIdentifierField.optional();
|
||||
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zMainModelFieldValue,
|
||||
});
|
||||
const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -398,37 +297,16 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie
|
||||
zMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ModelIdentifierField
|
||||
export const zModelIdentifierFieldValue = zModelIdentifierField.optional();
|
||||
const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zModelIdentifierFieldValue,
|
||||
});
|
||||
const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zModelIdentifierFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zModelIdentifierFieldValue,
|
||||
});
|
||||
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zModelIdentifierFieldType,
|
||||
});
|
||||
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
|
||||
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
|
||||
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
|
||||
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
|
||||
zModelIdentifierFieldInputInstance.safeParse(val).success;
|
||||
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
|
||||
zModelIdentifierFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
|
||||
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
});
|
||||
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSDXLMainModelFieldValue,
|
||||
});
|
||||
const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSDXLMainModelFieldValue,
|
||||
});
|
||||
const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -443,7 +321,9 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
});
|
||||
/** @alias */ // tells knip to ignore this duplicate export
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@@ -451,7 +331,6 @@ const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
});
|
||||
const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -467,14 +346,15 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
});
|
||||
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zVAEModelFieldValue,
|
||||
});
|
||||
const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -490,14 +370,15 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
|
||||
// #endregion
|
||||
|
||||
// #region LoRAModelField
|
||||
|
||||
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
});
|
||||
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zLoRAModelFieldValue,
|
||||
});
|
||||
const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -513,14 +394,15 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
});
|
||||
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zControlNetModelFieldValue,
|
||||
});
|
||||
const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -536,14 +418,15 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
|
||||
// #endregion
|
||||
|
||||
// #region IPAdapterModelField
|
||||
|
||||
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
});
|
||||
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIPAdapterModelFieldValue,
|
||||
});
|
||||
const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -559,14 +442,15 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
|
||||
// #endregion
|
||||
|
||||
// #region T2IAdapterField
|
||||
|
||||
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
});
|
||||
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -582,14 +466,15 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
});
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSchedulerFieldValue,
|
||||
});
|
||||
const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSchedulerFieldValue,
|
||||
});
|
||||
const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -616,14 +501,15 @@ export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFie
|
||||
* - Reserved fields like IsIntermediate
|
||||
* - Any other field we don't have full-on schemas for
|
||||
*/
|
||||
|
||||
const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
name: z.string().min(1), // stateless --> we accept the field's name as the type
|
||||
});
|
||||
const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
|
||||
const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStatelessFieldValue,
|
||||
});
|
||||
const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStatelessFieldValue,
|
||||
input: z.literal('connection'), // stateless --> only accepts connection inputs
|
||||
});
|
||||
@@ -649,6 +535,34 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
|
||||
* for all other StatelessFields.
|
||||
*/
|
||||
|
||||
// #region StatefulFieldType & FieldType
|
||||
const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
export const isStatefulFieldType = (val: unknown): val is StatefulFieldType =>
|
||||
zStatefulFieldType.safeParse(val).success;
|
||||
|
||||
const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldValue & FieldValue
|
||||
export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
@@ -658,7 +572,6 @@ export const zStatefulFieldValue = z.union([
|
||||
zEnumFieldValue,
|
||||
zImageFieldValue,
|
||||
zBoardFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
@@ -685,7 +598,6 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zEnumFieldInputInstance,
|
||||
zImageFieldInputInstance,
|
||||
zBoardFieldInputInstance,
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
@@ -713,7 +625,6 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zEnumFieldInputTemplate,
|
||||
zImageFieldInputTemplate,
|
||||
zBoardFieldInputTemplate,
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
@@ -742,7 +653,6 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { StatefulFieldType, StatelessFieldType } from 'features/nodes/types/v2/field';
|
||||
import type { FieldType, StatefulFieldType } from 'features/nodes/types/field';
|
||||
|
||||
import type { FieldTypeV1 } from './workflowV1';
|
||||
|
||||
@@ -165,7 +165,7 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
|
||||
* Thus, this object was manually edited to ensure it is correct.
|
||||
*/
|
||||
const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
|
||||
[key in FieldTypeV1]?: StatelessFieldType;
|
||||
[key in FieldTypeV1]?: FieldType;
|
||||
} = {
|
||||
Any: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false },
|
||||
ClipField: {
|
||||
|
||||
@@ -316,7 +316,6 @@ const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
name: z.string().min(1), // stateless --> we accept the field's name as the type
|
||||
});
|
||||
export type StatelessFieldType = z.infer<typeof zStatelessFieldType>;
|
||||
const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
|
||||
const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
@@ -328,27 +327,6 @@ const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
|
||||
// #endregion
|
||||
|
||||
const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
|
||||
/**
|
||||
* Here we define the main field unions:
|
||||
* - FieldType
|
||||
|
||||
@@ -47,7 +47,6 @@ const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({
|
||||
type: z.literal('default'),
|
||||
sourceHandle: z.string().trim().min(1),
|
||||
targetHandle: z.string().trim().min(1),
|
||||
hidden: z.boolean().optional(),
|
||||
});
|
||||
const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({
|
||||
type: z.literal('collapsed'),
|
||||
|
||||
@@ -29,7 +29,7 @@ export const addControlNetToLinearGraph = async (
|
||||
assert(activeTabName !== 'generation', 'Tried to use addControlNetToLinearGraph on generation tab');
|
||||
|
||||
if (controlNets.length) {
|
||||
// Even though denoise_latents' control input is SINGLE_OR_COLLECTION, keep it simple and always use a collect
|
||||
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
|
||||
const controlNetIterateNode: Invocation<'collect'> = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
|
||||
@@ -25,7 +25,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
});
|
||||
|
||||
if (ipAdapters.length) {
|
||||
// Even though denoise_latents' ip adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect
|
||||
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
||||
const ipAdapterCollectNode: Invocation<'collect'> = {
|
||||
id: IP_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
|
||||
@@ -28,7 +28,7 @@ export const addT2IAdaptersToLinearGraph = async (
|
||||
);
|
||||
|
||||
if (t2iAdapters.length) {
|
||||
// Even though denoise_latents' t2i adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect
|
||||
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
|
||||
const t2iAdapterCollectNode: Invocation<'collect'> = {
|
||||
id: T2I_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
|
||||
@@ -330,7 +330,6 @@ export const buildCanvasImageToImageGraph = async (
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
CANVAS_OUTPUT,
|
||||
@@ -422,15 +421,6 @@ export const buildCanvasInpaintGraph = async (
|
||||
});
|
||||
}
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
generation_mode: 'inpaint',
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
CANVAS_OUTPAINT_GRAPH,
|
||||
CANVAS_OUTPUT,
|
||||
@@ -580,15 +579,6 @@ export const buildCanvasOutpaintGraph = async (
|
||||
);
|
||||
}
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
generation_mode: 'outpaint',
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
@@ -332,7 +332,6 @@ export const buildCanvasSDXLImageToImageGraph = async (
|
||||
init_image: initialImage.image_name,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
INPAINT_CREATE_MASK,
|
||||
@@ -433,15 +432,6 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
});
|
||||
}
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
generation_mode: 'sdxl_inpaint',
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
INPAINT_CREATE_MASK,
|
||||
@@ -589,15 +588,6 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
);
|
||||
}
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
generation_mode: 'sdxl_outpaint',
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
@@ -291,7 +291,6 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
@@ -280,7 +280,6 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
clip_skip: clipSkip,
|
||||
_canvas_objects: state.canvas.layerState.objects,
|
||||
},
|
||||
CANVAS_OUTPUT
|
||||
);
|
||||
|
||||
@@ -11,6 +11,8 @@ export const addSDXLLoRas = (
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
modelLoader: Invocation<'sdxl_model_loader'>,
|
||||
seamless: Invocation<'seamless'> | null,
|
||||
clipSkip: Invocation<'clip_skip'>,
|
||||
clipSkip2: Invocation<'clip_skip'>,
|
||||
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'sdxl_compel_prompt'>
|
||||
): void => {
|
||||
@@ -37,8 +39,8 @@ export const addSDXLLoRas = (
|
||||
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
||||
// Use seamless as UNet input if it exists, otherwise use the model loader
|
||||
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
|
||||
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
|
||||
g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2');
|
||||
g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip');
|
||||
g.addEdge(clipSkip2, 'clip', loraCollectionLoader, 'clip2');
|
||||
// Reroute UNet & CLIP connections through the LoRA collection loader
|
||||
g.deleteEdgesTo(denoise, ['unet']);
|
||||
g.deleteEdgesTo(posCond, ['clip', 'clip2']);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
LATENTS_TO_IMAGE,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
@@ -29,6 +30,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
clipSkip: skipped_layers,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
@@ -51,6 +53,16 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
id: SDXL_MODEL_LOADER,
|
||||
model,
|
||||
});
|
||||
const clipSkip = g.addNode({
|
||||
type: 'clip_skip',
|
||||
id: CLIP_SKIP,
|
||||
skipped_layers,
|
||||
});
|
||||
const clipSkip2 = g.addNode({
|
||||
type: 'clip_skip',
|
||||
id: `${CLIP_SKIP}_2`,
|
||||
skipped_layers,
|
||||
});
|
||||
const posCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
@@ -103,10 +115,12 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
|
||||
|
||||
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 'clip', negCond, 'clip');
|
||||
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
|
||||
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
|
||||
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
|
||||
g.addEdge(modelLoader, 'clip2', clipSkip2, 'clip');
|
||||
g.addEdge(clipSkip, 'clip', posCond, 'clip');
|
||||
g.addEdge(clipSkip, 'clip', negCond, 'clip');
|
||||
g.addEdge(clipSkip2, 'clip', posCond, 'clip2');
|
||||
g.addEdge(clipSkip2, 'clip', negCond, 'clip2');
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
||||
@@ -132,12 +146,13 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
scheduler,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
clip_skip: skipped_layers,
|
||||
vae: vae ?? undefined,
|
||||
});
|
||||
|
||||
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
|
||||
|
||||
addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond);
|
||||
addSDXLLoRas(state, g, denoise, modelLoader, seamless, clipSkip, clipSkip2, posCond, negCond);
|
||||
|
||||
// We might get the VAE from the main model, custom VAE, or seamless node.
|
||||
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user