Compare commits

..

33 Commits

Author SHA1 Message Date
psychedelicious
8b76d112be tests(config): set root to a tmp dir if didn't parse args
This prevents tests from triggering config related parsing on your "live" root.
2024-05-14 18:02:22 +10:00
psychedelicious
6e40142a59 tests(config): test migrations directly, not via load_and_migrate_config 2024-05-14 17:16:50 +10:00
psychedelicious
00ccd73d53 Merge branch 'main' into lstein/feat/config-migration 2024-05-14 16:56:40 +10:00
psychedelicious
964adb817c tests(config): update tests for config migration 2024-05-14 16:55:53 +10:00
psychedelicious
7d8b011f89 fix(config): restore missing config field assignment in migration 2024-05-14 16:55:44 +10:00
psychedelicious
d487102904 fix(config): fix config _check_for_discontinuities
Need to sort the migrations first.
2024-05-14 16:55:22 +10:00
psychedelicious
4c081d58e0 tidy(config): add note about circular deps in config_migrate.py 2024-05-14 16:31:11 +10:00
psychedelicious
18b5aafade tidy(config): add "config" to class names to differentiate from SQLite migration classes 2024-05-14 16:21:50 +10:00
psychedelicious
6946a3871f feat(config): simplify config migrator logic
- Remove `Migrations` class - unnecessary complexity on top of `MigrationEntry`
- Move common classes to `config_common`
- Tidy docstrings, variable names
2024-05-14 16:20:40 +10:00
Lincoln Stein
fc23b16a73 add more checking of migration step operations 2024-05-03 06:49:16 -04:00
Lincoln Stein
31f63028fd merge 2024-05-03 06:44:37 -04:00
Lincoln Stein
2dd42d0917 check that right no. of migration steps run 2024-05-03 06:44:03 -04:00
Lincoln Stein
2bba7f38b9 Merge branch 'main' into lstein/feat/config-migration 2024-05-03 00:04:37 -04:00
Lincoln Stein
a48abfacf4 make config migrator into an instance; refactor location of get_config() 2024-05-02 23:45:34 -04:00
Lincoln Stein
d5aee87684 Merge branch 'main' into lstein/feat/config-migration 2024-04-30 18:08:46 -04:00
Lincoln Stein
36b14343c7 Merge branch 'main' into lstein/feat/config-migration 2024-04-28 14:32:41 -04:00
Lincoln Stein
59deef97c5 Merge branch 'main' into lstein/feat/config-migration 2024-04-28 14:31:43 -04:00
Lincoln Stein
d852ca7a8d added test for non-contiguous migration routines 2024-04-28 14:31:38 -04:00
Lincoln Stein
d24877561d reinstated failing deny_nodes validation test for Graph 2024-04-25 00:22:09 -04:00
Lincoln Stein
8144a263de updated and reinstated the test_deny_nodes() unit test 2024-04-24 22:14:48 -04:00
Lincoln Stein
ab086a7069 Merge branch 'main' into lstein/feat/config-migration 2024-04-24 21:37:46 -04:00
Lincoln Stein
048306b417 Merge branch 'main' into lstein/feat/config-migration 2024-04-24 21:37:12 -04:00
Lincoln Stein
6eaed9a9cb check for strictly contiguous from_version->to_version ranges 2024-04-24 21:36:28 -04:00
psychedelicious
ab9ebef345 tests(config): fix typo 2024-04-23 17:52:51 +10:00
psychedelicious
984dd93798 tests(config): add failing test case to for config migrator 2024-04-23 17:50:31 +10:00
psychedelicious
d12fb7db68 fix(config): fix duplicate migration logic
This was checking a `Version` object against a `MigrationEntry`, but what we want is to check the version object against `MigrationEntry.from_version`
2024-04-23 17:25:53 +10:00
psychedelicious
5d411e446a tidy(config): use a type alias for the migration function 2024-04-23 17:21:05 +10:00
psychedelicious
6f128c86b4 tidy(config): use dataclass for MigrationEntry
The only pydantic usage was to convert strings to `Version` objects. The reason to do this conversion was to allow the register decorator to accept strings. MigrationEntry is only created inside this class, so we can just create versions from each migration when instantiating MigrationEntry instead.

Also, pydantic doesn't provide runtime time checking for arbitrary classes like Version, so we don't get any real benefit.
2024-04-23 17:19:54 +10:00
psychedelicious
aca9e44a3a fix(config): use TypeAlias instead of TypeVar
TypeVar is for generics, but the usage here is as an alias
2024-04-23 17:12:19 +10:00
psychedelicious
e39f035264 tidy(config): removed extraneous ABC
We don't need separate implementations for this class, let's not complicate it with an ABC
2024-04-23 17:11:13 +10:00
psychedelicious
b612c73954 tidy(config): remove unused TYPE_CHECKING block 2024-04-23 17:09:50 +10:00
Lincoln Stein
36495b730d use packaging.version rather than version-parse 2024-04-18 23:07:54 -04:00
Lincoln Stein
6ad1948a44 add InvokeAIAppConfig schema migration system 2024-04-18 21:33:54 -04:00
155 changed files with 3943 additions and 5951 deletions

View File

@@ -98,7 +98,7 @@ Updating is exactly the same as installing - download the latest installer, choo
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
[installation requirements]: INSTALL_REQUIREMENTS.md
[installation requirements]: INSTALLATION.md#installation-requirements
[FAQ]: ../help/FAQ.md
[install some models]: 050_INSTALLING_MODELS.md
[configuration docs]: ../features/CONFIGURATION.md

View File

@@ -26,7 +26,7 @@ import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
@@ -164,12 +164,6 @@ def custom_openapi() -> dict[str, Any]:
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
@@ -178,8 +172,6 @@ def custom_openapi() -> dict[str, Any]:
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?

View File

@@ -3,7 +3,7 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
custom_nodes_path = Path(get_config().custom_nodes_path)
custom_nodes_path.mkdir(parents=True, exist_ok=True)

View File

@@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import (
FieldKind,
Input,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string

View File

@@ -1,11 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, List, Union
from typing import Literal, Optional
import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from transformers import AutoModelForCausalLM, AutoTokenizer
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
@@ -16,7 +15,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import ImageOutput, CaptionImageOutputs, CaptionImageOutput
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
@@ -67,56 +66,6 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto)
@invocation(
"auto_caption_image",
title="Automatically Caption Image",
tags=["image", "caption"],
category="image",
version="1.2.2",
)
class CaptionImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Adds a caption to an image"""
images: Union[ImageField,List[ImageField]] = InputField(description="The image to caption")
prompt: str = InputField(default="Describe this list of images in 20 words or less", description="Describe how you would like the image to be captioned.")
def invoke(self, context: InvocationContext) -> CaptionImageOutputs:
model_id = "vikhyatk/moondream2"
model_revision = "2024-04-02"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
moondream_model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=model_revision
)
output: CaptionImageOutputs = CaptionImageOutputs()
try:
from PIL.Image import Image
images: List[Image] = []
image_fields = self.images if isinstance(self.images, list) else [self.images]
for image in image_fields:
images.append(context.images.get_pil(image.image_name))
answers: List[str] = moondream_model.batch_answer(
images=images,
prompts=[self.prompt] * len(images),
tokenizer=tokenizer,
)
assert isinstance(answers, list)
for i, answer in enumerate(answers):
output.images.append(CaptionImageOutput(
image=image_fields[i],
width=images[i].width,
height=images[i].height,
caption=answer
))
except:
raise
finally:
del moondream_model
del tokenizer
return output
@invocation(
"img_crop",
title="Crop Image",
@@ -245,7 +194,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Extracts the alpha channel of an image as a mask."""
image: List[ImageField] = InputField(description="The image to create the mask from")
image: ImageField = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@@ -190,75 +190,6 @@ class LoRALoaderInvocation(BaseInvocation):
return output
@invocation_output("lora_selector_output")
class LoRASelectorOutput(BaseInvocationOutput):
"""Model loader output"""
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> LoRASelectorOutput:
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
output = LoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base in (BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2)
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output
@invocation_output("sdxl_lora_loader_output")
class SDXLLoRALoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
@@ -348,72 +279,6 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
return output
@invocation(
"sdxl_lora_collection_loader",
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.0.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
clip2: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
output = SDXLLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base is BaseModelType.StableDiffusionXL
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.clip2 is not None:
if output.clip2 is None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(lora)
return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional, List
from typing import Optional
import torch
@@ -247,17 +247,6 @@ class ImageOutput(BaseInvocationOutput):
)
@invocation_output("captioned_image_output")
class CaptionImageOutput(ImageOutput):
caption: str = OutputField(description="Caption for given image")
@invocation_output("captioned_image_outputs")
class CaptionImageOutputs(BaseInvocationOutput):
images: List[CaptionImageOutput] = OutputField(description="List of captioned images", default=[])
@invocation_output("image_collection_output")
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""

View File

@@ -2,6 +2,7 @@
from invokeai.app.services.config.config_common import PagingArgumentParser
from .config_default import InvokeAIAppConfig, get_config
from .config_default import InvokeAIAppConfig
from .config_migrate import get_config
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]

View File

@@ -12,6 +12,10 @@ from __future__ import annotations
import argparse
import pydoc
from dataclasses import dataclass
from typing import Any, Callable, TypeAlias
from packaging.version import Version
class PagingArgumentParser(argparse.ArgumentParser):
@@ -23,3 +27,21 @@ class PagingArgumentParser(argparse.ArgumentParser):
def print_help(self, file=None) -> None:
text = self.format_help()
pydoc.pager(text)
AppConfigDict: TypeAlias = dict[str, Any]
ConfigMigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class ConfigMigration:
"""Defines an individual config migration."""
from_version: Version
to_version: Version
function: ConfigMigrationFunction
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))

View File

@@ -3,11 +3,8 @@
from __future__ import annotations
import locale
import os
import re
import shutil
from functools import lru_cache
from pathlib import Path
from typing import Any, Literal, Optional
@@ -16,9 +13,7 @@ import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
import invokeai.configs as model_configs
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
@@ -346,169 +341,3 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config

View File

@@ -0,0 +1,177 @@
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Utility class for migrating among versions of the InvokeAI app config schema.
"""
import locale
import shutil
from copy import deepcopy
from functools import lru_cache
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable
import yaml
from packaging.version import Version
import invokeai.configs as model_configs
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
class ConfigMigrator:
"""This class allows migrators to register their input and output versions."""
def __init__(self) -> None:
self._migrations: set[ConfigMigration] = set()
def register(self, migration: ConfigMigration) -> None:
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
if migration_from_already_registered or migration_to_already_registered:
raise ValueError(
f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
)
self._migrations.add(migration)
@staticmethod
def _check_for_discontinuities(migrations: Iterable[ConfigMigration]) -> None:
current_version = Version("3.0.0")
sorted_migrations = sorted(migrations, key=lambda x: x.from_version)
for m in sorted_migrations:
if current_version != m.from_version:
raise ValueError(
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
)
current_version = m.to_version
def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
"""
Use the registered migrations to bring config up to latest version.
Args:
original_config: The original configuration.
Returns:
The new configuration, lifted up to the latest version.
"""
# Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
self._check_for_discontinuities(sorted_migrations)
# Do not mutate the incoming dict - we don't know who else may be using it
migrated_config = deepcopy(original_config)
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
if "InvokeAI" in migrated_config:
version = Version("3.0.0")
else:
version = Version(migrated_config["schema_version"])
for migration in sorted_migrations:
if version == migration.from_version:
migrated_config = migration.function(migrated_config)
version = migration.to_version
# We must end on the latest version
assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
return migrated_config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict: AppConfigDict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
migrator = ConfigMigrator()
migrator.register(config_migration_1)
migrator.register(config_migration_2)
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
# TODO(psyche): This must must be in this file to avoid circular dependencies
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
tmpdir = TemporaryDirectory()
config._root = Path(tmpdir.name)
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config

View File

@@ -0,0 +1,102 @@
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Schema migrations to perform on an InvokeAIAppConfig object.
The Migrations class defined in this module defines a series of
schema version migration steps for the InvokeAIConfig object.
To define a new migration, add a migration function to
Migrations.load_migrations() following the existing examples.
"""
from pathlib import Path
from packaging.version import Version
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from .config_default import InvokeAIAppConfig
def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
"""Migrate a v3.0.0 config dict to v4.0.0.
Changes in this migration:
- `outdir` was renamed to `outputs_dir`
- `max_cache_size` was renamed to `ram`
- `max_vram_cache_size` was renamed to `vram`
- `conf_path`, which pointed to the old `models.yaml`, was removed - but if need to stash it to migrate the entries
to the database
- `legacy_conf_dir` was changed from a path relative to the app root, to a path relative to $INVOKEAI_ROOT/configs
Args:
config_dict: The v3.0.0 config dict to migrate.
Returns:
The migrated v4.0.0 config dict.
"""
migrated_config: AppConfigDict = {}
for _category_name, category_dict in original_config["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
migrated_config["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
migrated_config["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
migrated_config["vram"] = v
if k == "conf_path":
migrated_config["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
migrated_config["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
migrated_config["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
migrated_config[k] = v
migrated_config["schema_version"] = "4.0.0"
return migrated_config
config_migration_1 = ConfigMigration(
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
)
def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
"""Migrate a v4.0.0 config dict to v4.0.1.
Changes in this migration:
- `precision: "autocast"` was removed, fall back to "auto"
Args:
config_dict: The v4.0.0 config dict to migrate.
Returns:
The migrated v4.0.1 config dict.
"""
migrated_config: AppConfigDict = {}
for k, v in original_config.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
migrated_config["precision"] = "auto"
# skip unknown fields
elif k in InvokeAIAppConfig.model_fields:
migrated_config[k] = v
migrated_config["schema_version"] = "4.0.1"
return migrated_config
config_migration_2 = ConfigMigration(
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
)

View File

@@ -9,7 +9,7 @@ from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices

View File

@@ -1,6 +1,6 @@
from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1

View File

@@ -1,7 +1,7 @@
import sqlite3
from pathlib import Path
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration

View File

@@ -9,7 +9,7 @@ from einops import repeat
from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize

View File

@@ -5,7 +5,7 @@
import numpy as np
import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice

View File

@@ -6,7 +6,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice

View File

@@ -9,7 +9,7 @@ import numpy as np
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
class PatchMatch:

View File

@@ -10,7 +10,7 @@ from imwatermark import WatermarkEncoder
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
config = get_config()

View File

@@ -12,7 +12,7 @@ from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings

View File

@@ -20,7 +20,7 @@ from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData

View File

@@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
import torch
from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterData,
Range,

View File

@@ -3,7 +3,7 @@ from typing import Dict, Literal, Optional, Union
import torch
from deprecated import deprecated
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import get_config
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]

View File

@@ -180,8 +180,7 @@ import urllib.parse
from pathlib import Path
from typing import Any, Dict, Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config import InvokeAIAppConfig, get_config
try:
import syslog

View File

@@ -10,8 +10,6 @@ module.exports = {
'path/no-relative-imports': ['error', { maxDepth: 0 }],
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
},
overrides: [
/**

View File

@@ -43,5 +43,4 @@ stats.html
yalc.lock
# vitest
tsconfig.vitest-temp.json
coverage/
tsconfig.vitest-temp.json

View File

@@ -35,7 +35,6 @@
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",
"test": "vitest",
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch"
},
"madge": {
@@ -133,8 +132,6 @@
"@types/react-dom": "^18.3.0",
"@types/uuid": "^9.0.8",
"@vitejs/plugin-react-swc": "^3.6.0",
"@vitest/coverage-v8": "^1.5.0",
"@vitest/ui": "^1.5.0",
"concurrently": "^8.2.2",
"dpdm": "^3.14.0",
"eslint": "^8.57.0",

View File

@@ -229,12 +229,6 @@ devDependencies:
'@vitejs/plugin-react-swc':
specifier: ^3.6.0
version: 3.6.0(vite@5.2.11)
'@vitest/coverage-v8':
specifier: ^1.5.0
version: 1.6.0(vitest@1.6.0)
'@vitest/ui':
specifier: ^1.5.0
version: 1.6.0(vitest@1.6.0)
concurrently:
specifier: ^8.2.2
version: 8.2.2
@@ -294,7 +288,7 @@ devDependencies:
version: 4.3.2(typescript@5.4.5)(vite@5.2.11)
vitest:
specifier: ^1.6.0
version: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
version: 1.6.0(@types/node@20.12.10)
packages:
@@ -1685,10 +1679,6 @@ packages:
resolution: {integrity: sha512-4iri8i1AqYHJE2DstZYkyEprg6Pq6sKx3xn5FpySk9sNhH7qN2LLlHJCfDTZRILNwQNPD7mATWM0TBui7uC1pA==}
dev: true
/@bcoe/v8-coverage@0.2.3:
resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==}
dev: true
/@chakra-ui/accordion@2.3.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react@18.3.1):
resolution: {integrity: sha512-FSXRm8iClFyU+gVaXisOSEw0/4Q+qZbFRiuhIAkVU6Boj0FxAMrlo9a8AV5TuF77rgaHytCdHk0Ng+cyUijrag==}
peerDependencies:
@@ -3645,11 +3635,6 @@ packages:
wrap-ansi-cjs: /wrap-ansi@7.0.0
dev: true
/@istanbuljs/schema@0.1.3:
resolution: {integrity: sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==}
engines: {node: '>=8'}
dev: true
/@jest/schemas@29.6.3:
resolution: {integrity: sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==}
engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0}
@@ -3837,10 +3822,6 @@ packages:
dev: true
optional: true
/@polka/url@1.0.0-next.25:
resolution: {integrity: sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==}
dev: true
/@popperjs/core@2.11.8:
resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==}
dev: false
@@ -5165,7 +5146,7 @@ packages:
dom-accessibility-api: 0.6.3
lodash: 4.17.21
redent: 3.0.0
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
vitest: 1.6.0(@types/node@20.12.10)
dev: true
/@testing-library/user-event@14.5.2(@testing-library/dom@9.3.4):
@@ -5844,29 +5825,6 @@ packages:
- '@swc/helpers'
dev: true
/@vitest/coverage-v8@1.6.0(vitest@1.6.0):
resolution: {integrity: sha512-KvapcbMY/8GYIG0rlwwOKCVNRc0OL20rrhFkg/CHNzncV03TE2XWvO5w9uZYoxNiMEBacAJt3unSOiZ7svePew==}
peerDependencies:
vitest: 1.6.0
dependencies:
'@ampproject/remapping': 2.3.0
'@bcoe/v8-coverage': 0.2.3
debug: 4.3.4
istanbul-lib-coverage: 3.2.2
istanbul-lib-report: 3.0.1
istanbul-lib-source-maps: 5.0.4
istanbul-reports: 3.1.7
magic-string: 0.30.10
magicast: 0.3.4
picocolors: 1.0.0
std-env: 3.7.0
strip-literal: 2.1.0
test-exclude: 6.0.0
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
transitivePeerDependencies:
- supports-color
dev: true
/@vitest/expect@1.3.1:
resolution: {integrity: sha512-xofQFwIzfdmLLlHa6ag0dPV8YsnKOCP1KdAeVVh34vSjN2dcUiXYCD9htu/9eM7t8Xln4v03U9HLxLpPlsXdZw==}
dependencies:
@@ -5911,21 +5869,6 @@ packages:
tinyspy: 2.2.1
dev: true
/@vitest/ui@1.6.0(vitest@1.6.0):
resolution: {integrity: sha512-k3Lyo+ONLOgylctiGovRKy7V4+dIN2yxstX3eY5cWFXH6WP+ooVX79YSyi0GagdTQzLmT43BF27T0s6dOIPBXA==}
peerDependencies:
vitest: 1.6.0
dependencies:
'@vitest/utils': 1.6.0
fast-glob: 3.3.2
fflate: 0.8.2
flatted: 3.3.1
pathe: 1.1.2
picocolors: 1.0.0
sirv: 2.0.4
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
dev: true
/@vitest/utils@1.3.1:
resolution: {integrity: sha512-d3Waie/299qqRyHTm2DjADeTaNdNSVsnwHPWrs20JMpjh6eiVq7ggggweO8rc4arhf6rRkWuHKwvxGvejUXZZQ==}
dependencies:
@@ -8578,10 +8521,6 @@ packages:
resolution: {integrity: sha512-3yurQZ2hD9VISAhJJP9bpYFNQrHHBXE2JxxjY5aLEcDi46RmAzJE2OC9FAde0yis5ElW0jTTzs0zfg/Cca4XqQ==}
dev: true
/fflate@0.8.2:
resolution: {integrity: sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==}
dev: true
/file-entry-cache@6.0.1:
resolution: {integrity: sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==}
engines: {node: ^10.12.0 || >=12.0.0}
@@ -9145,10 +9084,6 @@ packages:
resolution: {integrity: sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==}
dev: true
/html-escaper@2.0.2:
resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==}
dev: true
/html-parse-stringify@3.0.1:
resolution: {integrity: sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==}
dependencies:
@@ -9578,39 +9513,6 @@ packages:
engines: {node: '>=0.10.0'}
dev: true
/istanbul-lib-coverage@3.2.2:
resolution: {integrity: sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==}
engines: {node: '>=8'}
dev: true
/istanbul-lib-report@3.0.1:
resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==}
engines: {node: '>=10'}
dependencies:
istanbul-lib-coverage: 3.2.2
make-dir: 4.0.0
supports-color: 7.2.0
dev: true
/istanbul-lib-source-maps@5.0.4:
resolution: {integrity: sha512-wHOoEsNJTVltaJp8eVkm8w+GVkVNHT2YDYo53YdzQEL2gWm1hBX5cGFR9hQJtuGLebidVX7et3+dmDZrmclduw==}
engines: {node: '>=10'}
dependencies:
'@jridgewell/trace-mapping': 0.3.25
debug: 4.3.4
istanbul-lib-coverage: 3.2.2
transitivePeerDependencies:
- supports-color
dev: true
/istanbul-reports@3.1.7:
resolution: {integrity: sha512-BewmUXImeuRk2YY0PVbxgKAysvhRPUQE0h5QRM++nVWyubKGV0l8qQ5op8+B2DOmwSe63Jivj0BjkPQVf8fP5g==}
engines: {node: '>=8'}
dependencies:
html-escaper: 2.0.2
istanbul-lib-report: 3.0.1
dev: true
/iterable-lookahead@1.0.0:
resolution: {integrity: sha512-hJnEP2Xk4+44DDwJqUQGdXal5VbyeWLaPyDl2AQc242Zr7iqz4DgpQOrEzglWVMGHMDCkguLHEKxd1+rOsmgSQ==}
engines: {node: '>=4'}
@@ -10010,14 +9912,6 @@ packages:
'@jridgewell/sourcemap-codec': 1.4.15
dev: true
/magicast@0.3.4:
resolution: {integrity: sha512-TyDF/Pn36bBji9rWKHlZe+PZb6Mx5V8IHCSxk7X4aljM4e/vyDvZZYwHewdVaqiA0nb3ghfHU/6AUpDxWoER2Q==}
dependencies:
'@babel/parser': 7.24.5
'@babel/types': 7.24.5
source-map-js: 1.2.0
dev: true
/make-dir@2.1.0:
resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==}
engines: {node: '>=6'}
@@ -10033,13 +9927,6 @@ packages:
semver: 6.3.1
dev: true
/make-dir@4.0.0:
resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==}
engines: {node: '>=10'}
dependencies:
semver: 7.6.0
dev: true
/map-obj@2.0.0:
resolution: {integrity: sha512-TzQSV2DiMYgoF5RycneKVUzIa9bQsj/B3tTgsE3dOGqlzHnGIDaC7XBE7grnA+8kZPnfqSGFe95VHc2oc0VFUQ==}
engines: {node: '>=4'}
@@ -10214,11 +10101,6 @@ packages:
resolution: {integrity: sha512-iSAJLHYKnX41mKcJKjqvnAN9sf0LMDTXDEvFv+ffuRR9a1MIuXLjMNL6EsnDHSkKLTWNqQQ5uo61P4EbU4NU+Q==}
dev: false
/mrmime@2.0.0:
resolution: {integrity: sha512-eu38+hdgojoyq63s+yTpN4XMBdt5l8HhMhc4VKLO9KM5caLIBvUm4thi7fFaxyTmCKeNnXZ5pAlBwCUnhA09uw==}
engines: {node: '>=10'}
dev: true
/ms@2.0.0:
resolution: {integrity: sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==}
dev: true
@@ -11884,15 +11766,6 @@ packages:
engines: {node: '>=14'}
dev: true
/sirv@2.0.4:
resolution: {integrity: sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==}
engines: {node: '>= 10'}
dependencies:
'@polka/url': 1.0.0-next.25
mrmime: 2.0.0
totalist: 3.0.1
dev: true
/sisteransi@1.0.5:
resolution: {integrity: sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==}
dev: true
@@ -12318,15 +12191,6 @@ packages:
unique-string: 2.0.0
dev: true
/test-exclude@6.0.0:
resolution: {integrity: sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w==}
engines: {node: '>=8'}
dependencies:
'@istanbuljs/schema': 0.1.3
glob: 7.2.3
minimatch: 3.1.2
dev: true
/text-table@0.2.0:
resolution: {integrity: sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==}
dev: true
@@ -12400,11 +12264,6 @@ packages:
engines: {node: '>=0.6'}
dev: true
/totalist@3.0.1:
resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==}
engines: {node: '>=6'}
dev: true
/tr46@0.0.3:
resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==}
@@ -12978,7 +12837,7 @@ packages:
fsevents: 2.3.3
dev: true
/vitest@1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0):
/vitest@1.6.0(@types/node@20.12.10):
resolution: {integrity: sha512-H5r/dN06swuFnzNFhq/dnz37bPXnq8xB2xB5JOVk8K09rUtoeNN+LHWkoQ0A/i3hvbUKKcCei9KpbxqHMLhLLA==}
engines: {node: ^18.0.0 || >=20.0.0}
hasBin: true
@@ -13008,7 +12867,6 @@ packages:
'@vitest/runner': 1.6.0
'@vitest/snapshot': 1.6.0
'@vitest/spy': 1.6.0
'@vitest/ui': 1.6.0(vitest@1.6.0)
'@vitest/utils': 1.6.0
acorn-walk: 8.3.2
chai: 4.4.1

View File

@@ -774,7 +774,6 @@
"cannotConnectOutputToOutput": "Cannot connect output to output",
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"nodePack": "Node pack",
"collection": "Collection",
"collectionFieldType": "{{name}} Collection",

View File

@@ -1,4 +1,3 @@
/* eslint-disable no-console */
import fs from 'node:fs';
import openapiTS from 'openapi-typescript';

View File

@@ -67,8 +67,6 @@ export const useSocketIO = () => {
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = $socketOptions;
// This is only enabled manually for debugging, console is allowed.
/* eslint-disable-next-line no-console */
console.log('Socket initialized', socket);
}
@@ -77,8 +75,6 @@ export const useSocketIO = () => {
return () => {
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = undefined;
// This is only enabled manually for debugging, console is allowed.
/* eslint-disable-next-line no-console */
console.log('Socket teardown', socket);
}
socket.disconnect();

View File

@@ -1,6 +1,3 @@
/* eslint-disable no-console */
// This is only enabled manually for debugging, console is allowed.
import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { diff } from 'jsondiffpatch';

View File

@@ -1,6 +1,7 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
@@ -24,6 +25,13 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {

View File

@@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const { canvas, nodes, controlAdapters, controlLayers } = getState();
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name);
if (imageUsage.isCanvasImage && !wasCanvasReset) {
dispatch(resetCanvas());

View File

@@ -148,6 +148,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
log.trace('Control Adapter preprocessor cancelled');
} else {
// Some other error condition...
console.log(error);
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {

View File

@@ -8,8 +8,8 @@ import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { buildCanvasGraph } from 'features/nodes/util/graph/buildCanvasGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';

View File

@@ -1,9 +1,9 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
import { queueApi } from 'services/api/endpoints/queue';
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
@@ -18,7 +18,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let graph;
if (model?.base === 'sdxl') {
if (model && model.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state);
} else {
graph = await buildGenerationTabGraph(state);

View File

@@ -11,9 +11,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { nodes, edges } = state.nodes.present;
const { nodes, edges } = state.nodes;
const workflow = state.workflow;
const graph = buildNodesGraph(state.nodes.present);
const graph = buildNodesGraph(state.nodes);
const builtWorkflow = buildWorkflowWithValidation({
nodes,
edges,

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { $templates } from 'features/nodes/store/nodesSlice';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { appInfoApi } from 'services/api/endpoints/appInfo';
@@ -9,7 +9,7 @@ import { appInfoApi } from 'services/api/endpoints/appInfo';
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
effect: (action, { getState }) => {
effect: (action, { dispatch, getState }) => {
const log = logger('system');
const schemaJSON = action.payload;
@@ -20,7 +20,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
$templates.set(nodeTemplates);
dispatch(nodeTemplatesBuilt(nodeTemplates));
},
});

View File

@@ -29,7 +29,7 @@ import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.present.nodes.forEach((node) => {
state.nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}

View File

@@ -1,8 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
const log = logger('socketio');
@@ -12,13 +9,6 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(action.payload, `Generator progress`);
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
}
},
});
};

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
@@ -10,9 +9,7 @@ import {
isImageViewerOpenChanged,
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { isImageOutput } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
@@ -31,7 +28,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
const { result, node, queue_batch_id, source_node_id } = data;
const { result, node, queue_batch_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
const { image_name } = result.image;
@@ -113,16 +110,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
}
}
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@@ -1,8 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationError } from 'services/events/actions';
const log = logger('socketio');
@@ -12,15 +9,6 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
actionCreator: socketInvocationError,
effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error;
nes.progress = null;
nes.progressImage = null;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@@ -1,8 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions';
const log = logger('socketio');
@@ -12,12 +9,6 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@@ -1,9 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
@@ -58,21 +54,6 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
dispatch(
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
);
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
}
const clone = deepClone(nes);
clone.status = zNodeStatus.enum.PENDING;
clone.error = null;
clone.progress = null;
clone.progressImage = null;
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
}
},
});
};

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
@@ -14,8 +14,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
actionCreator: updateAllNodesRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const { nodes } = getState().nodes.present;
const templates = $templates.get();
const { nodes, templates } = getState().nodes;
let unableToUpdateCount = 0;
@@ -25,7 +24,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
unableToUpdateCount++;
return;
}
if (!getNeedsUpdate(node.data, template)) {
if (!getNeedsUpdate(node, template)) {
// No need to increment the count here, since we're not actually updating
return;
}

View File

@@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
@@ -15,10 +14,10 @@ import { fromZodError } from 'zod-validation-error';
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => {
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const { workflow, asCopy } = action.payload;
const nodeTemplates = $templates.get();
const nodeTemplates = getState().nodes.templates;
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);

View File

@@ -21,8 +21,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
@@ -51,7 +50,7 @@ const allReducers = {
[canvasSlice.name]: canvasSlice.reducer,
[gallerySlice.name]: gallerySlice.reducer,
[generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[nodesSlice.name]: nodesSlice.reducer,
[postprocessingSlice.name]: postprocessingSlice.reducer,
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
@@ -67,7 +66,6 @@ const allReducers = {
[workflowSlice.name]: workflowSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[controlLayersSlice.name]: undoable(controlLayersSlice.reducer, controlLayersUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[api.reducerPath]: api.reducer,
};
@@ -113,7 +111,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@@ -1,4 +1,3 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
@@ -10,16 +9,13 @@ import { selectControlLayersSlice } from 'features/controlLayers/store/controlLa
import type { Layer } from 'features/controlLayers/store/types';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
@@ -29,208 +25,199 @@ const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
regional_guidance_layer: 'controlLayers.regionalGuidance',
};
const createSelector = (templates: Templates) =>
createMemoizedSelector(
[
selectControlAdaptersSlice,
selectGenerationSlice,
selectSystemSlice,
selectNodesSlice,
selectWorkflowSettingsSlice,
selectDynamicPromptsSlice,
selectControlLayersSlice,
activeTabNameSelector,
],
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present;
const selector = createMemoizedSelector(
[
selectControlAdaptersSlice,
selectGenerationSlice,
selectSystemSlice,
selectNodesSlice,
selectDynamicPromptsSlice,
selectControlLayersSlice,
activeTabNameSelector,
],
(controlAdapters, generation, system, nodes, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present;
const { isConnected } = system;
const { isConnected } = system;
const reasons: { prefix?: string; content: string }[] = [];
const reasons: { prefix?: string; content: string }[] = [];
// Cannot generate if not connected
if (!isConnected) {
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
}
// Cannot generate if not connected
if (!isConnected) {
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
}
if (activeTabName === 'workflows') {
if (workflowSettings.shouldValidateGraph) {
if (!nodes.nodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
if (activeTabName === 'workflows') {
if (nodes.shouldValidateGraph) {
if (!nodes.nodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
}
nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
const nodeTemplate = nodes.templates[node.data.type];
if (!nodeTemplate) {
// Node type not found
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
return;
}
const connectedEdges = getConnectedEdges([node], nodes.edges);
forEach(node.data.inputs, (field) => {
const fieldTemplate = nodeTemplate.inputs[field.name];
const hasConnection = connectedEdges.some(
(edge) => edge.target === node.id && edge.targetHandle === field.name
);
if (!fieldTemplate) {
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
return;
}
const nodeTemplate = templates[node.data.type];
if (!nodeTemplate) {
// Node type not found
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({
content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
}),
});
return;
}
const connectedEdges = getConnectedEdges([node], nodes.edges);
forEach(node.data.inputs, (field) => {
const fieldTemplate = nodeTemplate.inputs[field.name];
const hasConnection = connectedEdges.some(
(edge) => edge.target === node.id && edge.targetHandle === field.name
);
if (!fieldTemplate) {
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
return;
}
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({
content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
}),
});
return;
}
});
});
}
} else {
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
}
});
}
} else {
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
}
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (activeTabName === 'generation') {
// Handling for generation tab
controlLayers.present.layers
.filter((l) => l.isEnabled)
.forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
if (l.type === 'control_adapter_layer') {
// Must have model
if (!l.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (l.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// Must have a control image OR, if it has a processor, it must have a processed image
if (!l.controlAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
}
// T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
if (l.controlAdapter.type === 't2i_adapter') {
const multiple = model?.base === 'sdxl' ? 32 : 64;
if (size.width % multiple !== 0 || size.height % multiple !== 0) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
}
if (activeTabName === 'generation') {
// Handling for generation tab
controlLayers.present.layers
.filter((l) => l.isEnabled)
.forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
if (l.type === 'control_adapter_layer') {
// Must have model
if (!l.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (l.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// Must have a control image OR, if it has a processor, it must have a processed image
if (!l.controlAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
}
// T2I Adapters require images have dimensions that are multiples of 64
if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
}
if (l.type === 'ip_adapter_layer') {
if (l.type === 'ip_adapter_layer') {
// Must have model
if (!l.ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (l.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!l.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
}
if (l.type === 'initial_image_layer') {
// Must have an image
if (!l.image) {
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
}
}
if (l.type === 'regional_guidance_layer') {
// Must have a region
if (l.maskObjects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
l.ipAdapters.forEach((ipAdapter) => {
// Must have model
if (!l.ipAdapter.model) {
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (l.ipAdapter.model?.base !== model?.base) {
if (ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!l.ipAdapter.image) {
if (!ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
}
});
}
if (l.type === 'initial_image_layer') {
// Must have an image
if (!l.image) {
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
}
}
if (problems.length) {
const content = upperFirst(problems.join(', '));
reasons.push({ prefix, content });
}
});
} else {
// Handling for all other tabs
selectControlAdapterAll(controlAdapters)
.filter((ca) => ca.isEnabled)
.forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (l.type === 'regional_guidance_layer') {
// Must have a region
if (l.maskObjects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
l.ipAdapters.forEach((ipAdapter) => {
// Must have model
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
}
if (!ca.model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push({
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
});
}
if (problems.length) {
const content = upperFirst(problems.join(', '));
reasons.push({ prefix, content });
}
});
} else {
// Handling for all other tabs
selectControlAdapterAll(controlAdapters)
.filter((ca) => ca.isEnabled)
.forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (!ca.model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push({
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
});
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
) {
reasons.push({
content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }),
});
}
});
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
) {
reasons.push({ content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }) });
}
});
}
return { isReady: !reasons.length, reasons };
}
);
return { isReady: !reasons.length, reasons };
}
);
export const useIsReadyToEnqueue = () => {
const templates = useStore($templates);
const selector = useMemo(() => createSelector(templates), [templates]);
const value = useAppSelector(selector);
return value;
};

View File

@@ -5,7 +5,22 @@ import type {
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import type { components } from 'services/api/schema';
import type { Invocation } from 'services/api/types';
import type {
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
DepthAnythingImageProcessorInvocation,
DWOpenposeImageProcessorInvocation,
HedImageProcessorInvocation,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
MediapipeFaceProcessorInvocation,
MidasDepthImageProcessorInvocation,
MlsdImageProcessorInvocation,
NormalbaeImageProcessorInvocation,
PidiImageProcessorInvocation,
ZoeDepthImageProcessorInvocation,
} from 'services/api/types';
import type { O } from 'ts-toolbelt';
import { z } from 'zod';
@@ -13,20 +28,20 @@ import { z } from 'zod';
* Any ControlNet processor node
*/
export type ControlAdapterProcessorNode =
| Invocation<'canny_image_processor'>
| Invocation<'color_map_image_processor'>
| Invocation<'content_shuffle_image_processor'>
| Invocation<'depth_anything_image_processor'>
| Invocation<'hed_image_processor'>
| Invocation<'lineart_anime_image_processor'>
| Invocation<'lineart_image_processor'>
| Invocation<'mediapipe_face_processor'>
| Invocation<'midas_depth_image_processor'>
| Invocation<'mlsd_image_processor'>
| Invocation<'normalbae_image_processor'>
| Invocation<'dw_openpose_image_processor'>
| Invocation<'pidi_image_processor'>
| Invocation<'zoe_depth_image_processor'>;
| CannyImageProcessorInvocation
| ColorMapImageProcessorInvocation
| ContentShuffleImageProcessorInvocation
| DepthAnythingImageProcessorInvocation
| HedImageProcessorInvocation
| LineartAnimeImageProcessorInvocation
| LineartImageProcessorInvocation
| MediapipeFaceProcessorInvocation
| MidasDepthImageProcessorInvocation
| MlsdImageProcessorInvocation
| NormalbaeImageProcessorInvocation
| DWOpenposeImageProcessorInvocation
| PidiImageProcessorInvocation
| ZoeDepthImageProcessorInvocation;
/**
* Any ControlNet processor type
@@ -56,7 +71,7 @@ export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterPr
* The Canny processor node, with parameters flagged as required
*/
export type RequiredCannyImageProcessorInvocation = O.Required<
Invocation<'canny_image_processor'>,
CannyImageProcessorInvocation,
'type' | 'low_threshold' | 'high_threshold' | 'image_resolution' | 'detect_resolution'
>;
@@ -64,7 +79,7 @@ export type RequiredCannyImageProcessorInvocation = O.Required<
* The Color Map processor node, with parameters flagged as required
*/
export type RequiredColorMapImageProcessorInvocation = O.Required<
Invocation<'color_map_image_processor'>,
ColorMapImageProcessorInvocation,
'type' | 'color_map_tile_size'
>;
@@ -72,7 +87,7 @@ export type RequiredColorMapImageProcessorInvocation = O.Required<
* The ContentShuffle processor node, with parameters flagged as required
*/
export type RequiredContentShuffleImageProcessorInvocation = O.Required<
Invocation<'content_shuffle_image_processor'>,
ContentShuffleImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f'
>;
@@ -80,7 +95,7 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
* The DepthAnything processor node, with parameters flagged as required
*/
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
Invocation<'depth_anything_image_processor'>,
DepthAnythingImageProcessorInvocation,
'type' | 'model_size' | 'resolution' | 'offload'
>;
@@ -93,7 +108,7 @@ export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSiz
* The HED processor node, with parameters flagged as required
*/
export type RequiredHedImageProcessorInvocation = O.Required<
Invocation<'hed_image_processor'>,
HedImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'scribble'
>;
@@ -101,7 +116,7 @@ export type RequiredHedImageProcessorInvocation = O.Required<
* The Lineart Anime processor node, with parameters flagged as required
*/
export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
Invocation<'lineart_anime_image_processor'>,
LineartAnimeImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution'
>;
@@ -109,7 +124,7 @@ export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
* The Lineart processor node, with parameters flagged as required
*/
export type RequiredLineartImageProcessorInvocation = O.Required<
Invocation<'lineart_image_processor'>,
LineartImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'coarse'
>;
@@ -117,7 +132,7 @@ export type RequiredLineartImageProcessorInvocation = O.Required<
* The MediapipeFace processor node, with parameters flagged as required
*/
export type RequiredMediapipeFaceProcessorInvocation = O.Required<
Invocation<'mediapipe_face_processor'>,
MediapipeFaceProcessorInvocation,
'type' | 'max_faces' | 'min_confidence' | 'image_resolution' | 'detect_resolution'
>;
@@ -125,7 +140,7 @@ export type RequiredMediapipeFaceProcessorInvocation = O.Required<
* The MidasDepth processor node, with parameters flagged as required
*/
export type RequiredMidasDepthImageProcessorInvocation = O.Required<
Invocation<'midas_depth_image_processor'>,
MidasDepthImageProcessorInvocation,
'type' | 'a_mult' | 'bg_th' | 'image_resolution' | 'detect_resolution'
>;
@@ -133,7 +148,7 @@ export type RequiredMidasDepthImageProcessorInvocation = O.Required<
* The MLSD processor node, with parameters flagged as required
*/
export type RequiredMlsdImageProcessorInvocation = O.Required<
Invocation<'mlsd_image_processor'>,
MlsdImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'thr_v' | 'thr_d'
>;
@@ -141,7 +156,7 @@ export type RequiredMlsdImageProcessorInvocation = O.Required<
* The NormalBae processor node, with parameters flagged as required
*/
export type RequiredNormalbaeImageProcessorInvocation = O.Required<
Invocation<'normalbae_image_processor'>,
NormalbaeImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution'
>;
@@ -149,7 +164,7 @@ export type RequiredNormalbaeImageProcessorInvocation = O.Required<
* The DW Openpose processor node, with parameters flagged as required
*/
export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
Invocation<'dw_openpose_image_processor'>,
DWOpenposeImageProcessorInvocation,
'type' | 'image_resolution' | 'draw_body' | 'draw_face' | 'draw_hands'
>;
@@ -157,14 +172,14 @@ export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
* The Pidi processor node, with parameters flagged as required
*/
export type RequiredPidiImageProcessorInvocation = O.Required<
Invocation<'pidi_image_processor'>,
PidiImageProcessorInvocation,
'type' | 'detect_resolution' | 'image_resolution' | 'safe' | 'scribble'
>;
/**
* The ZoeDepth processor node, with parameters flagged as required
*/
export type RequiredZoeDepthImageProcessorInvocation = O.Required<Invocation<'zoe_depth_image_processor'>, 'type'>;
export type RequiredZoeDepthImageProcessorInvocation = O.Required<ZoeDepthImageProcessorInvocation, 'type'>;
/**
* Any ControlNet Processor node, with its parameters flagged as required

View File

@@ -1,9 +1,23 @@
import type { Invocation } from 'services/api/types';
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type {
_CannyProcessorConfig,
_ColorMapProcessorConfig,
_ContentShuffleProcessorConfig,
_DepthAnythingProcessorConfig,
_DWOpenposeProcessorConfig,
_HedProcessorConfig,
_LineartAnimeProcessorConfig,
_LineartProcessorConfig,
_MediapipeFaceProcessorConfig,
_MidasDepthProcessorConfig,
_MlsdProcessorConfig,
_NormalbaeProcessorConfig,
_PidiProcessorConfig,
_ZoeDepthProcessorConfig,
CannyProcessorConfig,
CLIPVisionModelV2,
ColorMapProcessorConfig,
@@ -31,16 +45,16 @@ describe('Control Adapter Types', () => {
assert<Equals<ProcessorConfig['type'], ProcessorTypeV2>>();
});
test('IP Adapter Method', () => {
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();
assert<Equals<NonNullable<S['IPAdapterInvocation']['method']>, IPMethodV2>>();
});
test('CLIP Vision Model', () => {
assert<Equals<NonNullable<Invocation<'ip_adapter'>['clip_vision_model']>, CLIPVisionModelV2>>();
assert<Equals<NonNullable<S['IPAdapterInvocation']['clip_vision_model']>, CLIPVisionModelV2>>();
});
test('Control Mode', () => {
assert<Equals<NonNullable<Invocation<'controlnet'>['control_mode']>, ControlModeV2>>();
assert<Equals<NonNullable<S['ControlNetInvocation']['control_mode']>, ControlModeV2>>();
});
test('DepthAnything Model Size', () => {
assert<Equals<NonNullable<Invocation<'depth_anything_image_processor'>['model_size']>, DepthAnythingModelSize>>();
assert<Equals<NonNullable<S['DepthAnythingImageProcessorInvocation']['model_size']>, DepthAnythingModelSize>>();
});
test('Processor Configs', () => {
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
@@ -61,33 +75,3 @@ describe('Control Adapter Types', () => {
assert<Equals<_ZoeDepthProcessorConfig, ZoeDepthProcessorConfig>>();
});
});
// Types derived from OpenAPI
type _CannyProcessorConfig = Required<
Pick<Invocation<'canny_image_processor'>, 'id' | 'type' | 'low_threshold' | 'high_threshold'>
>;
type _ColorMapProcessorConfig = Required<
Pick<Invocation<'color_map_image_processor'>, 'id' | 'type' | 'color_map_tile_size'>
>;
type _ContentShuffleProcessorConfig = Required<
Pick<Invocation<'content_shuffle_image_processor'>, 'id' | 'type' | 'w' | 'h' | 'f'>
>;
type _DepthAnythingProcessorConfig = Required<
Pick<Invocation<'depth_anything_image_processor'>, 'id' | 'type' | 'model_size'>
>;
type _HedProcessorConfig = Required<Pick<Invocation<'hed_image_processor'>, 'id' | 'type' | 'scribble'>>;
type _LineartAnimeProcessorConfig = Required<Pick<Invocation<'lineart_anime_image_processor'>, 'id' | 'type'>>;
type _LineartProcessorConfig = Required<Pick<Invocation<'lineart_image_processor'>, 'id' | 'type' | 'coarse'>>;
type _MediapipeFaceProcessorConfig = Required<
Pick<Invocation<'mediapipe_face_processor'>, 'id' | 'type' | 'max_faces' | 'min_confidence'>
>;
type _MidasDepthProcessorConfig = Required<
Pick<Invocation<'midas_depth_image_processor'>, 'id' | 'type' | 'a_mult' | 'bg_th'>
>;
type _MlsdProcessorConfig = Required<Pick<Invocation<'mlsd_image_processor'>, 'id' | 'type' | 'thr_v' | 'thr_d'>>;
type _NormalbaeProcessorConfig = Required<Pick<Invocation<'normalbae_image_processor'>, 'id' | 'type'>>;
type _DWOpenposeProcessorConfig = Required<
Pick<Invocation<'dw_openpose_image_processor'>, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>;
type _PidiProcessorConfig = Required<Pick<Invocation<'pidi_image_processor'>, 'id' | 'type' | 'safe' | 'scribble'>>;
type _ZoeDepthProcessorConfig = Required<Pick<Invocation<'zoe_depth_image_processor'>, 'id' | 'type'>>;

View File

@@ -1,7 +1,27 @@
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, omit } from 'lodash-es';
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import type {
BaseModelType,
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
ControlNetModelConfig,
DepthAnythingImageProcessorInvocation,
DWOpenposeImageProcessorInvocation,
Graph,
HedImageProcessorInvocation,
ImageDTO,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
MediapipeFaceProcessorInvocation,
MidasDepthImageProcessorInvocation,
MlsdImageProcessorInvocation,
NormalbaeImageProcessorInvocation,
PidiImageProcessorInvocation,
T2IAdapterModelConfig,
ZoeDepthImageProcessorInvocation,
} from 'services/api/types';
import { z } from 'zod';
const zId = z.string().min(1);
@@ -12,6 +32,9 @@ const zCannyProcessorConfig = z.object({
low_threshold: z.number().int().gte(0).lte(255),
high_threshold: z.number().int().gte(0).lte(255),
});
export type _CannyProcessorConfig = Required<
Pick<CannyImageProcessorInvocation, 'id' | 'type' | 'low_threshold' | 'high_threshold'>
>;
export type CannyProcessorConfig = z.infer<typeof zCannyProcessorConfig>;
const zColorMapProcessorConfig = z.object({
@@ -19,6 +42,9 @@ const zColorMapProcessorConfig = z.object({
type: z.literal('color_map_image_processor'),
color_map_tile_size: z.number().int().gte(1),
});
export type _ColorMapProcessorConfig = Required<
Pick<ColorMapImageProcessorInvocation, 'id' | 'type' | 'color_map_tile_size'>
>;
export type ColorMapProcessorConfig = z.infer<typeof zColorMapProcessorConfig>;
const zContentShuffleProcessorConfig = z.object({
@@ -28,6 +54,9 @@ const zContentShuffleProcessorConfig = z.object({
h: z.number().int().gte(0),
f: z.number().int().gte(0),
});
export type _ContentShuffleProcessorConfig = Required<
Pick<ContentShuffleImageProcessorInvocation, 'id' | 'type' | 'w' | 'h' | 'f'>
>;
export type ContentShuffleProcessorConfig = z.infer<typeof zContentShuffleProcessorConfig>;
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
@@ -39,6 +68,9 @@ const zDepthAnythingProcessorConfig = z.object({
type: z.literal('depth_anything_image_processor'),
model_size: zDepthAnythingModelSize,
});
export type _DepthAnythingProcessorConfig = Required<
Pick<DepthAnythingImageProcessorInvocation, 'id' | 'type' | 'model_size'>
>;
export type DepthAnythingProcessorConfig = z.infer<typeof zDepthAnythingProcessorConfig>;
const zHedProcessorConfig = z.object({
@@ -46,12 +78,14 @@ const zHedProcessorConfig = z.object({
type: z.literal('hed_image_processor'),
scribble: z.boolean(),
});
export type _HedProcessorConfig = Required<Pick<HedImageProcessorInvocation, 'id' | 'type' | 'scribble'>>;
export type HedProcessorConfig = z.infer<typeof zHedProcessorConfig>;
const zLineartAnimeProcessorConfig = z.object({
id: zId,
type: z.literal('lineart_anime_image_processor'),
});
export type _LineartAnimeProcessorConfig = Required<Pick<LineartAnimeImageProcessorInvocation, 'id' | 'type'>>;
export type LineartAnimeProcessorConfig = z.infer<typeof zLineartAnimeProcessorConfig>;
const zLineartProcessorConfig = z.object({
@@ -59,6 +93,7 @@ const zLineartProcessorConfig = z.object({
type: z.literal('lineart_image_processor'),
coarse: z.boolean(),
});
export type _LineartProcessorConfig = Required<Pick<LineartImageProcessorInvocation, 'id' | 'type' | 'coarse'>>;
export type LineartProcessorConfig = z.infer<typeof zLineartProcessorConfig>;
const zMediapipeFaceProcessorConfig = z.object({
@@ -67,6 +102,9 @@ const zMediapipeFaceProcessorConfig = z.object({
max_faces: z.number().int().gte(1),
min_confidence: z.number().gte(0).lte(1),
});
export type _MediapipeFaceProcessorConfig = Required<
Pick<MediapipeFaceProcessorInvocation, 'id' | 'type' | 'max_faces' | 'min_confidence'>
>;
export type MediapipeFaceProcessorConfig = z.infer<typeof zMediapipeFaceProcessorConfig>;
const zMidasDepthProcessorConfig = z.object({
@@ -75,6 +113,9 @@ const zMidasDepthProcessorConfig = z.object({
a_mult: z.number().gte(0),
bg_th: z.number().gte(0),
});
export type _MidasDepthProcessorConfig = Required<
Pick<MidasDepthImageProcessorInvocation, 'id' | 'type' | 'a_mult' | 'bg_th'>
>;
export type MidasDepthProcessorConfig = z.infer<typeof zMidasDepthProcessorConfig>;
const zMlsdProcessorConfig = z.object({
@@ -83,12 +124,14 @@ const zMlsdProcessorConfig = z.object({
thr_v: z.number().gte(0),
thr_d: z.number().gte(0),
});
export type _MlsdProcessorConfig = Required<Pick<MlsdImageProcessorInvocation, 'id' | 'type' | 'thr_v' | 'thr_d'>>;
export type MlsdProcessorConfig = z.infer<typeof zMlsdProcessorConfig>;
const zNormalbaeProcessorConfig = z.object({
id: zId,
type: z.literal('normalbae_image_processor'),
});
export type _NormalbaeProcessorConfig = Required<Pick<NormalbaeImageProcessorInvocation, 'id' | 'type'>>;
export type NormalbaeProcessorConfig = z.infer<typeof zNormalbaeProcessorConfig>;
const zDWOpenposeProcessorConfig = z.object({
@@ -98,6 +141,9 @@ const zDWOpenposeProcessorConfig = z.object({
draw_face: z.boolean(),
draw_hands: z.boolean(),
});
export type _DWOpenposeProcessorConfig = Required<
Pick<DWOpenposeImageProcessorInvocation, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>;
export type DWOpenposeProcessorConfig = z.infer<typeof zDWOpenposeProcessorConfig>;
const zPidiProcessorConfig = z.object({
@@ -106,12 +152,14 @@ const zPidiProcessorConfig = z.object({
safe: z.boolean(),
scribble: z.boolean(),
});
export type _PidiProcessorConfig = Required<Pick<PidiImageProcessorInvocation, 'id' | 'type' | 'safe' | 'scribble'>>;
export type PidiProcessorConfig = z.infer<typeof zPidiProcessorConfig>;
const zZoeDepthProcessorConfig = z.object({
id: zId,
type: z.literal('zoe_depth_image_processor'),
});
export type _ZoeDepthProcessorConfig = Required<Pick<ZoeDepthImageProcessorInvocation, 'id' | 'type'>>;
export type ZoeDepthProcessorConfig = z.infer<typeof zZoeDepthProcessorConfig>;
const zProcessorConfig = z.discriminatedUnion('type', [

View File

@@ -1,21 +1,23 @@
import type { Modifier } from '@dnd-kit/core';
import { getEventCoordinates } from '@dnd-kit/utilities';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $viewport } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
const selectZoom = createSelector([selectNodesSlice, activeTabNameSelector], (nodes, activeTabName) =>
activeTabName === 'workflows' ? nodes.viewport.zoom : 1
);
/**
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
*/
export const useScaledModifer = () => {
const activeTabName = useAppSelector(activeTabNameSelector);
const workflowsViewport = useStore($viewport);
const zoom = useAppSelector(selectZoom);
const modifier: Modifier = useCallback(
({ activatorEvent, draggingNodeRect, transform }) => {
if (draggingNodeRect && activatorEvent) {
const zoom = activeTabName === 'workflows' ? workflowsViewport.zoom : 1;
const activatorCoordinates = getEventCoordinates(activatorEvent);
if (!activatorCoordinates) {
@@ -40,7 +42,7 @@ export const useScaledModifer = () => {
return transform;
},
[activeTabName, workflowsViewport.zoom]
[zoom]
);
return modifier;

View File

@@ -75,8 +75,8 @@ export const LoRACard = memo((props: LoRACardProps) => {
<CompositeNumberInput
value={lora.weight}
onChange={handleChange}
min={-10}
max={10}
min={-5}
max={5}
step={0.01}
w={20}
flexShrink={0}

View File

@@ -2,34 +2,26 @@ import 'reactflow/dist/style.css';
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
$cursorPos,
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
closeAddNodePopover,
connectionMade,
addNodePopoverClosed,
addNodePopoverOpened,
nodeAdded,
openAddNodePopover,
selectNodesSlice,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { filter, map, memoize, some } from 'lodash-es';
import type { KeyboardEventHandler } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { memo, useCallback, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import { assert } from 'tsafe';
const createRegex = memoize(
(inputValue: string) =>
@@ -62,30 +54,26 @@ const AddNodePopover = () => {
const { t } = useTranslation();
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
const inputRef = useRef<HTMLInputElement>(null);
const templates = useStore($templates);
const pendingConnection = useStore($pendingConnection);
const isOpen = useStore($isAddNodePopoverOpen);
const store = useAppStore();
const filteredTemplates = useMemo(() => {
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
// If we have a connection in progress, we need to filter the node choices
if (!pendingConnection) {
return map(templates);
}
const filteredNodeTemplates = fieldFilter
? filter(nodes.templates, (template) => {
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
return filter(templates, (template) => {
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
});
}, [templates, pendingConnection]);
return some(handles, (handle) => {
const sourceType = handleFilter === 'source' ? fieldFilter : handle.type;
const targetType = handleFilter === 'target' ? fieldFilter : handle.type;
const options = useMemo(() => {
const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
return validateSourceAndTargetTypes(sourceType, targetType);
});
})
: map(nodes.templates);
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
return {
label: template.title,
value: template.type,
@@ -95,15 +83,15 @@ const AddNodePopover = () => {
});
//We only want these nodes if we're not filtered
if (!pendingConnection) {
_options.push({
if (fieldFilter === null) {
options.push({
label: t('nodes.currentImage'),
value: 'current_image',
description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
_options.push({
options.push({
label: t('nodes.notes'),
value: 'notes',
description: t('nodes.notesDescription'),
@@ -111,15 +99,18 @@ const AddNodePopover = () => {
});
}
_options.sort((a, b) => a.label.localeCompare(b.label));
options.sort((a, b) => a.label.localeCompare(b.label));
return _options;
}, [filteredTemplates, pendingConnection, t]);
return { options };
});
const { options } = useAppSelector(selector);
const isOpen = useAppSelector((s) => s.nodes.isAddNodePopoverOpen);
const addNode = useCallback(
(nodeType: string): AnyNode | null => {
const node = buildInvocation(nodeType);
if (!node) {
(nodeType: string) => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
const errorMessage = t('nodes.unknownNode', {
nodeType: nodeType,
});
@@ -127,11 +118,10 @@ const AddNodePopover = () => {
status: 'error',
title: errorMessage,
});
return null;
return;
}
const cursorPos = $cursorPos.get();
dispatch(nodeAdded({ node, cursorPos }));
return node;
dispatch(nodeAdded(invocation));
},
[dispatch, buildInvocation, toaster, t]
);
@@ -141,50 +131,52 @@ const AddNodePopover = () => {
if (!v) {
return;
}
const node = addNode(v.value);
// Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) {
const template = templates[node.data.type];
assert(template, 'Template not found');
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
if (connection) {
dispatch(connectionMade(connection));
}
}
closeAddNodePopover();
addNode(v.value);
dispatch(addNodePopoverClosed());
},
[addNode, dispatch, pendingConnection, store, templates]
[addNode, dispatch]
);
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
e.preventDefault();
openAddNodePopover();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
}, []);
const onClose = useCallback(() => {
dispatch(addNodePopoverClosed());
}, [dispatch]);
const onOpen = useCallback(() => {
dispatch(addNodePopoverOpened());
}, [dispatch]);
const handleHotkeyOpen: HotkeyCallback = useCallback(
(e) => {
e.preventDefault();
onOpen();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
},
[onOpen]
);
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
closeAddNodePopover();
}, []);
onClose();
}, [onClose]);
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
useHotkeys(['escape'], handleHotkeyClose);
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
if (e.key === 'Escape') {
closeAddNodePopover();
}
}, []);
const onKeyDown: KeyboardEventHandler = useCallback(
(e) => {
if (e.key === 'Escape') {
onClose();
}
},
[onClose]
);
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
return (
<Popover
isOpen={isOpen}
onClose={closeAddNodePopover}
onClose={onClose}
placement="bottom"
openDelay={0}
closeDelay={0}
@@ -214,7 +206,7 @@ const AddNodePopover = () => {
noOptionsMessage={noOptionsMessage}
filterOption={filterOption}
onChange={onChange}
onMenuClose={closeAddNodePopover}
onMenuClose={onClose}
onKeyDown={onKeyDown}
inputRef={inputRef}
closeMenuOnSelect={false}

View File

@@ -1,33 +1,34 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import {
$cursorPos,
$isAddNodePopoverOpen,
$isUpdatingEdge,
$pendingConnection,
$viewport,
connectionEnded,
connectionMade,
connectionStarted,
edgeAdded,
edgeChangeStarted,
edgeDeleted,
edgesChanged,
edgesDeleted,
nodesChanged,
nodesDeleted,
redo,
selectedAll,
undo,
selectedEdgesChanged,
selectedNodesChanged,
selectionCopied,
selectionPasted,
viewportChanged,
} from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import type {
OnConnect,
OnConnectEnd,
OnConnectStart,
OnEdgesChange,
OnEdgesDelete,
OnEdgeUpdateFunc,
@@ -35,11 +36,12 @@ import type {
OnMoveEnd,
OnNodesChange,
OnNodesDelete,
OnSelectionChangeFunc,
ProOptions,
ReactFlowProps,
ReactFlowState,
XYPosition,
} from 'reactflow';
import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow';
import { Background, ReactFlow } from 'reactflow';
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
@@ -66,23 +68,17 @@ const proOptions: ProOptions = { hideAttribution: true };
const snapGrid: [number, number] = [25, 25];
const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection;
export const Flow = memo(() => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((s) => s.nodes.present.nodes);
const edges = useAppSelector((s) => s.nodes.present.edges);
const viewport = useStore($viewport);
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
const nodes = useAppSelector((s) => s.nodes.nodes);
const edges = useAppSelector((s) => s.nodes.edges);
const viewport = useAppSelector((s) => s.nodes.viewport);
const shouldSnapToGrid = useAppSelector((s) => s.nodes.shouldSnapToGrid);
const selectionMode = useAppSelector((s) => s.nodes.selectionMode);
const flowWrapper = useRef<HTMLDivElement>(null);
const cursorPosition = useRef<XYPosition | null>(null);
const isValidConnection = useIsValidConnection();
const cancelConnection = useReactFlowStore(selectCancelConnection);
useWorkflowWatcher();
useSyncExecutionState();
const [borderRadius] = useToken('radii', ['base']);
const flowStyles = useMemo<CSSProperties>(
@@ -106,6 +102,32 @@ export const Flow = memo(() => {
[dispatch]
);
const onConnectStart: OnConnectStart = useCallback(
(event, params) => {
dispatch(connectionStarted(params));
},
[dispatch]
);
const onConnect: OnConnect = useCallback(
(connection) => {
dispatch(connectionMade(connection));
},
[dispatch]
);
const onConnectEnd: OnConnectEnd = useCallback(() => {
if (!cursorPosition.current) {
return;
}
dispatch(
connectionEnded({
cursorPosition: cursorPosition.current,
mouseOverNodeId: $mouseOverNode.get(),
})
);
}, [dispatch]);
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
@@ -120,9 +142,20 @@ export const Flow = memo(() => {
[dispatch]
);
const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => {
$viewport.set(viewport);
}, []);
const handleSelectionChange: OnSelectionChangeFunc = useCallback(
({ nodes, edges }) => {
dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
},
[dispatch]
);
const handleMoveEnd: OnMoveEnd = useCallback(
(e, viewport) => {
dispatch(viewportChanged(viewport));
},
[dispatch]
);
const { onCloseGlobal } = useGlobalMenuClose();
const handlePaneClick = useCallback(() => {
@@ -136,12 +169,11 @@ export const Flow = memo(() => {
const onMouseMove = useCallback((event: MouseEvent<HTMLDivElement>) => {
if (flowWrapper.current?.getBoundingClientRect()) {
$cursorPos.set(
cursorPosition.current =
$flow.get()?.screenToFlowPosition({
x: event.clientX,
y: event.clientY,
}) ?? null
);
}) ?? null;
}
}, []);
@@ -163,18 +195,19 @@ export const Flow = memo(() => {
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback(
(e, edge, _handleType) => {
$isUpdatingEdge.set(true);
// update mouse event
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
dispatch(edgeChangeStarted());
},
[dispatch]
);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(_oldEdge, newConnection) => {
// Because we deleted the edge when the update started, we must create a new edge from the connection
// instead of updating the edge (we deleted it earlier), we instead create
// a new one.
dispatch(connectionMade(newConnection));
},
[dispatch]
@@ -182,10 +215,8 @@ export const Flow = memo(() => {
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> = useCallback(
(e, edge, _handleType) => {
$isUpdatingEdge.set(false);
$pendingConnection.set(null);
// Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting
// the edge update - we need to add it back
// Handle the case where user begins a drag but didn't move the cursor -
// bc we deleted the edge, we need to add it back
if (
// ignore touch events
!('touches' in e) &&
@@ -202,64 +233,23 @@ export const Flow = memo(() => {
// #endregion
const { copySelection, pasteSelection } = useCopyPaste();
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault();
dispatch(selectionCopied());
});
const onCopyHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
copySelection();
},
[copySelection]
);
useHotkeys(['Ctrl+c', 'Meta+c'], onCopyHotkey);
useHotkeys(['Ctrl+a', 'Meta+a'], (e) => {
e.preventDefault();
dispatch(selectedAll());
});
const onSelectAllHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
dispatch(selectedAll());
},
[dispatch]
);
useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey);
const onPasteHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
pasteSelection();
},
[pasteSelection]
);
useHotkeys(['Ctrl+v', 'Meta+v'], onPasteHotkey);
const onPasteWithEdgesToNodesHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
pasteSelection(true);
},
[pasteSelection]
);
useHotkeys(['Ctrl+shift+v', 'Meta+shift+v'], onPasteWithEdgesToNodesHotkey);
const onUndoHotkey = useCallback(() => {
if (mayUndo) {
dispatch(undo());
useHotkeys(['Ctrl+v', 'Meta+v'], (e) => {
if (!cursorPosition.current) {
return;
}
}, [dispatch, mayUndo]);
useHotkeys(['meta+z', 'ctrl+z'], onUndoHotkey);
const onRedoHotkey = useCallback(() => {
if (mayRedo) {
dispatch(redo());
}
}, [dispatch, mayRedo]);
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
const onEscapeHotkey = useCallback(() => {
$pendingConnection.set(null);
$isAddNodePopoverOpen.set(false);
cancelConnection();
}, [cancelConnection]);
useHotkeys('esc', onEscapeHotkey);
e.preventDefault();
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
});
return (
<ReactFlow
@@ -284,6 +274,7 @@ export const Flow = memo(() => {
onConnectEnd={onConnectEnd}
onMoveEnd={handleMoveEnd}
connectionLineComponent={CustomConnectionLine}
onSelectionChange={handleSelectionChange}
isValidConnection={isValidConnection}
minZoom={0.1}
snapToGrid={shouldSnapToGrid}
@@ -294,7 +285,6 @@ export const Flow = memo(() => {
onPaneClick={handlePaneClick}
deleteKeyCode={DELETE_KEYS}
selectionMode={selectionMode}
elevateEdgesOnSelect
>
<Background />
</ReactFlow>

View File

@@ -1,33 +1,26 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { $pendingConnection } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import type { ConnectionLineComponentProps } from 'reactflow';
import { getBezierPath } from 'reactflow';
const selectStroke = createSelector(selectNodesSlice, (nodes) =>
nodes.shouldColorEdges ? getFieldColor(nodes.connectionStartFieldType) : colorTokenToCssVar('base.500')
);
const selectClassName = createSelector(selectNodesSlice, (nodes) =>
nodes.shouldAnimateEdges ? 'react-flow__custom_connection-path animated' : 'react-flow__custom_connection-path'
);
const pathStyles: CSSProperties = { opacity: 0.8 };
const CustomConnectionLine = ({ fromX, fromY, fromPosition, toX, toY, toPosition }: ConnectionLineComponentProps) => {
const pendingConnection = useStore($pendingConnection);
const shouldColorEdges = useAppSelector((state) => state.workflowSettings.shouldColorEdges);
const shouldAnimateEdges = useAppSelector((state) => state.workflowSettings.shouldAnimateEdges);
const stroke = useMemo(() => {
if (shouldColorEdges && pendingConnection) {
return getFieldColor(pendingConnection.fieldTemplate.type);
} else {
return colorTokenToCssVar('base.500');
}
}, [pendingConnection, shouldColorEdges]);
const className = useMemo(() => {
if (shouldAnimateEdges) {
return 'react-flow__custom_connection-path animated';
} else {
return 'react-flow__custom_connection-path';
}
}, [shouldAnimateEdges]);
const stroke = useAppSelector(selectStroke);
const className = useAppSelector(selectClassName);
const pathParams = {
sourceX: fromX,

View File

@@ -1,8 +1,6 @@
import { Badge, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { $templates } from 'features/nodes/store/nodesSlice';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
@@ -24,10 +22,9 @@ const InvocationCollapsedEdge = ({
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, selected, source, sourceHandleId, target, targetHandleId]
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected),
[selected, source, sourceHandleId, target, targetHandleId]
);
const { isSelected, shouldAnimate } = useAppSelector(selector);

View File

@@ -1,7 +1,5 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
@@ -23,14 +21,13 @@ const InvocationDefaultEdge = ({
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, source, sourceHandleId, target, targetHandleId, selected]
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels);
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,

View File

@@ -1,8 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
@@ -15,7 +14,6 @@ const defaultReturnValue = {
};
export const makeEdgeSelector = (
templates: Templates,
source: string,
sourceHandleId: string | null | undefined,
target: string,
@@ -24,8 +22,7 @@ export const makeEdgeSelector = (
) =>
createMemoizedSelector(
selectNodesSlice,
selectWorkflowSettingsSlice,
(nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
@@ -36,20 +33,19 @@ export const makeEdgeSelector = (
return defaultReturnValue;
}
const sourceNodeTemplate = templates[sourceNode.data.type];
const targetNodeTemplate = templates[targetNode.data.type];
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke =
sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return {
isSelected,
shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
label,
};

View File

@@ -1,10 +1,12 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi';
@@ -22,7 +24,12 @@ const circleStyles: SystemStyleObject = {
};
const InvocationNodeStatusIndicator = ({ nodeId }: Props) => {
const nodeExecutionState = useExecutionState(nodeId);
const selectNodeExecutionState = useMemo(
() => createMemoizedSelector(selectNodesSlice, (nodes) => nodes.nodeExecutionStates[nodeId]),
[nodeId]
);
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
if (!nodeExecutionState) {
return null;

View File

@@ -1,7 +1,7 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import type { NodeProps } from 'reactflow';
@@ -11,13 +11,13 @@ import InvocationNodeUnknownFallback from './InvocationNodeUnknownFallback';
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { data, selected } = props;
const { id: nodeId, type, isOpen, label } = data;
const templates = useStore($templates);
const hasTemplate = useMemo(() => Boolean(templates[type]), [templates, type]);
const nodeExists = useAppSelector((s) => Boolean(s.nodes.present.nodes.find((n) => n.id === nodeId)));
if (!nodeExists) {
return null;
}
const hasTemplateSelector = useMemo(
() => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])),
[type]
);
const hasTemplate = useAppSelector(hasTemplateSelector);
if (!hasTemplate) {
return (

View File

@@ -37,8 +37,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
const [localTitle, setLocalTitle] = useState(label || fieldTemplateTitle || t('nodes.unknownField'));
const handleSubmit = useCallback(
async (newTitleRaw: string) => {
const newTitle = newTitleRaw.trim();
async (newTitle: string) => {
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
@@ -58,22 +57,22 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
}, [label, fieldTemplateTitle, t]);
return (
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
ref={ref}
position="relative"
overflow="hidden"
alignItems="center"
justifyContent="flex-start"
gap={1}
w="full"
<Tooltip
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Tooltip
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
ref={ref}
position="relative"
overflow="hidden"
alignItems="center"
justifyContent="flex-start"
gap={1}
w="full"
>
<EditablePreview
fontWeight="semibold"
@@ -81,10 +80,10 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
noOfLines={1}
color={isMissingInput ? 'error.300' : 'base.300'}
/>
</Tooltip>
<EditableInput className="nodrag" sx={editableInputStyles} />
<EditableControls />
</Editable>
<EditableInput className="nodrag" sx={editableInputStyles} />
<EditableControls />
</Editable>
</Tooltip>
);
});
@@ -128,15 +127,7 @@ const EditableControls = memo(() => {
}
return (
<Flex
onClick={handleClick}
position="absolute"
w="min-content"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
<Flex onClick={handleClick} position="absolute" w="full" h="full" top={0} insetInlineStart={0} cursor="text" />
);
});

View File

@@ -69,7 +69,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
);
}
if (fieldTemplate.input === 'connection' || isConnected) {
if (fieldTemplate.input === 'connection') {
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl isInvalid={isMissingInput} isDisabled={isConnected} px={2}>
@@ -95,15 +95,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
isInvalid={isMissingInput}
isDisabled={isConnected}
// Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
// connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
pointerEvents={isConnected ? 'none' : 'auto'}
orientation="vertical"
px={2}
>
<FormControl isInvalid={isMissingInput} isDisabled={isConnected} orientation="vertical" px={2}>
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
<Flex>
<EditableFieldTitle

View File

@@ -1,14 +1,14 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
import { nodeExclusivelySelected, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { MouseEvent, PropsWithChildren } from 'react';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
@@ -20,8 +20,16 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
const executionState = useExecutionState(nodeId);
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
const selectIsInProgress = useMemo(
() =>
createSelector(
selectNodesSlice,
(nodes) => nodes.nodeExecutionStates[nodeId]?.status === zNodeStatus.enum.IN_PROGRESS
),
[nodeId]
);
const isInProgress = useAppSelector(selectIsInProgress);
const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [
'nodeInProgress',
@@ -31,7 +39,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const dispatch = useAppDispatch();
const opacity = useAppSelector((s) => s.workflowSettings.nodeOpacity);
const opacity = useAppSelector((s) => s.nodes.nodeOpacity);
const { onCloseGlobal } = useGlobalMenuClose();
const handleClick = useCallback(

View File

@@ -1,12 +1,12 @@
import { CompositeSlider, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeOpacityChanged } from 'features/nodes/store/workflowSettingsSlice';
import { nodeOpacityChanged } from 'features/nodes/store/nodesSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const NodeOpacitySlider = () => {
const dispatch = useAppDispatch();
const nodeOpacity = useAppSelector((s) => s.workflowSettings.nodeOpacity);
const nodeOpacity = useAppSelector((s) => s.nodes.nodeOpacity);
const { t } = useTranslation();
const handleChange = useCallback(

View File

@@ -1,6 +1,9 @@
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shouldShowMinimapPanelChanged } from 'features/nodes/store/workflowSettingsSlice';
import {
// shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
} from 'features/nodes/store/nodesSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
@@ -16,9 +19,9 @@ const ViewportControls = () => {
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
// const shouldShowFieldTypeLegend = useAppSelector(
// (s) => s.nodes.present.shouldShowFieldTypeLegend
// (s) => s.nodes.shouldShowFieldTypeLegend
// );
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
const handleClickedZoomIn = useCallback(() => {
zoomIn();

View File

@@ -16,7 +16,7 @@ const minimapStyles: SystemStyleObject = {
};
const MinimapPanel = () => {
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
return (
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>

View File

@@ -1,18 +1,23 @@
import { IconButton } from '@invoke-ai/ui-library';
import { openAddNodePopover } from 'features/nodes/store/nodesSlice';
import { memo } from 'react';
import { useAppDispatch } from 'app/store/storeHooks';
import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
const AddNodeButton = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleOpenAddNodePopover = useCallback(() => {
dispatch(addNodePopoverOpened());
}, [dispatch]);
return (
<IconButton
tooltip={t('nodes.addNodeToolTip')}
aria-label={t('nodes.addNode')}
icon={<PiPlusBold />}
onClick={openAddNodePopover}
onClick={handleOpenAddNodePopover}
pointerEvents="auto"
/>
);

View File

@@ -21,13 +21,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ReloadNodeTemplatesButton from 'features/nodes/components/flow/panels/TopRightPanel/ReloadSchemaButton';
import {
selectionModeChanged,
selectWorkflowSettingsSlice,
selectNodesSlice,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from 'features/nodes/store/workflowSettingsSlice';
} from 'features/nodes/store/nodesSlice';
import type { ChangeEvent, ReactNode } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -35,7 +35,7 @@ import { SelectionMode } from 'reactflow';
const formLabelProps: FormLabelProps = { flexGrow: 1 };
const selector = createMemoizedSelector(selectWorkflowSettingsSlice, (workflowSettings) => {
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
@@ -43,7 +43,7 @@ const selector = createMemoizedSelector(selectWorkflowSettingsSlice, (workflowSe
shouldColorEdges,
shouldShowEdgeLabels,
selectionMode,
} = workflowSettings;
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,

View File

@@ -3,21 +3,27 @@ import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => selectLastSelectedNode(nodes));
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
return {
data: lastSelectedNode?.data,
};
});
const InspectorDataTab = () => {
const { t } = useTranslation();
const lastSelectedNode = useAppSelector(selector);
const { data } = useAppSelector(selector);
if (!lastSelectedNode) {
if (!data) {
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
}
return <DataViewer data={lastSelectedNode.data} label="Node Data" />;
return <DataViewer data={data} label="Node Data" />;
};
export default memo(InspectorDataTab);

View File

@@ -1,39 +1,36 @@
import { Box, Flex, FormControl, FormLabel, HStack, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import EditableNodeTitle from './details/EditableNodeTitle';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
});
const InspectorDetailsTab = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNode = selectLastSelectedNode(nodes);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
}),
[templates]
);
const data = useAppSelector(selector);
const { t } = useTranslation();

View File

@@ -1,49 +1,46 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { ImageOutput } from 'services/api/types';
import type { AnyResult } from 'services/events/types';
import ImageOutputPreview from './outputs/ImageOutputPreview';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
return;
}
return {
outputs: nes.outputs,
outputType: lastSelectedNodeTemplate.outputType,
};
});
const InspectorOutputsTab = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNode = selectLastSelectedNode(nodes);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.id,
outputType: lastSelectedNodeTemplate.outputType,
};
}),
[templates]
);
const data = useAppSelector(selector);
const nes = useExecutionState(data?.nodeId);
const { t } = useTranslation();
if (!data || !nes) {
if (!data) {
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
}
if (nes.outputs.length === 0) {
if (data.outputs.length === 0) {
return <IAINoContentFallback label={t('nodes.noOutputRecorded')} icon={null} />;
}
@@ -52,11 +49,11 @@ const InspectorOutputsTab = () => {
<ScrollableContent>
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
{data.outputType === 'image_output' ? (
nes.outputs.map((result, i) => (
data.outputs.map((result, i) => (
<ImageOutputPreview key={getKey(result, i)} output={result as ImageOutput} />
))
) : (
<DataViewer data={nes.outputs} label={t('nodes.nodeOutputs')} />
<DataViewer data={data.outputs} label={t('nodes.nodeOutputs')} />
)}
</Flex>
</ScrollableContent>

View File

@@ -1,26 +1,25 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { memo, useMemo } from 'react';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const NodeTemplateInspector = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNode = selectLastSelectedNode(nodes);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
return lastSelectedNodeTemplate;
}),
[templates]
);
const template = useAppSelector(selector);
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
return {
template: lastSelectedNodeTemplate,
};
});
const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector);
const { t } = useTranslation();
if (!template) {

View File

@@ -1,39 +1,31 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
}
const fields = map(template.inputs).filter(
(field) =>
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
}),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useMemo(() => {
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return false;
}
return (
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
const fieldNames = useAppSelector(selector);
return fieldNames;
};

View File

@@ -1,5 +1,4 @@
import { useStore } from '@nanostores/react';
import { $templates } from 'features/nodes/store/nodesSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
@@ -9,7 +8,8 @@ import { useCallback } from 'react';
import { useReactFlow } from 'reactflow';
export const useBuildNode = () => {
const templates = useStore($templates);
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
const flow = useReactFlow();
return useCallback(
@@ -41,10 +41,10 @@ export const useBuildNode = () => {
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = templates[type] as InvocationTemplate;
const template = nodeTemplates[type] as InvocationTemplate;
return buildInvocationNode(position, template);
},
[templates, flow]
[nodeTemplates, flow]
);
};

View File

@@ -1,93 +0,0 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import {
$isAddNodePopoverOpen,
$isUpdatingEdge,
$pendingConnection,
$templates,
connectionMade,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useCallback, useMemo } from 'react';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { assert } from 'tsafe';
export const useConnection = () => {
const store = useAppStore();
const templates = useStore($templates);
const onConnectStart = useCallback<OnConnectStart>(
(event, params) => {
const nodes = store.getState().nodes.present.nodes;
const { nodeId, handleId, handleType } = params;
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
const node = nodes.find((n) => n.id === nodeId);
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
const template = templates[node.data.type];
assert(template, `Template not found for node type: ${node.data.type}`);
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
$pendingConnection.set({
node,
template,
fieldTemplate,
});
},
[store, templates]
);
const onConnect = useCallback<OnConnect>(
(connection) => {
const { dispatch } = store;
dispatch(connectionMade(connection));
$pendingConnection.set(null);
},
[store]
);
const onConnectEnd = useCallback<OnConnectEnd>(() => {
const { dispatch } = store;
const pendingConnection = $pendingConnection.get();
const isUpdatingEdge = $isUpdatingEdge.get();
const mouseOverNodeId = $mouseOverNode.get();
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
// update logic can finish up
if (isUpdatingEdge && !mouseOverNodeId) {
$pendingConnection.set(null);
return;
}
if (!pendingConnection) {
return;
}
const { nodes, edges } = store.getState().nodes.present;
if (mouseOverNodeId) {
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
if (!candidateNode) {
// The mouse is over a non-invocation node - bail
return;
}
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const connection = getFirstValidConnection(
templates,
nodes,
edges,
pendingConnection,
candidateNode,
candidateTemplate
);
if (connection) {
dispatch(connectionMade(connection));
}
$pendingConnection.set(null);
} else {
// The mouse is not over a node - we should open the add node popover
$isAddNodePopoverOpen.set(true);
}
}, [store, templates]);
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
return api;
};

View File

@@ -1,39 +1,34 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
}
// get the visible fields
const fields = map(template.inputs).filter(
(field) =>
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
}),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useMemo(() => {
// get the visible fields
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return true;
}
return (
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
const fieldNames = useAppSelector(selector);
return fieldNames;
};

View File

@@ -1,12 +1,16 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { useMemo } from 'react';
import { useFieldType } from './useFieldType.ts';
const selectIsConnectionInProgress = createSelector(
selectNodesSlice,
(nodes) => nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null
);
type UseConnectionStateProps = {
nodeId: string;
fieldName: string;
@@ -14,8 +18,6 @@ type UseConnectionStateProps = {
};
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates);
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo(
@@ -34,30 +36,25 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
);
const selectConnectionError = useMemo(
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
[nodeId, fieldName, kind, fieldType]
);
const selectIsConnectionStartField = useMemo(
() =>
makeConnectionErrorSelector(
templates,
pendingConnection,
nodeId,
fieldName,
kind === 'inputs' ? 'target' : 'source',
fieldType
createSelector(selectNodesSlice, (nodes) =>
Boolean(
nodes.connectionStartParams?.nodeId === nodeId &&
nodes.connectionStartParams?.handleId === fieldName &&
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
)
),
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
[fieldName, kind, nodeId]
);
const isConnected = useAppSelector(selectIsConnected);
const isConnectionInProgress = useMemo(() => Boolean(pendingConnection), [pendingConnection]);
const isConnectionStartField = useMemo(() => {
if (!pendingConnection) {
return false;
}
return (
pendingConnection.node.id === nodeId &&
pendingConnection.fieldTemplate.name === fieldName &&
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
);
}, [fieldName, kind, nodeId, pendingConnection]);
const isConnectionInProgress = useAppSelector(selectIsConnectionInProgress);
const isConnectionStartField = useAppSelector(selectIsConnectionStartField);
const connectionError = useAppSelector(selectConnectionError);
const shouldDim = useMemo(

View File

@@ -1,78 +0,0 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import {
$copiedEdges,
$copiedNodes,
$cursorPos,
$edgesToCopiedNodes,
selectionPasted,
selectNodesSlice,
} from 'features/nodes/store/nodesSlice';
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
import { isEqual, uniqWith } from 'lodash-es';
import { v4 as uuidv4 } from 'uuid';
const copySelection = () => {
// Use the imperative API here so we don't have to pass the whole slice around
const { getState } = getStore();
const { nodes, edges } = selectNodesSlice(getState());
const selectedNodes = nodes.filter((node) => node.selected);
const selectedEdges = edges.filter((edge) => edge.selected);
const edgesToSelectedNodes = edges.filter((edge) => selectedNodes.some((node) => node.id === edge.target));
$copiedNodes.set(selectedNodes);
$copiedEdges.set(selectedEdges);
$edgesToCopiedNodes.set(edgesToSelectedNodes);
};
const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
const { getState, dispatch } = getStore();
const currentNodes = selectNodesSlice(getState()).nodes;
const cursorPos = $cursorPos.get();
const copiedNodes = deepClone($copiedNodes.get());
let copiedEdges = deepClone($copiedEdges.get());
if (withEdgesToCopiedNodes) {
const edgesToCopiedNodes = deepClone($edgesToCopiedNodes.get());
copiedEdges = uniqWith([...copiedEdges, ...edgesToCopiedNodes], isEqual);
}
// Calculate an offset to reposition nodes to surround the cursor position, maintaining relative positioning
const xCoords = copiedNodes.map((node) => node.position.x);
const yCoords = copiedNodes.map((node) => node.position.y);
const minX = Math.min(...xCoords);
const minY = Math.min(...yCoords);
const offsetX = cursorPos ? cursorPos.x - minX : 50;
const offsetY = cursorPos ? cursorPos.y - minY : 50;
copiedNodes.forEach((node) => {
const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY);
node.position.x = x;
node.position.y = y;
// Pasted nodes are selected
node.selected = true;
// Also give em a fresh id
const id = uuidv4();
// Update the edges to point to the new node id
for (const edge of copiedEdges) {
if (edge.source === node.id) {
edge.source = id;
edge.id = edge.id.replace(node.data.id, id);
}
if (edge.target === node.id) {
edge.target = id;
edge.id = edge.id.replace(node.data.id, id);
}
}
node.id = id;
node.data.id = id;
});
dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges }));
};
const api = { copySelection, pasteSelection };
export const useCopyPaste = () => {
return api;
};

View File

@@ -1,56 +0,0 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodeExecutionStates } from 'features/nodes/store/types';
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { map } from 'nanostores';
import { useEffect, useMemo } from 'react';
export const $nodeExecutionStates = map<NodeExecutionStates>({});
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
export const useExecutionState = (nodeId?: string) => {
const executionStates = useStore($nodeExecutionStates, nodeId ? { keys: [nodeId] } : undefined);
const executionState = useMemo(() => (nodeId ? executionStates[nodeId] : undefined), [executionStates, nodeId]);
return executionState;
};
const removeNodeExecutionState = (nodeId: string) => {
$nodeExecutionStates.setKey(nodeId, undefined);
};
export const upsertExecutionState = (nodeId: string, updates?: Partial<NodeExecutionState>) => {
const state = $nodeExecutionStates.get()[nodeId];
if (!state) {
$nodeExecutionStates.setKey(nodeId, { ...deepClone(initialNodeExecutionState), nodeId, ...updates });
} else {
$nodeExecutionStates.setKey(nodeId, { ...state, ...updates });
}
};
const selectNodeIds = createMemoizedSelector(selectNodesSlice, (nodesSlice) => nodesSlice.nodes.map((node) => node.id));
export const useSyncExecutionState = () => {
const nodeIds = useAppSelector(selectNodeIds);
useEffect(() => {
const nodeExecutionStates = $nodeExecutionStates.get();
const nodeIdsToAdd = nodeIds.filter((id) => !nodeExecutionStates[id]);
const nodeIdsToRemove = Object.keys(nodeExecutionStates).filter((id) => !nodeIds.includes(id));
for (const id of nodeIdsToAdd) {
upsertExecutionState(id);
}
for (const id of nodeIdsToRemove) {
removeNodeExecutionState(id);
}
}, [nodeIds]);
};

View File

@@ -1,9 +1,20 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
const template = useNodeTemplate(nodeId);
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectFieldInputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate;
};

View File

@@ -1,9 +1,20 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import type { FieldOutputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => {
const template = useNodeTemplate(nodeId);
const fieldTemplate = useMemo(() => template.outputs[fieldName] ?? null, [fieldName, template.outputs]);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate;
};

View File

@@ -1,36 +1,27 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectInvocationNodeType } from 'features/nodes/store/selectors';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useFieldTemplate = (
nodeId: string,
fieldName: string,
kind: 'inputs' | 'outputs'
): FieldInputTemplate | FieldOutputTemplate => {
const templates = useStore($templates);
const selectNodeType = useMemo(
() => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
[nodeId]
): FieldInputTemplate | FieldOutputTemplate | null => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName);
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, kind, nodeId]
);
const nodeType = useAppSelector(selectNodeType);
const fieldTemplate = useMemo(() => {
const template = templates[nodeType];
assert(template, `Template for node type ${nodeType} not found`);
if (kind === 'inputs') {
const fieldTemplate = template.inputs[fieldName];
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
return fieldTemplate;
} else {
const fieldTemplate = template.outputs[fieldName];
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
return fieldTemplate;
}
}, [fieldName, kind, nodeType, templates]);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate;
};

View File

@@ -1,8 +1,22 @@
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const fieldTemplateTitle = useMemo(() => fieldTemplate.title, [fieldTemplate]);
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null;
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null;
}),
[fieldName, kind, nodeId]
);
const fieldTemplateTitle = useAppSelector(selector);
return fieldTemplateTitle;
};

View File

@@ -1,9 +1,23 @@
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import type { FieldType } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]);
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null;
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null;
}),
[fieldName, kind, nodeId]
);
const fieldType = useAppSelector(selector);
return fieldType;
};

View File

@@ -1,26 +1,20 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
import { useMemo } from 'react';
const selector = createSelector(selectNodesSlice, (nodes) =>
nodes.nodes.filter(isInvocationNode).some((node) => {
const template = nodes.templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node, template);
})
);
export const useGetNodesNeedUpdate = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) =>
nodes.nodes.filter(isInvocationNode).some((node) => {
const template = templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node.data, template);
})
),
[templates]
);
const needsUpdate = useAppSelector(selector);
return needsUpdate;
const getNeedsUpdate = useAppSelector(selector);
return getNeedsUpdate;
};

View File

@@ -1,20 +1,26 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { some } from 'lodash-es';
import { useMemo } from 'react';
export const useHasImageOutput = (nodeId: string): boolean => {
const template = useNodeTemplate(nodeId);
const hasImageOutput = useMemo(
const selector = useMemo(
() =>
some(
template?.outputs,
(output) =>
output.type.name === 'ImageField' &&
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
template?.type !== 'image'
),
[template]
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
return some(
template?.outputs,
(output) =>
output.type.name === 'ImageField' &&
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
template?.type !== 'image'
);
}),
[nodeId]
);
const hasImageOutput = useAppSelector(selector);
return hasImageOutput;
};

View File

@@ -1,12 +1,8 @@
// TODO: enable this at some point
import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { isEqual } from 'lodash-es';
import { useCallback } from 'react';
import type { Connection, Node } from 'reactflow';
@@ -17,8 +13,7 @@ import type { Connection, Node } from 'reactflow';
export const useIsValidConnection = () => {
const store = useAppStore();
const templates = useStore($templates);
const shouldValidateGraph = useAppSelector((s) => s.workflowSettings.shouldValidateGraph);
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph);
const isValidConnection = useCallback(
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
// Connection must have valid targets
@@ -32,7 +27,7 @@ export const useIsValidConnection = () => {
}
const state = store.getState();
const { nodes, edges } = state.nodes.present;
const { nodes, edges, templates } = state.nodes;
// Find the source and target nodes
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
@@ -45,10 +40,6 @@ export const useIsValidConnection = () => {
return false;
}
if (targetFieldTemplate.input === 'direct') {
return false;
}
if (!shouldValidateGraph) {
// manual override!
return true;
@@ -66,14 +57,6 @@ export const useIsValidConnection = () => {
return false;
}
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
return isEqual(sourceFieldTemplate.type, collectItemType);
}
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
@@ -93,7 +76,7 @@ export const useIsValidConnection = () => {
// Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges);
},
[shouldValidateGraph, templates, store]
[shouldValidateGraph, store]
);
return isValidConnection;

View File

@@ -1,9 +1,19 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import type { Classification } from 'features/nodes/types/common';
import { useMemo } from 'react';
export const useNodeClassification = (nodeId: string): Classification => {
const template = useNodeTemplate(nodeId);
const classification = useMemo(() => template.classification, [template]);
return classification;
export const useNodeClassification = (nodeId: string): Classification | null => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
return selectNodeTemplate(nodes, nodeId)?.classification ?? null;
}),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};

View File

@@ -5,7 +5,7 @@ import { selectNodeData } from 'features/nodes/store/selectors';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useNodeData = (nodeId: string): InvocationNodeData => {
export const useNodeData = (nodeId: string): InvocationNodeData | null => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {

View File

@@ -1,11 +1,25 @@
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
import { useMemo } from 'react';
export const useNodeNeedsUpdate = (nodeId: string) => {
const data = useNodeData(nodeId);
const template = useNodeTemplate(nodeId);
const needsUpdate = useMemo(() => getNeedsUpdate(data, template), [data, template]);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const node = selectInvocationNode(nodes, nodeId);
const template = selectNodeTemplate(nodes, nodeId);
if (!node || !template) {
return false;
}
return getNeedsUpdate(node, template);
}),
[nodeId]
);
const needsUpdate = useAppSelector(selector);
return needsUpdate;
};

View File

@@ -1,23 +1,20 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectInvocationNodeType } from 'features/nodes/store/selectors';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useNodeTemplate = (nodeId: string): InvocationTemplate => {
const templates = useStore($templates);
const selectNodeType = useMemo(
() => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
return selectNodeTemplate(nodes, nodeId);
}),
[nodeId]
);
const nodeType = useAppSelector(selectNodeType);
const template = useMemo(() => {
const t = templates[nodeType];
assert(t, `Template for node type ${nodeType} not found`);
return t;
}, [nodeType, templates]);
return template;
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
};

View File

@@ -1,8 +1,18 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
export const useNodeTemplateTitle = (nodeId: string): string | null => {
const template = useNodeTemplate(nodeId);
const title = useMemo(() => template.title, [template.title]);
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
return selectNodeTemplate(nodes, nodeId)?.title ?? null;
}),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};

View File

@@ -1,10 +1,26 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { map } from 'lodash-es';
import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const fieldNames = useMemo(() => getSortedFilteredFieldNames(map(template.outputs)), [template.outputs]);
export const useOutputFieldNames = (nodeId: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
}
return getSortedFilteredFieldNames(map(template.outputs));
}),
[nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames;
};

View File

@@ -1,6 +1,7 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
@@ -42,21 +43,76 @@ import {
zT2IAdapterModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom } from 'nanostores';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import type {
Connection,
Edge,
EdgeChange,
EdgeRemoveChange,
Node,
NodeChange,
OnConnectStartParams,
Viewport,
XYPosition,
} from 'reactflow';
import {
addEdge,
applyEdgeChanges,
applyNodeChanges,
getConnectedEdges,
getIncomers,
getOutgoers,
SelectionMode,
} from 'reactflow';
import {
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationError,
socketInvocationStarted,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import type { z } from 'zod';
import type { NodesState, PendingConnection, Templates } from './types';
import type { NodesState } from './types';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
const initialNodesState: NodesState = {
_version: 1,
nodes: [],
edges: [],
templates: {},
connectionStartParams: null,
connectionStartFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
shouldShowMinimapPanel: true,
shouldValidateGraph: true,
shouldAnimateEdges: true,
shouldSnapToGrid: false,
shouldColorEdges: true,
shouldShowEdgeLabels: false,
isAddNodePopoverOpen: false,
nodeOpacity: 1,
selectedNodes: [],
selectedEdges: [],
nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 },
nodesToCopy: [],
edgesToCopy: [],
selectionMode: SelectionMode.Partial,
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
@@ -98,12 +154,12 @@ export const nodesSlice = createSlice({
}
state.nodes[nodeIndex] = action.payload.node;
},
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
const { node, cursorPos } = action.payload;
nodeAdded: (state, action: PayloadAction<AnyNode>) => {
const node = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
cursorPos?.x ?? node.position.x,
cursorPos?.y ?? node.position.y
state.addNewNodePosition?.x ?? node.position.x,
state.addNewNodePosition?.y ?? node.position.y
);
node.position = position;
node.selected = true;
@@ -119,6 +175,40 @@ export const nodesSlice = createSlice({
);
state.nodes.push(node);
if (!isInvocationNode(node)) {
return;
}
state.nodeExecutionStates[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
if (state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
const newConnection = findConnectionToValidHandle(
node,
state.nodes,
state.edges,
state.templates,
nodeId,
handleId,
handleType,
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
}
}
}
state.connectionStartParams = null;
state.connectionStartFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
@@ -126,8 +216,71 @@ export const nodesSlice = createSlice({
edgeAdded: (state, action: PayloadAction<Edge>) => {
state.edges = addEdge(action.payload, state.edges);
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload;
state.connectionMade = state.modifyingEdge;
const { nodeId, handleId, handleType } = action.payload;
if (!nodeId || !handleId) {
return;
}
const node = state.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const template = state.templates[node.data.type];
const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId];
state.connectionStartFieldType = field?.type ?? null;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
const fieldType = state.connectionStartFieldType;
if (!fieldType) {
return;
}
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
state.connectionMade = true;
},
connectionEnded: (
state,
action: PayloadAction<{
cursorPosition: XYPosition;
mouseOverNodeId: string | null;
}>
) => {
const { cursorPosition, mouseOverNodeId } = action.payload;
if (!state.connectionMade) {
if (mouseOverNodeId) {
const nodeIndex = state.nodes.findIndex((n) => n.id === mouseOverNodeId);
const mouseOverNode = state.nodes?.[nodeIndex];
if (mouseOverNode && state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
state.nodes,
state.edges,
state.templates,
nodeId,
handleId,
handleType,
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
}
}
}
state.connectionStartParams = null;
state.connectionStartFieldType = null;
} else {
state.addNewNodePosition = cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.connectionStartFieldType = null;
}
state.modifyingEdge = false;
},
fieldLabelChanged: (
state,
@@ -289,6 +442,7 @@ export const nodesSlice = createSlice({
if (!isInvocationNode(node)) {
return;
}
delete state.nodeExecutionStates[node.id];
});
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
@@ -320,6 +474,12 @@ export const nodesSlice = createSlice({
state.nodes
);
},
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
state.selectedNodes = action.payload;
},
selectedEdgesChanged: (state, action: PayloadAction<string[]>) => {
state.selectedEdges = action.payload;
},
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
fieldValueReducer(state, action, zStatefulFieldValue);
},
@@ -377,10 +537,34 @@ export const nodesSlice = createSlice({
}
node.data.notes = value;
},
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
},
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
},
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload;
},
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAnimateEdges = action.payload;
},
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowEdgeLabels = action.payload;
},
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
state.shouldSnapToGrid = action.payload;
},
shouldColorEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldColorEdges = action.payload;
},
nodeOpacityChanged: (state, action: PayloadAction<number>) => {
state.nodeOpacity = action.payload;
},
viewportChanged: (state, action: PayloadAction<Viewport>) => {
state.viewport = action.payload;
},
selectedAll: (state) => {
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
@@ -391,49 +575,136 @@ export const nodesSlice = createSlice({
state.edges
);
},
selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => {
const { nodes, edges } = action.payload;
selectionCopied: (state) => {
const nodesToCopy: AnyNode[] = [];
const edgesToCopy: Edge[] = [];
const nodeChanges: NodeChange[] = [];
for (const node of state.nodes) {
if (node.selected) {
nodesToCopy.push(deepClone(node));
}
}
// Deselect existing nodes
state.nodes.forEach((n) => {
nodeChanges.push({
id: n.data.id,
type: 'select',
selected: false,
});
});
// Add new nodes
nodes.forEach((n) => {
nodeChanges.push({
item: n,
type: 'add',
});
});
for (const edge of state.edges) {
if (edge.selected) {
edgesToCopy.push(deepClone(edge));
}
}
const edgeChanges: EdgeChange[] = [];
// Deselect existing edges
state.edges.forEach((e) => {
edgeChanges.push({
id: e.id,
type: 'select',
selected: false,
});
});
// Add new edges
edges.forEach((e) => {
edgeChanges.push({
item: e,
type: 'add',
});
});
state.nodesToCopy = nodesToCopy;
state.edgesToCopy = edgesToCopy;
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
if (state.nodesToCopy.length > 0) {
const averagePosition = { x: 0, y: 0 };
state.nodesToCopy.forEach((e) => {
const xOffset = 0.15 * (e.width ?? 0);
const yOffset = 0.5 * (e.height ?? 0);
averagePosition.x += e.position.x + xOffset;
averagePosition.y += e.position.y + yOffset;
});
averagePosition.x /= state.nodesToCopy.length;
averagePosition.y /= state.nodesToCopy.length;
state.nodesToCopy.forEach((e) => {
e.position.x -= averagePosition.x;
e.position.y -= averagePosition.y;
});
}
},
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
const { cursorPosition } = action.payload;
const newNodes: AnyNode[] = [];
for (const node of state.nodesToCopy) {
newNodes.push(deepClone(node));
}
const oldNodeIds = newNodes.map((n) => n.data.id);
const newEdges: Edge[] = [];
for (const edge of state.edgesToCopy) {
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
newEdges.push(deepClone(edge));
}
}
newEdges.forEach((e) => (e.selected = true));
newNodes.forEach((node) => {
const newNodeId = uuidv4();
newEdges.forEach((edge) => {
if (edge.source === node.data.id) {
edge.source = newNodeId;
edge.id = edge.id.replace(node.data.id, newNodeId);
}
if (edge.target === node.data.id) {
edge.target = newNodeId;
edge.id = edge.id.replace(node.data.id, newNodeId);
}
});
node.selected = true;
node.id = newNodeId;
node.data.id = newNodeId;
const position = findUnoccupiedPosition(
state.nodes,
node.position.x + (cursorPosition?.x ?? 0),
node.position.y + (cursorPosition?.y ?? 0)
);
node.position = position;
});
const nodeAdditions: NodeChange[] = newNodes.map((n) => ({
item: n,
type: 'add',
}));
const nodeSelectionChanges: NodeChange[] = state.nodes.map((n) => ({
id: n.data.id,
type: 'select',
selected: false,
}));
const edgeAdditions: EdgeChange[] = newEdges.map((e) => ({
item: e,
type: 'add',
}));
const edgeSelectionChanges: EdgeChange[] = state.edges.map((e) => ({
id: e.id,
type: 'select',
selected: false,
}));
state.nodes = applyNodeChanges(nodeAdditions.concat(nodeSelectionChanges), state.nodes);
state.edges = applyEdgeChanges(edgeAdditions.concat(edgeSelectionChanges), state.edges);
newNodes.forEach((node) => {
state.nodeExecutionStates[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
});
},
addNodePopoverOpened: (state) => {
state.addNewNodePosition = null; //Create the node in viewport center by default
state.isAddNodePopoverOpen = true;
},
addNodePopoverClosed: (state) => {
state.isAddNodePopoverOpen = false;
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.connectionStartFieldType = null;
},
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
},
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
state.templates = action.payload;
},
undo: (state) => state,
redo: (state) => state,
},
extraReducers: (builder) => {
builder.addCase(workflowLoaded, (state, action) => {
@@ -449,13 +720,75 @@ export const nodesSlice = createSlice({
edges.map((edge) => ({ item: edge, type: 'add' })),
[]
);
state.nodeExecutionStates = nodes.reduce<Record<string, NodeExecutionState>>((acc, node) => {
acc[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
return acc;
}, {});
});
builder.addCase(socketInvocationStarted, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.IN_PROGRESS;
}
});
builder.addCase(socketInvocationComplete, (state, action) => {
const { source_node_id, result } = action.payload.data;
const nes = state.nodeExecutionStates[source_node_id];
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
}
});
builder.addCase(socketInvocationError, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.FAILED;
node.error = action.payload.data.error;
node.progress = null;
node.progressImage = null;
}
});
builder.addCase(socketGeneratorProgress, (state, action) => {
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.IN_PROGRESS;
node.progress = (step + 1) / total_steps;
node.progressImage = progress_image ?? null;
}
});
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = zNodeStatus.enum.PENDING;
nes.error = null;
nes.progress = null;
nes.progressImage = null;
nes.outputs = [];
});
}
});
},
});
export const {
addNodePopoverClosed,
addNodePopoverOpened,
connectionEnded,
connectionMade,
connectionStarted,
edgeDeleted,
edgeChangeStarted,
edgesChanged,
edgesDeleted,
fieldValueReset,
@@ -483,97 +816,31 @@ export const {
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
nodeOpacityChanged,
nodesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged,
selectedAll,
selectedEdgesChanged,
selectedNodesChanged,
selectionCopied,
selectionModeChanged,
selectionPasted,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowMinimapPanelChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
viewportChanged,
edgeAdded,
undo,
redo,
nodeTemplatesBuilt,
shouldShowEdgeLabelsChanged,
} = nodesSlice.actions;
export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Templates>({});
export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null);
export const $isUpdatingEdge = atom(false);
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
export const $isAddNodePopoverOpen = atom(false);
export const closeAddNodePopover = () => {
$isAddNodePopoverOpen.set(false);
$pendingConnection.set(null);
};
export const openAddNodePopover = () => {
$isAddNodePopoverOpen.set(true);
};
export const selectNodesSlice = (state: RootState) => state.nodes.present;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialNodesState,
migrate: migrateNodesState,
persistDenylist: [],
};
const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected);
const isSelectionAction = (action: UnknownAction) => {
if (selectionMatcher(action)) {
return true;
}
if (nodesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'select')) {
return true;
}
}
return false;
};
const individualGroupByMatcher = isAnyOf(nodesChanged);
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
limit: 64,
undoType: nodesSlice.actions.undo.type,
redoType: nodesSlice.actions.redo.type,
groupBy: (action, state, history) => {
if (isSelectionAction(action)) {
// Changes to selection should never be recorded on their own
return history.group;
}
if (individualGroupByMatcher(action)) {
return action.type;
}
return null;
},
filter: (action, _state, _history) => {
// Ignore all actions from other slices
if (!action.type.startsWith(nodesSlice.name)) {
return false;
}
if (nodesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'dimensions')) {
return false;
}
}
return true;
},
};
// This is used for tracking `state.workflow.isTouched`
export const isAnyNodeOrEdgeMutation = isAnyOf(
connectionEnded,
connectionMade,
edgeDeleted,
edgesChanged,
@@ -606,3 +873,30 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
selectionPasted,
edgeAdded
);
export const selectNodesSlice = (state: RootState) => state.nodes;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialNodesState,
migrate: migrateNodesState,
persistDenylist: [
'connectionStartParams',
'connectionStartFieldType',
'selectedNodes',
'selectedEdges',
'nodesToCopy',
'edgesToCopy',
'connectionMade',
'modifyingEdge',
'addNewNodePosition',
],
};

View File

@@ -1,23 +1,26 @@
import type { NodesState } from 'features/nodes/store/types';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { assert } from 'tsafe';
const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
if (!isInvocationNode(node)) {
return null;
}
return node;
};
export const selectInvocationNodeType = (nodesSlice: NodesState, nodeId: string): string => {
const node = selectInvocationNode(nodesSlice, nodeId);
return node.data.type;
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => {
return selectInvocationNode(nodesSlice, nodeId)?.data ?? null;
};
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData => {
export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => {
const node = selectInvocationNode(nodesSlice, nodeId);
return node.data;
if (!node) {
return null;
}
return nodesSlice.templates[node.data.type] ?? null;
};
export const selectFieldInputInstance = (
@@ -29,10 +32,20 @@ export const selectFieldInputInstance = (
return data?.inputs[fieldName] ?? null;
};
export const selectLastSelectedNode = (nodesSlice: NodesState) => {
const selectedNodes = nodesSlice.nodes.filter((n) => n.selected);
if (selectedNodes.length === 1) {
return selectedNodes[0];
}
return null;
export const selectFieldInputTemplate = (
nodesSlice: NodesState,
nodeId: string,
fieldName: string
): FieldInputTemplate | null => {
const template = selectNodeTemplate(nodesSlice, nodeId);
return template?.inputs[fieldName] ?? null;
};
export const selectFieldOutputTemplate = (
nodesSlice: NodesState,
nodeId: string,
fieldName: string
): FieldOutputTemplate | null => {
const template = selectNodeTemplate(nodesSlice, nodeId);
return template?.outputs[fieldName] ?? null;
};

View File

@@ -1,31 +1,38 @@
import type {
FieldIdentifier,
FieldInputTemplate,
FieldOutputTemplate,
StatefulFieldValue,
} from 'features/nodes/types/field';
import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/nodes/types/field';
import type {
AnyNode,
InvocationNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
export type PendingConnection = {
node: InvocationNode;
template: InvocationTemplate;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};
import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow';
export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: InvocationNodeEdge[];
templates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
connectionStartFieldType: FieldType | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectedNodes: string[];
selectedEdges: string[];
nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport;
nodesToCopy: AnyNode[];
edgesToCopy: InvocationNodeEdge[];
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode;
};
export type WorkflowMode = 'edit' | 'view';

View File

@@ -1,105 +1,112 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, isEqual, map } from 'lodash-es';
import type { Connection } from 'reactflow';
import { assert } from 'tsafe';
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { Connection, Edge, HandleType, Node } from 'reactflow';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getFirstValidConnection = (
templates: Templates,
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
node: Node,
handle: FieldInputTemplate | FieldOutputTemplate
) => {
let isValidConnection = true;
if (handleCurrentType === 'source') {
if (
edges.find((edge) => {
return edge.target === node.id && edge.targetHandle === handle.name;
})
) {
isValidConnection = false;
}
} else {
if (
edges.find((edge) => {
return edge.source === node.id && edge.sourceHandle === handle.name;
})
) {
isValidConnection = false;
}
}
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
isValidConnection = false;
}
return isValidConnection;
};
export const findConnectionToValidHandle = (
node: AnyNode,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate
templates: Record<string, InvocationTemplate>,
handleCurrentNodeId: string,
handleCurrentName: string,
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType
): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
if (node.id === handleCurrentNodeId || !isInvocationNode(node)) {
return null;
}
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
const template = templates[node.data.type];
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct');
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
);
const candidateField = candidateUnconnectedFields.find((field) =>
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null;
}
if (!template) {
return null;
}
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return null;
}
const handles = handleCurrentType === 'source' ? template.inputs : template.outputs;
// Sources/outputs can have any number of edges, we can take the first matching output field
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
//Prioritize handles whos name matches the node we're coming from
const handle = handles[handleCurrentName];
if (handle) {
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
if (isGraphAcyclic && valid) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
};
}
}
for (const handleName in handles) {
const handle = handles[handleName];
if (!handle) {
continue;
}
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
if (isGraphAcyclic && valid) {
return {
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
};
}
}
return null;
};

Some files were not shown because too many files have changed in this diff Show More