mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 00:17:56 -05:00
Compare commits
78 Commits
v4.2.1
...
image-capt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59327e827b | ||
|
|
a18d7adad4 | ||
|
|
32dff2c4e3 | ||
|
|
575ecb4028 | ||
|
|
ad8778df6c | ||
|
|
d2f5103f9f | ||
|
|
dd42a56084 | ||
|
|
23ac340a3f | ||
|
|
6791b4eaa8 | ||
|
|
a8b042177d | ||
|
|
76825f4261 | ||
|
|
78cb4d75ad | ||
|
|
a18bbac262 | ||
|
|
9ff5596963 | ||
|
|
8ea596b1e9 | ||
|
|
e3a143eaed | ||
|
|
c359ab6d9b | ||
|
|
dbfaa07e03 | ||
|
|
7f78fe7a36 | ||
|
|
6cf5b402c6 | ||
|
|
b0c7c7cb47 | ||
|
|
4d68cd8dbb | ||
|
|
2c1fa30639 | ||
|
|
708c68413d | ||
|
|
1d884fb794 | ||
|
|
f6a44681a8 | ||
|
|
d4df312300 | ||
|
|
9c0d44b412 | ||
|
|
27826369f0 | ||
|
|
31d8b50276 | ||
|
|
40b4fa7238 | ||
|
|
3b1743b7c2 | ||
|
|
f489c818f1 | ||
|
|
af477fa295 | ||
|
|
0ff0290735 | ||
|
|
67dbe6d949 | ||
|
|
4c3c2297b9 | ||
|
|
cadea55521 | ||
|
|
c8f30b1392 | ||
|
|
3d14a98abf | ||
|
|
77024bfca7 | ||
|
|
4a1c3786a1 | ||
|
|
b239891986 | ||
|
|
9fb03d43ff | ||
|
|
bdc59786bd | ||
|
|
fb6e926500 | ||
|
|
48ccd63dba | ||
|
|
ee647a05dc | ||
|
|
154b52ca4d | ||
|
|
5dd460c3ce | ||
|
|
4897ce2a13 | ||
|
|
5425526d50 | ||
|
|
5a4b050e66 | ||
|
|
8d39520232 | ||
|
|
04d12a1e98 | ||
|
|
39aa70963b | ||
|
|
5743254a41 | ||
|
|
c538ffea26 | ||
|
|
e8d3a7c870 | ||
|
|
2be66b1546 | ||
|
|
76e181fd44 | ||
|
|
b5d42fbc66 | ||
|
|
b463cd763e | ||
|
|
eb320df41d | ||
|
|
de1869773f | ||
|
|
ef89c7e537 | ||
|
|
008645d386 | ||
|
|
f8042ffb41 | ||
|
|
dbe22be598 | ||
|
|
8f6078d007 | ||
|
|
4020bf47e2 | ||
|
|
9d685da759 | ||
|
|
e3289856c0 | ||
|
|
47b8153728 | ||
|
|
7901e4c082 | ||
|
|
18b0977a31 | ||
|
|
fc6b214470 | ||
|
|
e22211dac0 |
@@ -98,7 +98,7 @@ Updating is exactly the same as installing - download the latest installer, choo
|
||||
|
||||
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
|
||||
|
||||
[installation requirements]: INSTALLATION.md#installation-requirements
|
||||
[installation requirements]: INSTALL_REQUIREMENTS.md
|
||||
[FAQ]: ../help/FAQ.md
|
||||
[install some models]: 050_INSTALLING_MODELS.md
|
||||
[configuration docs]: ../features/CONFIGURATION.md
|
||||
|
||||
@@ -164,6 +164,12 @@ def custom_openapi() -> dict[str, Any]:
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
@@ -172,6 +178,8 @@ def custom_openapi() -> dict[str, Any]:
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, List, Union
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@@ -15,7 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.primitives import ImageOutput, CaptionImageOutputs, CaptionImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
@@ -66,6 +67,56 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"auto_caption_image",
|
||||
title="Automatically Caption Image",
|
||||
tags=["image", "caption"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
)
|
||||
class CaptionImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Adds a caption to an image"""
|
||||
|
||||
images: Union[ImageField,List[ImageField]] = InputField(description="The image to caption")
|
||||
prompt: str = InputField(default="Describe this list of images in 20 words or less", description="Describe how you would like the image to be captioned.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CaptionImageOutputs:
|
||||
|
||||
model_id = "vikhyatk/moondream2"
|
||||
model_revision = "2024-04-02"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
|
||||
moondream_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, trust_remote_code=True, revision=model_revision
|
||||
)
|
||||
output: CaptionImageOutputs = CaptionImageOutputs()
|
||||
try:
|
||||
from PIL.Image import Image
|
||||
images: List[Image] = []
|
||||
image_fields = self.images if isinstance(self.images, list) else [self.images]
|
||||
for image in image_fields:
|
||||
images.append(context.images.get_pil(image.image_name))
|
||||
answers: List[str] = moondream_model.batch_answer(
|
||||
images=images,
|
||||
prompts=[self.prompt] * len(images),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert isinstance(answers, list)
|
||||
for i, answer in enumerate(answers):
|
||||
output.images.append(CaptionImageOutput(
|
||||
image=image_fields[i],
|
||||
width=images[i].width,
|
||||
height=images[i].height,
|
||||
caption=answer
|
||||
))
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
del moondream_model
|
||||
del tokenizer
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
@@ -194,7 +245,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
image: List[ImageField] = InputField(description="The image to create the mask from")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
@@ -190,6 +190,75 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation_output("lora_selector_output")
|
||||
class LoRASelectorOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
|
||||
|
||||
|
||||
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0")
|
||||
class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRASelectorOutput:
|
||||
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
|
||||
|
||||
|
||||
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
|
||||
class LoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
|
||||
|
||||
loras: LoRAField | list[LoRAField] = InputField(
|
||||
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
||||
output = LoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
for lora in loras:
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
assert lora.lora.base in (BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.unet is not None:
|
||||
if output.unet is None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(lora)
|
||||
|
||||
if self.clip is not None:
|
||||
if output.clip is None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
@@ -279,6 +348,72 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"sdxl_lora_collection_loader",
|
||||
title="SDXL LoRA Collection Loader",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
|
||||
|
||||
loras: LoRAField | list[LoRAField] = InputField(
|
||||
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
clip2: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
||||
output = SDXLLoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
for lora in loras:
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
assert lora.lora.base is BaseModelType.StableDiffusionXL
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.unet is not None:
|
||||
if output.unet is None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(lora)
|
||||
|
||||
if self.clip is not None:
|
||||
if output.clip is None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(lora)
|
||||
|
||||
if self.clip2 is not None:
|
||||
if output.clip2 is None:
|
||||
output.clip2 = self.clip2.model_copy(deep=True)
|
||||
output.clip2.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -247,6 +247,17 @@ class ImageOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("captioned_image_output")
|
||||
class CaptionImageOutput(ImageOutput):
|
||||
caption: str = OutputField(description="Caption for given image")
|
||||
|
||||
|
||||
|
||||
@invocation_output("captioned_image_outputs")
|
||||
class CaptionImageOutputs(BaseInvocationOutput):
|
||||
images: List[CaptionImageOutput] = OutputField(description="List of captioned images", default=[])
|
||||
|
||||
|
||||
@invocation_output("image_collection_output")
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
||||
@@ -10,6 +10,8 @@ module.exports = {
|
||||
'path/no-relative-imports': ['error', { maxDepth: 0 }],
|
||||
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
|
||||
'i18next/no-literal-string': 'error',
|
||||
// https://eslint.org/docs/latest/rules/no-console
|
||||
'no-console': 'error',
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
|
||||
3
invokeai/frontend/web/.gitignore
vendored
3
invokeai/frontend/web/.gitignore
vendored
@@ -43,4 +43,5 @@ stats.html
|
||||
yalc.lock
|
||||
|
||||
# vitest
|
||||
tsconfig.vitest-temp.json
|
||||
tsconfig.vitest-temp.json
|
||||
coverage/
|
||||
@@ -35,6 +35,7 @@
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
"test": "vitest",
|
||||
"test:ui": "vitest --coverage --ui",
|
||||
"test:no-watch": "vitest --no-watch"
|
||||
},
|
||||
"madge": {
|
||||
@@ -132,6 +133,8 @@
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@vitejs/plugin-react-swc": "^3.6.0",
|
||||
"@vitest/coverage-v8": "^1.5.0",
|
||||
"@vitest/ui": "^1.5.0",
|
||||
"concurrently": "^8.2.2",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^8.57.0",
|
||||
|
||||
148
invokeai/frontend/web/pnpm-lock.yaml
generated
148
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -229,6 +229,12 @@ devDependencies:
|
||||
'@vitejs/plugin-react-swc':
|
||||
specifier: ^3.6.0
|
||||
version: 3.6.0(vite@5.2.11)
|
||||
'@vitest/coverage-v8':
|
||||
specifier: ^1.5.0
|
||||
version: 1.6.0(vitest@1.6.0)
|
||||
'@vitest/ui':
|
||||
specifier: ^1.5.0
|
||||
version: 1.6.0(vitest@1.6.0)
|
||||
concurrently:
|
||||
specifier: ^8.2.2
|
||||
version: 8.2.2
|
||||
@@ -288,7 +294,7 @@ devDependencies:
|
||||
version: 4.3.2(typescript@5.4.5)(vite@5.2.11)
|
||||
vitest:
|
||||
specifier: ^1.6.0
|
||||
version: 1.6.0(@types/node@20.12.10)
|
||||
version: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
|
||||
packages:
|
||||
|
||||
@@ -1679,6 +1685,10 @@ packages:
|
||||
resolution: {integrity: sha512-4iri8i1AqYHJE2DstZYkyEprg6Pq6sKx3xn5FpySk9sNhH7qN2LLlHJCfDTZRILNwQNPD7mATWM0TBui7uC1pA==}
|
||||
dev: true
|
||||
|
||||
/@bcoe/v8-coverage@0.2.3:
|
||||
resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==}
|
||||
dev: true
|
||||
|
||||
/@chakra-ui/accordion@2.3.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react@18.3.1):
|
||||
resolution: {integrity: sha512-FSXRm8iClFyU+gVaXisOSEw0/4Q+qZbFRiuhIAkVU6Boj0FxAMrlo9a8AV5TuF77rgaHytCdHk0Ng+cyUijrag==}
|
||||
peerDependencies:
|
||||
@@ -3635,6 +3645,11 @@ packages:
|
||||
wrap-ansi-cjs: /wrap-ansi@7.0.0
|
||||
dev: true
|
||||
|
||||
/@istanbuljs/schema@0.1.3:
|
||||
resolution: {integrity: sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==}
|
||||
engines: {node: '>=8'}
|
||||
dev: true
|
||||
|
||||
/@jest/schemas@29.6.3:
|
||||
resolution: {integrity: sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==}
|
||||
engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0}
|
||||
@@ -3822,6 +3837,10 @@ packages:
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/@polka/url@1.0.0-next.25:
|
||||
resolution: {integrity: sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==}
|
||||
dev: true
|
||||
|
||||
/@popperjs/core@2.11.8:
|
||||
resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==}
|
||||
dev: false
|
||||
@@ -5146,7 +5165,7 @@ packages:
|
||||
dom-accessibility-api: 0.6.3
|
||||
lodash: 4.17.21
|
||||
redent: 3.0.0
|
||||
vitest: 1.6.0(@types/node@20.12.10)
|
||||
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
dev: true
|
||||
|
||||
/@testing-library/user-event@14.5.2(@testing-library/dom@9.3.4):
|
||||
@@ -5825,6 +5844,29 @@ packages:
|
||||
- '@swc/helpers'
|
||||
dev: true
|
||||
|
||||
/@vitest/coverage-v8@1.6.0(vitest@1.6.0):
|
||||
resolution: {integrity: sha512-KvapcbMY/8GYIG0rlwwOKCVNRc0OL20rrhFkg/CHNzncV03TE2XWvO5w9uZYoxNiMEBacAJt3unSOiZ7svePew==}
|
||||
peerDependencies:
|
||||
vitest: 1.6.0
|
||||
dependencies:
|
||||
'@ampproject/remapping': 2.3.0
|
||||
'@bcoe/v8-coverage': 0.2.3
|
||||
debug: 4.3.4
|
||||
istanbul-lib-coverage: 3.2.2
|
||||
istanbul-lib-report: 3.0.1
|
||||
istanbul-lib-source-maps: 5.0.4
|
||||
istanbul-reports: 3.1.7
|
||||
magic-string: 0.30.10
|
||||
magicast: 0.3.4
|
||||
picocolors: 1.0.0
|
||||
std-env: 3.7.0
|
||||
strip-literal: 2.1.0
|
||||
test-exclude: 6.0.0
|
||||
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
dev: true
|
||||
|
||||
/@vitest/expect@1.3.1:
|
||||
resolution: {integrity: sha512-xofQFwIzfdmLLlHa6ag0dPV8YsnKOCP1KdAeVVh34vSjN2dcUiXYCD9htu/9eM7t8Xln4v03U9HLxLpPlsXdZw==}
|
||||
dependencies:
|
||||
@@ -5869,6 +5911,21 @@ packages:
|
||||
tinyspy: 2.2.1
|
||||
dev: true
|
||||
|
||||
/@vitest/ui@1.6.0(vitest@1.6.0):
|
||||
resolution: {integrity: sha512-k3Lyo+ONLOgylctiGovRKy7V4+dIN2yxstX3eY5cWFXH6WP+ooVX79YSyi0GagdTQzLmT43BF27T0s6dOIPBXA==}
|
||||
peerDependencies:
|
||||
vitest: 1.6.0
|
||||
dependencies:
|
||||
'@vitest/utils': 1.6.0
|
||||
fast-glob: 3.3.2
|
||||
fflate: 0.8.2
|
||||
flatted: 3.3.1
|
||||
pathe: 1.1.2
|
||||
picocolors: 1.0.0
|
||||
sirv: 2.0.4
|
||||
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
dev: true
|
||||
|
||||
/@vitest/utils@1.3.1:
|
||||
resolution: {integrity: sha512-d3Waie/299qqRyHTm2DjADeTaNdNSVsnwHPWrs20JMpjh6eiVq7ggggweO8rc4arhf6rRkWuHKwvxGvejUXZZQ==}
|
||||
dependencies:
|
||||
@@ -8521,6 +8578,10 @@ packages:
|
||||
resolution: {integrity: sha512-3yurQZ2hD9VISAhJJP9bpYFNQrHHBXE2JxxjY5aLEcDi46RmAzJE2OC9FAde0yis5ElW0jTTzs0zfg/Cca4XqQ==}
|
||||
dev: true
|
||||
|
||||
/fflate@0.8.2:
|
||||
resolution: {integrity: sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==}
|
||||
dev: true
|
||||
|
||||
/file-entry-cache@6.0.1:
|
||||
resolution: {integrity: sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==}
|
||||
engines: {node: ^10.12.0 || >=12.0.0}
|
||||
@@ -9084,6 +9145,10 @@ packages:
|
||||
resolution: {integrity: sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==}
|
||||
dev: true
|
||||
|
||||
/html-escaper@2.0.2:
|
||||
resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==}
|
||||
dev: true
|
||||
|
||||
/html-parse-stringify@3.0.1:
|
||||
resolution: {integrity: sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==}
|
||||
dependencies:
|
||||
@@ -9513,6 +9578,39 @@ packages:
|
||||
engines: {node: '>=0.10.0'}
|
||||
dev: true
|
||||
|
||||
/istanbul-lib-coverage@3.2.2:
|
||||
resolution: {integrity: sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==}
|
||||
engines: {node: '>=8'}
|
||||
dev: true
|
||||
|
||||
/istanbul-lib-report@3.0.1:
|
||||
resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==}
|
||||
engines: {node: '>=10'}
|
||||
dependencies:
|
||||
istanbul-lib-coverage: 3.2.2
|
||||
make-dir: 4.0.0
|
||||
supports-color: 7.2.0
|
||||
dev: true
|
||||
|
||||
/istanbul-lib-source-maps@5.0.4:
|
||||
resolution: {integrity: sha512-wHOoEsNJTVltaJp8eVkm8w+GVkVNHT2YDYo53YdzQEL2gWm1hBX5cGFR9hQJtuGLebidVX7et3+dmDZrmclduw==}
|
||||
engines: {node: '>=10'}
|
||||
dependencies:
|
||||
'@jridgewell/trace-mapping': 0.3.25
|
||||
debug: 4.3.4
|
||||
istanbul-lib-coverage: 3.2.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
dev: true
|
||||
|
||||
/istanbul-reports@3.1.7:
|
||||
resolution: {integrity: sha512-BewmUXImeuRk2YY0PVbxgKAysvhRPUQE0h5QRM++nVWyubKGV0l8qQ5op8+B2DOmwSe63Jivj0BjkPQVf8fP5g==}
|
||||
engines: {node: '>=8'}
|
||||
dependencies:
|
||||
html-escaper: 2.0.2
|
||||
istanbul-lib-report: 3.0.1
|
||||
dev: true
|
||||
|
||||
/iterable-lookahead@1.0.0:
|
||||
resolution: {integrity: sha512-hJnEP2Xk4+44DDwJqUQGdXal5VbyeWLaPyDl2AQc242Zr7iqz4DgpQOrEzglWVMGHMDCkguLHEKxd1+rOsmgSQ==}
|
||||
engines: {node: '>=4'}
|
||||
@@ -9912,6 +10010,14 @@ packages:
|
||||
'@jridgewell/sourcemap-codec': 1.4.15
|
||||
dev: true
|
||||
|
||||
/magicast@0.3.4:
|
||||
resolution: {integrity: sha512-TyDF/Pn36bBji9rWKHlZe+PZb6Mx5V8IHCSxk7X4aljM4e/vyDvZZYwHewdVaqiA0nb3ghfHU/6AUpDxWoER2Q==}
|
||||
dependencies:
|
||||
'@babel/parser': 7.24.5
|
||||
'@babel/types': 7.24.5
|
||||
source-map-js: 1.2.0
|
||||
dev: true
|
||||
|
||||
/make-dir@2.1.0:
|
||||
resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==}
|
||||
engines: {node: '>=6'}
|
||||
@@ -9927,6 +10033,13 @@ packages:
|
||||
semver: 6.3.1
|
||||
dev: true
|
||||
|
||||
/make-dir@4.0.0:
|
||||
resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==}
|
||||
engines: {node: '>=10'}
|
||||
dependencies:
|
||||
semver: 7.6.0
|
||||
dev: true
|
||||
|
||||
/map-obj@2.0.0:
|
||||
resolution: {integrity: sha512-TzQSV2DiMYgoF5RycneKVUzIa9bQsj/B3tTgsE3dOGqlzHnGIDaC7XBE7grnA+8kZPnfqSGFe95VHc2oc0VFUQ==}
|
||||
engines: {node: '>=4'}
|
||||
@@ -10101,6 +10214,11 @@ packages:
|
||||
resolution: {integrity: sha512-iSAJLHYKnX41mKcJKjqvnAN9sf0LMDTXDEvFv+ffuRR9a1MIuXLjMNL6EsnDHSkKLTWNqQQ5uo61P4EbU4NU+Q==}
|
||||
dev: false
|
||||
|
||||
/mrmime@2.0.0:
|
||||
resolution: {integrity: sha512-eu38+hdgojoyq63s+yTpN4XMBdt5l8HhMhc4VKLO9KM5caLIBvUm4thi7fFaxyTmCKeNnXZ5pAlBwCUnhA09uw==}
|
||||
engines: {node: '>=10'}
|
||||
dev: true
|
||||
|
||||
/ms@2.0.0:
|
||||
resolution: {integrity: sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==}
|
||||
dev: true
|
||||
@@ -11766,6 +11884,15 @@ packages:
|
||||
engines: {node: '>=14'}
|
||||
dev: true
|
||||
|
||||
/sirv@2.0.4:
|
||||
resolution: {integrity: sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==}
|
||||
engines: {node: '>= 10'}
|
||||
dependencies:
|
||||
'@polka/url': 1.0.0-next.25
|
||||
mrmime: 2.0.0
|
||||
totalist: 3.0.1
|
||||
dev: true
|
||||
|
||||
/sisteransi@1.0.5:
|
||||
resolution: {integrity: sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==}
|
||||
dev: true
|
||||
@@ -12191,6 +12318,15 @@ packages:
|
||||
unique-string: 2.0.0
|
||||
dev: true
|
||||
|
||||
/test-exclude@6.0.0:
|
||||
resolution: {integrity: sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w==}
|
||||
engines: {node: '>=8'}
|
||||
dependencies:
|
||||
'@istanbuljs/schema': 0.1.3
|
||||
glob: 7.2.3
|
||||
minimatch: 3.1.2
|
||||
dev: true
|
||||
|
||||
/text-table@0.2.0:
|
||||
resolution: {integrity: sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==}
|
||||
dev: true
|
||||
@@ -12264,6 +12400,11 @@ packages:
|
||||
engines: {node: '>=0.6'}
|
||||
dev: true
|
||||
|
||||
/totalist@3.0.1:
|
||||
resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==}
|
||||
engines: {node: '>=6'}
|
||||
dev: true
|
||||
|
||||
/tr46@0.0.3:
|
||||
resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==}
|
||||
|
||||
@@ -12837,7 +12978,7 @@ packages:
|
||||
fsevents: 2.3.3
|
||||
dev: true
|
||||
|
||||
/vitest@1.6.0(@types/node@20.12.10):
|
||||
/vitest@1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0):
|
||||
resolution: {integrity: sha512-H5r/dN06swuFnzNFhq/dnz37bPXnq8xB2xB5JOVk8K09rUtoeNN+LHWkoQ0A/i3hvbUKKcCei9KpbxqHMLhLLA==}
|
||||
engines: {node: ^18.0.0 || >=20.0.0}
|
||||
hasBin: true
|
||||
@@ -12867,6 +13008,7 @@ packages:
|
||||
'@vitest/runner': 1.6.0
|
||||
'@vitest/snapshot': 1.6.0
|
||||
'@vitest/spy': 1.6.0
|
||||
'@vitest/ui': 1.6.0(vitest@1.6.0)
|
||||
'@vitest/utils': 1.6.0
|
||||
acorn-walk: 8.3.2
|
||||
chai: 4.4.1
|
||||
|
||||
@@ -774,6 +774,7 @@
|
||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||
"cannotConnectToSelf": "Cannot connect to self",
|
||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
||||
"nodePack": "Node pack",
|
||||
"collection": "Collection",
|
||||
"collectionFieldType": "{{name}} Collection",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
/* eslint-disable no-console */
|
||||
import fs from 'node:fs';
|
||||
|
||||
import openapiTS from 'openapi-typescript';
|
||||
|
||||
@@ -67,6 +67,8 @@ export const useSocketIO = () => {
|
||||
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
window.$socketOptions = $socketOptions;
|
||||
// This is only enabled manually for debugging, console is allowed.
|
||||
/* eslint-disable-next-line no-console */
|
||||
console.log('Socket initialized', socket);
|
||||
}
|
||||
|
||||
@@ -75,6 +77,8 @@ export const useSocketIO = () => {
|
||||
return () => {
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
window.$socketOptions = undefined;
|
||||
// This is only enabled manually for debugging, console is allowed.
|
||||
/* eslint-disable-next-line no-console */
|
||||
console.log('Socket teardown', socket);
|
||||
}
|
||||
socket.disconnect();
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
/* eslint-disable no-console */
|
||||
// This is only enabled manually for debugging, console is allowed.
|
||||
|
||||
import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import { diff } from 'jsondiffpatch';
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
@@ -25,13 +24,6 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||
};
|
||||
}
|
||||
|
||||
if (nodeTemplatesBuilt.match(action)) {
|
||||
return {
|
||||
...action,
|
||||
payload: '<Node templates omitted>',
|
||||
};
|
||||
}
|
||||
|
||||
if (socketGeneratorProgress.match(action)) {
|
||||
const sanitized = deepClone(action);
|
||||
if (sanitized.payload.data.progress_image) {
|
||||
|
||||
@@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
|
||||
|
||||
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
||||
deleted_images.forEach((image_name) => {
|
||||
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name);
|
||||
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
|
||||
|
||||
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
||||
dispatch(resetCanvas());
|
||||
|
||||
@@ -148,7 +148,6 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
log.trace('Control Adapter preprocessor cancelled');
|
||||
} else {
|
||||
// Some other error condition...
|
||||
console.log(error);
|
||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
||||
|
||||
if (error instanceof Object) {
|
||||
|
||||
@@ -8,8 +8,8 @@ import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graph/buildCanvasGraph';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
@@ -18,7 +18,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
let graph;
|
||||
|
||||
if (model && model.base === 'sdxl') {
|
||||
if (model?.base === 'sdxl') {
|
||||
graph = await buildGenerationTabSDXLGraph(state);
|
||||
} else {
|
||||
graph = await buildGenerationTabGraph(state);
|
||||
|
||||
@@ -11,9 +11,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { nodes, edges } = state.nodes;
|
||||
const { nodes, edges } = state.nodes.present;
|
||||
const workflow = state.workflow;
|
||||
const graph = buildNodesGraph(state.nodes);
|
||||
const graph = buildNodesGraph(state.nodes.present);
|
||||
const builtWorkflow = buildWorkflowWithValidation({
|
||||
nodes,
|
||||
edges,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { size } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
@@ -9,7 +9,7 @@ import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: (action, { getState }) => {
|
||||
const log = logger('system');
|
||||
const schemaJSON = action.payload;
|
||||
|
||||
@@ -20,7 +20,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
|
||||
|
||||
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
|
||||
|
||||
dispatch(nodeTemplatesBuilt(nodeTemplates));
|
||||
$templates.set(nodeTemplates);
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.nodes.forEach((node) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@@ -9,6 +12,13 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
actionCreator: socketGeneratorProgress,
|
||||
effect: (action) => {
|
||||
log.trace(action.payload, `Generator progress`);
|
||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
nes.progressImage = progress_image ?? null;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
@@ -9,7 +10,9 @@ import {
|
||||
isImageViewerOpenChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -28,7 +31,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
const { data } = action.payload;
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||
|
||||
const { result, node, queue_batch_id } = data;
|
||||
const { result, node, queue_batch_id, source_node_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||
const { image_name } = result.image;
|
||||
@@ -110,6 +113,16 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
nes.outputs.push(result);
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@@ -9,6 +12,15 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
||||
actionCreator: socketInvocationError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.FAILED;
|
||||
nes.error = action.payload.data.error;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationStarted } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@@ -9,6 +12,12 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
|
||||
actionCreator: socketInvocationStarted,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||
|
||||
@@ -54,6 +58,21 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
dispatch(
|
||||
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
|
||||
);
|
||||
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach($nodeExecutionStates.get(), (nes) => {
|
||||
if (!nes) {
|
||||
return;
|
||||
}
|
||||
const clone = deepClone(nes);
|
||||
clone.status = zNodeStatus.enum.PENDING;
|
||||
clone.error = null;
|
||||
clone.progress = null;
|
||||
clone.progressImage = null;
|
||||
clone.outputs = [];
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
|
||||
@@ -14,7 +14,8 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
||||
actionCreator: updateAllNodesRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const { nodes, templates } = getState().nodes;
|
||||
const { nodes } = getState().nodes.present;
|
||||
const templates = $templates.get();
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
@@ -24,7 +25,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
||||
unableToUpdateCount++;
|
||||
return;
|
||||
}
|
||||
if (!getNeedsUpdate(node, template)) {
|
||||
if (!getNeedsUpdate(node.data, template)) {
|
||||
// No need to increment the count here, since we're not actually updating
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
|
||||
@@ -14,10 +15,10 @@ import { fromZodError } from 'zod-validation-error';
|
||||
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: (action, { dispatch }) => {
|
||||
const log = logger('nodes');
|
||||
const { workflow, asCopy } = action.payload;
|
||||
const nodeTemplates = getState().nodes.templates;
|
||||
const nodeTemplates = $templates.get();
|
||||
|
||||
try {
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||
|
||||
@@ -21,7 +21,8 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
|
||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
@@ -50,7 +51,7 @@ const allReducers = {
|
||||
[canvasSlice.name]: canvasSlice.reducer,
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
[generationSlice.name]: generationSlice.reducer,
|
||||
[nodesSlice.name]: nodesSlice.reducer,
|
||||
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
|
||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
@@ -66,6 +67,7 @@ const allReducers = {
|
||||
[workflowSlice.name]: workflowSlice.reducer,
|
||||
[hrfSlice.name]: hrfSlice.reducer,
|
||||
[controlLayersSlice.name]: undoable(controlLayersSlice.reducer, controlLayersUndoableConfig),
|
||||
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@@ -111,6 +113,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
|
||||
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
@@ -9,13 +10,16 @@ import { selectControlLayersSlice } from 'features/controlLayers/store/controlLa
|
||||
import type { Layer } from 'features/controlLayers/store/types';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import i18n from 'i18next';
|
||||
import { forEach, upperFirst } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
|
||||
const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
|
||||
@@ -25,199 +29,208 @@ const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
|
||||
regional_guidance_layer: 'controlLayers.regionalGuidance',
|
||||
};
|
||||
|
||||
const selector = createMemoizedSelector(
|
||||
[
|
||||
selectControlAdaptersSlice,
|
||||
selectGenerationSlice,
|
||||
selectSystemSlice,
|
||||
selectNodesSlice,
|
||||
selectDynamicPromptsSlice,
|
||||
selectControlLayersSlice,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(controlAdapters, generation, system, nodes, dynamicPrompts, controlLayers, activeTabName) => {
|
||||
const { model } = generation;
|
||||
const { size } = controlLayers.present;
|
||||
const { positivePrompt } = controlLayers.present;
|
||||
const createSelector = (templates: Templates) =>
|
||||
createMemoizedSelector(
|
||||
[
|
||||
selectControlAdaptersSlice,
|
||||
selectGenerationSlice,
|
||||
selectSystemSlice,
|
||||
selectNodesSlice,
|
||||
selectWorkflowSettingsSlice,
|
||||
selectDynamicPromptsSlice,
|
||||
selectControlLayersSlice,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
|
||||
const { model } = generation;
|
||||
const { size } = controlLayers.present;
|
||||
const { positivePrompt } = controlLayers.present;
|
||||
|
||||
const { isConnected } = system;
|
||||
const { isConnected } = system;
|
||||
|
||||
const reasons: { prefix?: string; content: string }[] = [];
|
||||
const reasons: { prefix?: string; content: string }[] = [];
|
||||
|
||||
// Cannot generate if not connected
|
||||
if (!isConnected) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
|
||||
}
|
||||
// Cannot generate if not connected
|
||||
if (!isConnected) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
|
||||
}
|
||||
|
||||
if (activeTabName === 'workflows') {
|
||||
if (nodes.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
|
||||
if (activeTabName === 'workflows') {
|
||||
if (workflowSettings.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = templates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
|
||||
return;
|
||||
}
|
||||
|
||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||
|
||||
forEach(node.data.inputs, (field) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
const hasConnection = connectedEdges.some(
|
||||
(edge) => edge.target === node.id && edge.targetHandle === field.name
|
||||
);
|
||||
|
||||
if (!fieldTemplate) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
|
||||
return;
|
||||
}
|
||||
|
||||
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.missingInputForField', {
|
||||
nodeLabel: node.data.label || nodeTemplate.title,
|
||||
fieldLabel: field.label || fieldTemplate.title,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
if (!model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
|
||||
const nodeTemplate = nodes.templates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
|
||||
return;
|
||||
}
|
||||
|
||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||
|
||||
forEach(node.data.inputs, (field) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
const hasConnection = connectedEdges.some(
|
||||
(edge) => edge.target === node.id && edge.targetHandle === field.name
|
||||
);
|
||||
|
||||
if (!fieldTemplate) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
|
||||
return;
|
||||
}
|
||||
|
||||
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.missingInputForField', {
|
||||
nodeLabel: node.data.label || nodeTemplate.title,
|
||||
fieldLabel: field.label || fieldTemplate.title,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
|
||||
if (activeTabName === 'generation') {
|
||||
// Handling for generation tab
|
||||
controlLayers.present.layers
|
||||
.filter((l) => l.isEnabled)
|
||||
.forEach((l, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
if (l.type === 'control_adapter_layer') {
|
||||
// Must have model
|
||||
if (!l.controlAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (l.controlAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have a control image OR, if it has a processor, it must have a processed image
|
||||
if (!l.controlAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
|
||||
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
|
||||
}
|
||||
// T2I Adapters require images have dimensions that are multiples of 64
|
||||
if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
|
||||
}
|
||||
}
|
||||
|
||||
if (l.type === 'ip_adapter_layer') {
|
||||
// Must have model
|
||||
if (!l.ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (l.ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!l.ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
}
|
||||
|
||||
if (l.type === 'initial_image_layer') {
|
||||
// Must have an image
|
||||
if (!l.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
|
||||
}
|
||||
}
|
||||
|
||||
if (l.type === 'regional_guidance_layer') {
|
||||
// Must have a region
|
||||
if (l.maskObjects.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
|
||||
}
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
|
||||
}
|
||||
l.ipAdapters.forEach((ipAdapter) => {
|
||||
if (activeTabName === 'generation') {
|
||||
// Handling for generation tab
|
||||
controlLayers.present.layers
|
||||
.filter((l) => l.isEnabled)
|
||||
.forEach((l, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
if (l.type === 'control_adapter_layer') {
|
||||
// Must have model
|
||||
if (!ipAdapter.model) {
|
||||
if (!l.controlAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (l.controlAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have a control image OR, if it has a processor, it must have a processed image
|
||||
if (!l.controlAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
|
||||
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
|
||||
}
|
||||
// T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
|
||||
if (l.controlAdapter.type === 't2i_adapter') {
|
||||
const multiple = model?.base === 'sdxl' ? 32 : 64;
|
||||
if (size.width % multiple !== 0 || size.height % multiple !== 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (l.type === 'ip_adapter_layer') {
|
||||
// Must have model
|
||||
if (!l.ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (ipAdapter.model?.base !== model?.base) {
|
||||
if (l.ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!ipAdapter.image) {
|
||||
if (!l.ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Handling for all other tabs
|
||||
selectControlAdapterAll(controlAdapters)
|
||||
.filter((ca) => ca.isEnabled)
|
||||
.forEach((ca, i) => {
|
||||
if (!ca.isEnabled) {
|
||||
return;
|
||||
}
|
||||
if (l.type === 'initial_image_layer') {
|
||||
// Must have an image
|
||||
if (!l.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
|
||||
}
|
||||
}
|
||||
|
||||
if (!ca.model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
|
||||
} else if (ca.model.base !== model?.base) {
|
||||
// This should never happen, just a sanity check
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
|
||||
});
|
||||
}
|
||||
if (l.type === 'regional_guidance_layer') {
|
||||
// Must have a region
|
||||
if (l.maskObjects.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
|
||||
}
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
|
||||
}
|
||||
l.ipAdapters.forEach((ipAdapter) => {
|
||||
// Must have model
|
||||
if (!ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
!ca.controlImage ||
|
||||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
|
||||
) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }) });
|
||||
}
|
||||
});
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Handling for all other tabs
|
||||
selectControlAdapterAll(controlAdapters)
|
||||
.filter((ca) => ca.isEnabled)
|
||||
.forEach((ca, i) => {
|
||||
if (!ca.isEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ca.model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
|
||||
} else if (ca.model.base !== model?.base) {
|
||||
// This should never happen, just a sanity check
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
!ca.controlImage ||
|
||||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
|
||||
) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { isReady: !reasons.length, reasons };
|
||||
}
|
||||
);
|
||||
return { isReady: !reasons.length, reasons };
|
||||
}
|
||||
);
|
||||
|
||||
export const useIsReadyToEnqueue = () => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(() => createSelector(templates), [templates]);
|
||||
const value = useAppSelector(selector);
|
||||
return value;
|
||||
};
|
||||
|
||||
@@ -5,22 +5,7 @@ import type {
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type {
|
||||
CannyImageProcessorInvocation,
|
||||
ColorMapImageProcessorInvocation,
|
||||
ContentShuffleImageProcessorInvocation,
|
||||
DepthAnythingImageProcessorInvocation,
|
||||
DWOpenposeImageProcessorInvocation,
|
||||
HedImageProcessorInvocation,
|
||||
LineartAnimeImageProcessorInvocation,
|
||||
LineartImageProcessorInvocation,
|
||||
MediapipeFaceProcessorInvocation,
|
||||
MidasDepthImageProcessorInvocation,
|
||||
MlsdImageProcessorInvocation,
|
||||
NormalbaeImageProcessorInvocation,
|
||||
PidiImageProcessorInvocation,
|
||||
ZoeDepthImageProcessorInvocation,
|
||||
} from 'services/api/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -28,20 +13,20 @@ import { z } from 'zod';
|
||||
* Any ControlNet processor node
|
||||
*/
|
||||
export type ControlAdapterProcessorNode =
|
||||
| CannyImageProcessorInvocation
|
||||
| ColorMapImageProcessorInvocation
|
||||
| ContentShuffleImageProcessorInvocation
|
||||
| DepthAnythingImageProcessorInvocation
|
||||
| HedImageProcessorInvocation
|
||||
| LineartAnimeImageProcessorInvocation
|
||||
| LineartImageProcessorInvocation
|
||||
| MediapipeFaceProcessorInvocation
|
||||
| MidasDepthImageProcessorInvocation
|
||||
| MlsdImageProcessorInvocation
|
||||
| NormalbaeImageProcessorInvocation
|
||||
| DWOpenposeImageProcessorInvocation
|
||||
| PidiImageProcessorInvocation
|
||||
| ZoeDepthImageProcessorInvocation;
|
||||
| Invocation<'canny_image_processor'>
|
||||
| Invocation<'color_map_image_processor'>
|
||||
| Invocation<'content_shuffle_image_processor'>
|
||||
| Invocation<'depth_anything_image_processor'>
|
||||
| Invocation<'hed_image_processor'>
|
||||
| Invocation<'lineart_anime_image_processor'>
|
||||
| Invocation<'lineart_image_processor'>
|
||||
| Invocation<'mediapipe_face_processor'>
|
||||
| Invocation<'midas_depth_image_processor'>
|
||||
| Invocation<'mlsd_image_processor'>
|
||||
| Invocation<'normalbae_image_processor'>
|
||||
| Invocation<'dw_openpose_image_processor'>
|
||||
| Invocation<'pidi_image_processor'>
|
||||
| Invocation<'zoe_depth_image_processor'>;
|
||||
|
||||
/**
|
||||
* Any ControlNet processor type
|
||||
@@ -71,7 +56,7 @@ export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterPr
|
||||
* The Canny processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredCannyImageProcessorInvocation = O.Required<
|
||||
CannyImageProcessorInvocation,
|
||||
Invocation<'canny_image_processor'>,
|
||||
'type' | 'low_threshold' | 'high_threshold' | 'image_resolution' | 'detect_resolution'
|
||||
>;
|
||||
|
||||
@@ -79,7 +64,7 @@ export type RequiredCannyImageProcessorInvocation = O.Required<
|
||||
* The Color Map processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredColorMapImageProcessorInvocation = O.Required<
|
||||
ColorMapImageProcessorInvocation,
|
||||
Invocation<'color_map_image_processor'>,
|
||||
'type' | 'color_map_tile_size'
|
||||
>;
|
||||
|
||||
@@ -87,7 +72,7 @@ export type RequiredColorMapImageProcessorInvocation = O.Required<
|
||||
* The ContentShuffle processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredContentShuffleImageProcessorInvocation = O.Required<
|
||||
ContentShuffleImageProcessorInvocation,
|
||||
Invocation<'content_shuffle_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f'
|
||||
>;
|
||||
|
||||
@@ -95,7 +80,7 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
|
||||
* The DepthAnything processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
|
||||
DepthAnythingImageProcessorInvocation,
|
||||
Invocation<'depth_anything_image_processor'>,
|
||||
'type' | 'model_size' | 'resolution' | 'offload'
|
||||
>;
|
||||
|
||||
@@ -108,7 +93,7 @@ export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSiz
|
||||
* The HED processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredHedImageProcessorInvocation = O.Required<
|
||||
HedImageProcessorInvocation,
|
||||
Invocation<'hed_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'scribble'
|
||||
>;
|
||||
|
||||
@@ -116,7 +101,7 @@ export type RequiredHedImageProcessorInvocation = O.Required<
|
||||
* The Lineart Anime processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
|
||||
LineartAnimeImageProcessorInvocation,
|
||||
Invocation<'lineart_anime_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution'
|
||||
>;
|
||||
|
||||
@@ -124,7 +109,7 @@ export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
|
||||
* The Lineart processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredLineartImageProcessorInvocation = O.Required<
|
||||
LineartImageProcessorInvocation,
|
||||
Invocation<'lineart_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'coarse'
|
||||
>;
|
||||
|
||||
@@ -132,7 +117,7 @@ export type RequiredLineartImageProcessorInvocation = O.Required<
|
||||
* The MediapipeFace processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredMediapipeFaceProcessorInvocation = O.Required<
|
||||
MediapipeFaceProcessorInvocation,
|
||||
Invocation<'mediapipe_face_processor'>,
|
||||
'type' | 'max_faces' | 'min_confidence' | 'image_resolution' | 'detect_resolution'
|
||||
>;
|
||||
|
||||
@@ -140,7 +125,7 @@ export type RequiredMediapipeFaceProcessorInvocation = O.Required<
|
||||
* The MidasDepth processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredMidasDepthImageProcessorInvocation = O.Required<
|
||||
MidasDepthImageProcessorInvocation,
|
||||
Invocation<'midas_depth_image_processor'>,
|
||||
'type' | 'a_mult' | 'bg_th' | 'image_resolution' | 'detect_resolution'
|
||||
>;
|
||||
|
||||
@@ -148,7 +133,7 @@ export type RequiredMidasDepthImageProcessorInvocation = O.Required<
|
||||
* The MLSD processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredMlsdImageProcessorInvocation = O.Required<
|
||||
MlsdImageProcessorInvocation,
|
||||
Invocation<'mlsd_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'thr_v' | 'thr_d'
|
||||
>;
|
||||
|
||||
@@ -156,7 +141,7 @@ export type RequiredMlsdImageProcessorInvocation = O.Required<
|
||||
* The NormalBae processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredNormalbaeImageProcessorInvocation = O.Required<
|
||||
NormalbaeImageProcessorInvocation,
|
||||
Invocation<'normalbae_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution'
|
||||
>;
|
||||
|
||||
@@ -164,7 +149,7 @@ export type RequiredNormalbaeImageProcessorInvocation = O.Required<
|
||||
* The DW Openpose processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
|
||||
DWOpenposeImageProcessorInvocation,
|
||||
Invocation<'dw_openpose_image_processor'>,
|
||||
'type' | 'image_resolution' | 'draw_body' | 'draw_face' | 'draw_hands'
|
||||
>;
|
||||
|
||||
@@ -172,14 +157,14 @@ export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
|
||||
* The Pidi processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredPidiImageProcessorInvocation = O.Required<
|
||||
PidiImageProcessorInvocation,
|
||||
Invocation<'pidi_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'safe' | 'scribble'
|
||||
>;
|
||||
|
||||
/**
|
||||
* The ZoeDepth processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredZoeDepthImageProcessorInvocation = O.Required<ZoeDepthImageProcessorInvocation, 'type'>;
|
||||
export type RequiredZoeDepthImageProcessorInvocation = O.Required<Invocation<'zoe_depth_image_processor'>, 'type'>;
|
||||
|
||||
/**
|
||||
* Any ControlNet Processor node, with its parameters flagged as required
|
||||
|
||||
@@ -1,23 +1,9 @@
|
||||
import type { S } from 'services/api/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
|
||||
import type {
|
||||
_CannyProcessorConfig,
|
||||
_ColorMapProcessorConfig,
|
||||
_ContentShuffleProcessorConfig,
|
||||
_DepthAnythingProcessorConfig,
|
||||
_DWOpenposeProcessorConfig,
|
||||
_HedProcessorConfig,
|
||||
_LineartAnimeProcessorConfig,
|
||||
_LineartProcessorConfig,
|
||||
_MediapipeFaceProcessorConfig,
|
||||
_MidasDepthProcessorConfig,
|
||||
_MlsdProcessorConfig,
|
||||
_NormalbaeProcessorConfig,
|
||||
_PidiProcessorConfig,
|
||||
_ZoeDepthProcessorConfig,
|
||||
CannyProcessorConfig,
|
||||
CLIPVisionModelV2,
|
||||
ColorMapProcessorConfig,
|
||||
@@ -45,16 +31,16 @@ describe('Control Adapter Types', () => {
|
||||
assert<Equals<ProcessorConfig['type'], ProcessorTypeV2>>();
|
||||
});
|
||||
test('IP Adapter Method', () => {
|
||||
assert<Equals<NonNullable<S['IPAdapterInvocation']['method']>, IPMethodV2>>();
|
||||
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();
|
||||
});
|
||||
test('CLIP Vision Model', () => {
|
||||
assert<Equals<NonNullable<S['IPAdapterInvocation']['clip_vision_model']>, CLIPVisionModelV2>>();
|
||||
assert<Equals<NonNullable<Invocation<'ip_adapter'>['clip_vision_model']>, CLIPVisionModelV2>>();
|
||||
});
|
||||
test('Control Mode', () => {
|
||||
assert<Equals<NonNullable<S['ControlNetInvocation']['control_mode']>, ControlModeV2>>();
|
||||
assert<Equals<NonNullable<Invocation<'controlnet'>['control_mode']>, ControlModeV2>>();
|
||||
});
|
||||
test('DepthAnything Model Size', () => {
|
||||
assert<Equals<NonNullable<S['DepthAnythingImageProcessorInvocation']['model_size']>, DepthAnythingModelSize>>();
|
||||
assert<Equals<NonNullable<Invocation<'depth_anything_image_processor'>['model_size']>, DepthAnythingModelSize>>();
|
||||
});
|
||||
test('Processor Configs', () => {
|
||||
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
|
||||
@@ -75,3 +61,33 @@ describe('Control Adapter Types', () => {
|
||||
assert<Equals<_ZoeDepthProcessorConfig, ZoeDepthProcessorConfig>>();
|
||||
});
|
||||
});
|
||||
|
||||
// Types derived from OpenAPI
|
||||
type _CannyProcessorConfig = Required<
|
||||
Pick<Invocation<'canny_image_processor'>, 'id' | 'type' | 'low_threshold' | 'high_threshold'>
|
||||
>;
|
||||
type _ColorMapProcessorConfig = Required<
|
||||
Pick<Invocation<'color_map_image_processor'>, 'id' | 'type' | 'color_map_tile_size'>
|
||||
>;
|
||||
type _ContentShuffleProcessorConfig = Required<
|
||||
Pick<Invocation<'content_shuffle_image_processor'>, 'id' | 'type' | 'w' | 'h' | 'f'>
|
||||
>;
|
||||
type _DepthAnythingProcessorConfig = Required<
|
||||
Pick<Invocation<'depth_anything_image_processor'>, 'id' | 'type' | 'model_size'>
|
||||
>;
|
||||
type _HedProcessorConfig = Required<Pick<Invocation<'hed_image_processor'>, 'id' | 'type' | 'scribble'>>;
|
||||
type _LineartAnimeProcessorConfig = Required<Pick<Invocation<'lineart_anime_image_processor'>, 'id' | 'type'>>;
|
||||
type _LineartProcessorConfig = Required<Pick<Invocation<'lineart_image_processor'>, 'id' | 'type' | 'coarse'>>;
|
||||
type _MediapipeFaceProcessorConfig = Required<
|
||||
Pick<Invocation<'mediapipe_face_processor'>, 'id' | 'type' | 'max_faces' | 'min_confidence'>
|
||||
>;
|
||||
type _MidasDepthProcessorConfig = Required<
|
||||
Pick<Invocation<'midas_depth_image_processor'>, 'id' | 'type' | 'a_mult' | 'bg_th'>
|
||||
>;
|
||||
type _MlsdProcessorConfig = Required<Pick<Invocation<'mlsd_image_processor'>, 'id' | 'type' | 'thr_v' | 'thr_d'>>;
|
||||
type _NormalbaeProcessorConfig = Required<Pick<Invocation<'normalbae_image_processor'>, 'id' | 'type'>>;
|
||||
type _DWOpenposeProcessorConfig = Required<
|
||||
Pick<Invocation<'dw_openpose_image_processor'>, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
|
||||
>;
|
||||
type _PidiProcessorConfig = Required<Pick<Invocation<'pidi_image_processor'>, 'id' | 'type' | 'safe' | 'scribble'>>;
|
||||
type _ZoeDepthProcessorConfig = Required<Pick<Invocation<'zoe_depth_image_processor'>, 'id' | 'type'>>;
|
||||
|
||||
@@ -1,27 +1,7 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type {
|
||||
BaseModelType,
|
||||
CannyImageProcessorInvocation,
|
||||
ColorMapImageProcessorInvocation,
|
||||
ContentShuffleImageProcessorInvocation,
|
||||
ControlNetModelConfig,
|
||||
DepthAnythingImageProcessorInvocation,
|
||||
DWOpenposeImageProcessorInvocation,
|
||||
Graph,
|
||||
HedImageProcessorInvocation,
|
||||
ImageDTO,
|
||||
LineartAnimeImageProcessorInvocation,
|
||||
LineartImageProcessorInvocation,
|
||||
MediapipeFaceProcessorInvocation,
|
||||
MidasDepthImageProcessorInvocation,
|
||||
MlsdImageProcessorInvocation,
|
||||
NormalbaeImageProcessorInvocation,
|
||||
PidiImageProcessorInvocation,
|
||||
T2IAdapterModelConfig,
|
||||
ZoeDepthImageProcessorInvocation,
|
||||
} from 'services/api/types';
|
||||
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zId = z.string().min(1);
|
||||
@@ -32,9 +12,6 @@ const zCannyProcessorConfig = z.object({
|
||||
low_threshold: z.number().int().gte(0).lte(255),
|
||||
high_threshold: z.number().int().gte(0).lte(255),
|
||||
});
|
||||
export type _CannyProcessorConfig = Required<
|
||||
Pick<CannyImageProcessorInvocation, 'id' | 'type' | 'low_threshold' | 'high_threshold'>
|
||||
>;
|
||||
export type CannyProcessorConfig = z.infer<typeof zCannyProcessorConfig>;
|
||||
|
||||
const zColorMapProcessorConfig = z.object({
|
||||
@@ -42,9 +19,6 @@ const zColorMapProcessorConfig = z.object({
|
||||
type: z.literal('color_map_image_processor'),
|
||||
color_map_tile_size: z.number().int().gte(1),
|
||||
});
|
||||
export type _ColorMapProcessorConfig = Required<
|
||||
Pick<ColorMapImageProcessorInvocation, 'id' | 'type' | 'color_map_tile_size'>
|
||||
>;
|
||||
export type ColorMapProcessorConfig = z.infer<typeof zColorMapProcessorConfig>;
|
||||
|
||||
const zContentShuffleProcessorConfig = z.object({
|
||||
@@ -54,9 +28,6 @@ const zContentShuffleProcessorConfig = z.object({
|
||||
h: z.number().int().gte(0),
|
||||
f: z.number().int().gte(0),
|
||||
});
|
||||
export type _ContentShuffleProcessorConfig = Required<
|
||||
Pick<ContentShuffleImageProcessorInvocation, 'id' | 'type' | 'w' | 'h' | 'f'>
|
||||
>;
|
||||
export type ContentShuffleProcessorConfig = z.infer<typeof zContentShuffleProcessorConfig>;
|
||||
|
||||
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
||||
@@ -68,9 +39,6 @@ const zDepthAnythingProcessorConfig = z.object({
|
||||
type: z.literal('depth_anything_image_processor'),
|
||||
model_size: zDepthAnythingModelSize,
|
||||
});
|
||||
export type _DepthAnythingProcessorConfig = Required<
|
||||
Pick<DepthAnythingImageProcessorInvocation, 'id' | 'type' | 'model_size'>
|
||||
>;
|
||||
export type DepthAnythingProcessorConfig = z.infer<typeof zDepthAnythingProcessorConfig>;
|
||||
|
||||
const zHedProcessorConfig = z.object({
|
||||
@@ -78,14 +46,12 @@ const zHedProcessorConfig = z.object({
|
||||
type: z.literal('hed_image_processor'),
|
||||
scribble: z.boolean(),
|
||||
});
|
||||
export type _HedProcessorConfig = Required<Pick<HedImageProcessorInvocation, 'id' | 'type' | 'scribble'>>;
|
||||
export type HedProcessorConfig = z.infer<typeof zHedProcessorConfig>;
|
||||
|
||||
const zLineartAnimeProcessorConfig = z.object({
|
||||
id: zId,
|
||||
type: z.literal('lineart_anime_image_processor'),
|
||||
});
|
||||
export type _LineartAnimeProcessorConfig = Required<Pick<LineartAnimeImageProcessorInvocation, 'id' | 'type'>>;
|
||||
export type LineartAnimeProcessorConfig = z.infer<typeof zLineartAnimeProcessorConfig>;
|
||||
|
||||
const zLineartProcessorConfig = z.object({
|
||||
@@ -93,7 +59,6 @@ const zLineartProcessorConfig = z.object({
|
||||
type: z.literal('lineart_image_processor'),
|
||||
coarse: z.boolean(),
|
||||
});
|
||||
export type _LineartProcessorConfig = Required<Pick<LineartImageProcessorInvocation, 'id' | 'type' | 'coarse'>>;
|
||||
export type LineartProcessorConfig = z.infer<typeof zLineartProcessorConfig>;
|
||||
|
||||
const zMediapipeFaceProcessorConfig = z.object({
|
||||
@@ -102,9 +67,6 @@ const zMediapipeFaceProcessorConfig = z.object({
|
||||
max_faces: z.number().int().gte(1),
|
||||
min_confidence: z.number().gte(0).lte(1),
|
||||
});
|
||||
export type _MediapipeFaceProcessorConfig = Required<
|
||||
Pick<MediapipeFaceProcessorInvocation, 'id' | 'type' | 'max_faces' | 'min_confidence'>
|
||||
>;
|
||||
export type MediapipeFaceProcessorConfig = z.infer<typeof zMediapipeFaceProcessorConfig>;
|
||||
|
||||
const zMidasDepthProcessorConfig = z.object({
|
||||
@@ -113,9 +75,6 @@ const zMidasDepthProcessorConfig = z.object({
|
||||
a_mult: z.number().gte(0),
|
||||
bg_th: z.number().gte(0),
|
||||
});
|
||||
export type _MidasDepthProcessorConfig = Required<
|
||||
Pick<MidasDepthImageProcessorInvocation, 'id' | 'type' | 'a_mult' | 'bg_th'>
|
||||
>;
|
||||
export type MidasDepthProcessorConfig = z.infer<typeof zMidasDepthProcessorConfig>;
|
||||
|
||||
const zMlsdProcessorConfig = z.object({
|
||||
@@ -124,14 +83,12 @@ const zMlsdProcessorConfig = z.object({
|
||||
thr_v: z.number().gte(0),
|
||||
thr_d: z.number().gte(0),
|
||||
});
|
||||
export type _MlsdProcessorConfig = Required<Pick<MlsdImageProcessorInvocation, 'id' | 'type' | 'thr_v' | 'thr_d'>>;
|
||||
export type MlsdProcessorConfig = z.infer<typeof zMlsdProcessorConfig>;
|
||||
|
||||
const zNormalbaeProcessorConfig = z.object({
|
||||
id: zId,
|
||||
type: z.literal('normalbae_image_processor'),
|
||||
});
|
||||
export type _NormalbaeProcessorConfig = Required<Pick<NormalbaeImageProcessorInvocation, 'id' | 'type'>>;
|
||||
export type NormalbaeProcessorConfig = z.infer<typeof zNormalbaeProcessorConfig>;
|
||||
|
||||
const zDWOpenposeProcessorConfig = z.object({
|
||||
@@ -141,9 +98,6 @@ const zDWOpenposeProcessorConfig = z.object({
|
||||
draw_face: z.boolean(),
|
||||
draw_hands: z.boolean(),
|
||||
});
|
||||
export type _DWOpenposeProcessorConfig = Required<
|
||||
Pick<DWOpenposeImageProcessorInvocation, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
|
||||
>;
|
||||
export type DWOpenposeProcessorConfig = z.infer<typeof zDWOpenposeProcessorConfig>;
|
||||
|
||||
const zPidiProcessorConfig = z.object({
|
||||
@@ -152,14 +106,12 @@ const zPidiProcessorConfig = z.object({
|
||||
safe: z.boolean(),
|
||||
scribble: z.boolean(),
|
||||
});
|
||||
export type _PidiProcessorConfig = Required<Pick<PidiImageProcessorInvocation, 'id' | 'type' | 'safe' | 'scribble'>>;
|
||||
export type PidiProcessorConfig = z.infer<typeof zPidiProcessorConfig>;
|
||||
|
||||
const zZoeDepthProcessorConfig = z.object({
|
||||
id: zId,
|
||||
type: z.literal('zoe_depth_image_processor'),
|
||||
});
|
||||
export type _ZoeDepthProcessorConfig = Required<Pick<ZoeDepthImageProcessorInvocation, 'id' | 'type'>>;
|
||||
export type ZoeDepthProcessorConfig = z.infer<typeof zZoeDepthProcessorConfig>;
|
||||
|
||||
const zProcessorConfig = z.discriminatedUnion('type', [
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
import type { Modifier } from '@dnd-kit/core';
|
||||
import { getEventCoordinates } from '@dnd-kit/utilities';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $viewport } from 'features/nodes/store/nodesSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
const selectZoom = createSelector([selectNodesSlice, activeTabNameSelector], (nodes, activeTabName) =>
|
||||
activeTabName === 'workflows' ? nodes.viewport.zoom : 1
|
||||
);
|
||||
|
||||
/**
|
||||
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
|
||||
*/
|
||||
export const useScaledModifer = () => {
|
||||
const zoom = useAppSelector(selectZoom);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const workflowsViewport = useStore($viewport);
|
||||
const modifier: Modifier = useCallback(
|
||||
({ activatorEvent, draggingNodeRect, transform }) => {
|
||||
if (draggingNodeRect && activatorEvent) {
|
||||
const zoom = activeTabName === 'workflows' ? workflowsViewport.zoom : 1;
|
||||
const activatorCoordinates = getEventCoordinates(activatorEvent);
|
||||
|
||||
if (!activatorCoordinates) {
|
||||
@@ -42,7 +40,7 @@ export const useScaledModifer = () => {
|
||||
|
||||
return transform;
|
||||
},
|
||||
[zoom]
|
||||
[activeTabName, workflowsViewport.zoom]
|
||||
);
|
||||
|
||||
return modifier;
|
||||
|
||||
@@ -75,8 +75,8 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
<CompositeNumberInput
|
||||
value={lora.weight}
|
||||
onChange={handleChange}
|
||||
min={-5}
|
||||
max={5}
|
||||
min={-10}
|
||||
max={10}
|
||||
step={0.01}
|
||||
w={20}
|
||||
flexShrink={0}
|
||||
|
||||
@@ -2,26 +2,34 @@ import 'reactflow/dist/style.css';
|
||||
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
|
||||
import type { SelectInstance } from 'chakra-react-select';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import {
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverOpened,
|
||||
$cursorPos,
|
||||
$isAddNodePopoverOpen,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
closeAddNodePopover,
|
||||
connectionMade,
|
||||
nodeAdded,
|
||||
selectNodesSlice,
|
||||
openAddNodePopover,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { filter, map, memoize, some } from 'lodash-es';
|
||||
import type { KeyboardEventHandler } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const createRegex = memoize(
|
||||
(inputValue: string) =>
|
||||
@@ -54,26 +62,30 @@ const AddNodePopover = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const templates = useStore($templates);
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const isOpen = useStore($isAddNodePopoverOpen);
|
||||
const store = useAppStore();
|
||||
|
||||
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
|
||||
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const filteredTemplates = useMemo(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const filteredNodeTemplates = fieldFilter
|
||||
? filter(nodes.templates, (template) => {
|
||||
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
|
||||
if (!pendingConnection) {
|
||||
return map(templates);
|
||||
}
|
||||
|
||||
return some(handles, (handle) => {
|
||||
const sourceType = handleFilter === 'source' ? fieldFilter : handle.type;
|
||||
const targetType = handleFilter === 'target' ? fieldFilter : handle.type;
|
||||
return filter(templates, (template) => {
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
|
||||
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
|
||||
return some(fields, (field) => {
|
||||
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
});
|
||||
}, [templates, pendingConnection]);
|
||||
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
})
|
||||
: map(nodes.templates);
|
||||
|
||||
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
|
||||
const options = useMemo(() => {
|
||||
const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
|
||||
return {
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
@@ -83,15 +95,15 @@ const AddNodePopover = () => {
|
||||
});
|
||||
|
||||
//We only want these nodes if we're not filtered
|
||||
if (fieldFilter === null) {
|
||||
options.push({
|
||||
if (!pendingConnection) {
|
||||
_options.push({
|
||||
label: t('nodes.currentImage'),
|
||||
value: 'current_image',
|
||||
description: t('nodes.currentImageDescription'),
|
||||
tags: ['progress'],
|
||||
});
|
||||
|
||||
options.push({
|
||||
_options.push({
|
||||
label: t('nodes.notes'),
|
||||
value: 'notes',
|
||||
description: t('nodes.notesDescription'),
|
||||
@@ -99,18 +111,15 @@ const AddNodePopover = () => {
|
||||
});
|
||||
}
|
||||
|
||||
options.sort((a, b) => a.label.localeCompare(b.label));
|
||||
_options.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return { options };
|
||||
});
|
||||
|
||||
const { options } = useAppSelector(selector);
|
||||
const isOpen = useAppSelector((s) => s.nodes.isAddNodePopoverOpen);
|
||||
return _options;
|
||||
}, [filteredTemplates, pendingConnection, t]);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
const invocation = buildInvocation(nodeType);
|
||||
if (!invocation) {
|
||||
(nodeType: string): AnyNode | null => {
|
||||
const node = buildInvocation(nodeType);
|
||||
if (!node) {
|
||||
const errorMessage = t('nodes.unknownNode', {
|
||||
nodeType: nodeType,
|
||||
});
|
||||
@@ -118,10 +127,11 @@ const AddNodePopover = () => {
|
||||
status: 'error',
|
||||
title: errorMessage,
|
||||
});
|
||||
return;
|
||||
return null;
|
||||
}
|
||||
|
||||
dispatch(nodeAdded(invocation));
|
||||
const cursorPos = $cursorPos.get();
|
||||
dispatch(nodeAdded({ node, cursorPos }));
|
||||
return node;
|
||||
},
|
||||
[dispatch, buildInvocation, toaster, t]
|
||||
);
|
||||
@@ -131,52 +141,50 @@ const AddNodePopover = () => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
addNode(v.value);
|
||||
dispatch(addNodePopoverClosed());
|
||||
const node = addNode(v.value);
|
||||
|
||||
// Auto-connect an edge if we just added a node and have a pending connection
|
||||
if (pendingConnection && isInvocationNode(node)) {
|
||||
const template = templates[node.data.type];
|
||||
assert(template, 'Template not found');
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
}
|
||||
|
||||
closeAddNodePopover();
|
||||
},
|
||||
[addNode, dispatch]
|
||||
[addNode, dispatch, pendingConnection, store, templates]
|
||||
);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
dispatch(addNodePopoverClosed());
|
||||
}, [dispatch]);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
dispatch(addNodePopoverOpened());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleHotkeyOpen: HotkeyCallback = useCallback(
|
||||
(e) => {
|
||||
e.preventDefault();
|
||||
onOpen();
|
||||
flushSync(() => {
|
||||
selectRef.current?.inputRef?.focus();
|
||||
});
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
|
||||
e.preventDefault();
|
||||
openAddNodePopover();
|
||||
flushSync(() => {
|
||||
selectRef.current?.inputRef?.focus();
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
closeAddNodePopover();
|
||||
}, []);
|
||||
|
||||
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
|
||||
useHotkeys(['escape'], handleHotkeyClose);
|
||||
const onKeyDown: KeyboardEventHandler = useCallback(
|
||||
(e) => {
|
||||
if (e.key === 'Escape') {
|
||||
onClose();
|
||||
}
|
||||
},
|
||||
[onClose]
|
||||
);
|
||||
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
|
||||
if (e.key === 'Escape') {
|
||||
closeAddNodePopover();
|
||||
}
|
||||
}, []);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onClose={closeAddNodePopover}
|
||||
placement="bottom"
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
@@ -206,7 +214,7 @@ const AddNodePopover = () => {
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
filterOption={filterOption}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
onMenuClose={closeAddNodePopover}
|
||||
onKeyDown={onKeyDown}
|
||||
inputRef={inputRef}
|
||||
closeMenuOnSelect={false}
|
||||
|
||||
@@ -1,34 +1,33 @@
|
||||
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import {
|
||||
connectionEnded,
|
||||
$cursorPos,
|
||||
$isAddNodePopoverOpen,
|
||||
$isUpdatingEdge,
|
||||
$pendingConnection,
|
||||
$viewport,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeAdded,
|
||||
edgeChangeStarted,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
nodesChanged,
|
||||
nodesDeleted,
|
||||
redo,
|
||||
selectedAll,
|
||||
selectedEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
selectionCopied,
|
||||
selectionPasted,
|
||||
viewportChanged,
|
||||
undo,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type {
|
||||
OnConnect,
|
||||
OnConnectEnd,
|
||||
OnConnectStart,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnEdgeUpdateFunc,
|
||||
@@ -36,12 +35,11 @@ import type {
|
||||
OnMoveEnd,
|
||||
OnNodesChange,
|
||||
OnNodesDelete,
|
||||
OnSelectionChangeFunc,
|
||||
ProOptions,
|
||||
ReactFlowProps,
|
||||
XYPosition,
|
||||
ReactFlowState,
|
||||
} from 'reactflow';
|
||||
import { Background, ReactFlow } from 'reactflow';
|
||||
import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow';
|
||||
|
||||
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
|
||||
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
|
||||
@@ -68,17 +66,23 @@ const proOptions: ProOptions = { hideAttribution: true };
|
||||
|
||||
const snapGrid: [number, number] = [25, 25];
|
||||
|
||||
const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection;
|
||||
|
||||
export const Flow = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodes = useAppSelector((s) => s.nodes.nodes);
|
||||
const edges = useAppSelector((s) => s.nodes.edges);
|
||||
const viewport = useAppSelector((s) => s.nodes.viewport);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.nodes.shouldSnapToGrid);
|
||||
const selectionMode = useAppSelector((s) => s.nodes.selectionMode);
|
||||
const nodes = useAppSelector((s) => s.nodes.present.nodes);
|
||||
const edges = useAppSelector((s) => s.nodes.present.edges);
|
||||
const viewport = useStore($viewport);
|
||||
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
|
||||
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
||||
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
|
||||
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
|
||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||
const cursorPosition = useRef<XYPosition | null>(null);
|
||||
const isValidConnection = useIsValidConnection();
|
||||
const cancelConnection = useReactFlowStore(selectCancelConnection);
|
||||
useWorkflowWatcher();
|
||||
useSyncExecutionState();
|
||||
const [borderRadius] = useToken('radii', ['base']);
|
||||
|
||||
const flowStyles = useMemo<CSSProperties>(
|
||||
@@ -102,32 +106,6 @@ export const Flow = memo(() => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectStart: OnConnectStart = useCallback(
|
||||
(event, params) => {
|
||||
dispatch(connectionStarted(params));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnect: OnConnect = useCallback(
|
||||
(connection) => {
|
||||
dispatch(connectionMade(connection));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectEnd: OnConnectEnd = useCallback(() => {
|
||||
if (!cursorPosition.current) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
connectionEnded({
|
||||
cursorPosition: cursorPosition.current,
|
||||
mouseOverNodeId: $mouseOverNode.get(),
|
||||
})
|
||||
);
|
||||
}, [dispatch]);
|
||||
|
||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||
(edges) => {
|
||||
dispatch(edgesDeleted(edges));
|
||||
@@ -142,20 +120,9 @@ export const Flow = memo(() => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSelectionChange: OnSelectionChangeFunc = useCallback(
|
||||
({ nodes, edges }) => {
|
||||
dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
|
||||
dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleMoveEnd: OnMoveEnd = useCallback(
|
||||
(e, viewport) => {
|
||||
dispatch(viewportChanged(viewport));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => {
|
||||
$viewport.set(viewport);
|
||||
}, []);
|
||||
|
||||
const { onCloseGlobal } = useGlobalMenuClose();
|
||||
const handlePaneClick = useCallback(() => {
|
||||
@@ -169,11 +136,12 @@ export const Flow = memo(() => {
|
||||
|
||||
const onMouseMove = useCallback((event: MouseEvent<HTMLDivElement>) => {
|
||||
if (flowWrapper.current?.getBoundingClientRect()) {
|
||||
cursorPosition.current =
|
||||
$cursorPos.set(
|
||||
$flow.get()?.screenToFlowPosition({
|
||||
x: event.clientX,
|
||||
y: event.clientY,
|
||||
}) ?? null;
|
||||
}) ?? null
|
||||
);
|
||||
}
|
||||
}, []);
|
||||
|
||||
@@ -195,19 +163,18 @@ export const Flow = memo(() => {
|
||||
|
||||
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
$isUpdatingEdge.set(true);
|
||||
// update mouse event
|
||||
edgeUpdateMouseEvent.current = e;
|
||||
// always delete the edge when starting an updated
|
||||
dispatch(edgeDeleted(edge.id));
|
||||
dispatch(edgeChangeStarted());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
|
||||
(_oldEdge, newConnection) => {
|
||||
// instead of updating the edge (we deleted it earlier), we instead create
|
||||
// a new one.
|
||||
// Because we deleted the edge when the update started, we must create a new edge from the connection
|
||||
dispatch(connectionMade(newConnection));
|
||||
},
|
||||
[dispatch]
|
||||
@@ -215,8 +182,10 @@ export const Flow = memo(() => {
|
||||
|
||||
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> = useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
// Handle the case where user begins a drag but didn't move the cursor -
|
||||
// bc we deleted the edge, we need to add it back
|
||||
$isUpdatingEdge.set(false);
|
||||
$pendingConnection.set(null);
|
||||
// Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting
|
||||
// the edge update - we need to add it back
|
||||
if (
|
||||
// ignore touch events
|
||||
!('touches' in e) &&
|
||||
@@ -233,23 +202,64 @@ export const Flow = memo(() => {
|
||||
|
||||
// #endregion
|
||||
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionCopied());
|
||||
});
|
||||
const { copySelection, pasteSelection } = useCopyPaste();
|
||||
|
||||
useHotkeys(['Ctrl+a', 'Meta+a'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectedAll());
|
||||
});
|
||||
const onCopyHotkey = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
e.preventDefault();
|
||||
copySelection();
|
||||
},
|
||||
[copySelection]
|
||||
);
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], onCopyHotkey);
|
||||
|
||||
useHotkeys(['Ctrl+v', 'Meta+v'], (e) => {
|
||||
if (!cursorPosition.current) {
|
||||
return;
|
||||
const onSelectAllHotkey = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectedAll());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey);
|
||||
|
||||
const onPasteHotkey = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
e.preventDefault();
|
||||
pasteSelection();
|
||||
},
|
||||
[pasteSelection]
|
||||
);
|
||||
useHotkeys(['Ctrl+v', 'Meta+v'], onPasteHotkey);
|
||||
|
||||
const onPasteWithEdgesToNodesHotkey = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
e.preventDefault();
|
||||
pasteSelection(true);
|
||||
},
|
||||
[pasteSelection]
|
||||
);
|
||||
useHotkeys(['Ctrl+shift+v', 'Meta+shift+v'], onPasteWithEdgesToNodesHotkey);
|
||||
|
||||
const onUndoHotkey = useCallback(() => {
|
||||
if (mayUndo) {
|
||||
dispatch(undo());
|
||||
}
|
||||
e.preventDefault();
|
||||
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
|
||||
});
|
||||
}, [dispatch, mayUndo]);
|
||||
useHotkeys(['meta+z', 'ctrl+z'], onUndoHotkey);
|
||||
|
||||
const onRedoHotkey = useCallback(() => {
|
||||
if (mayRedo) {
|
||||
dispatch(redo());
|
||||
}
|
||||
}, [dispatch, mayRedo]);
|
||||
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
|
||||
|
||||
const onEscapeHotkey = useCallback(() => {
|
||||
$pendingConnection.set(null);
|
||||
$isAddNodePopoverOpen.set(false);
|
||||
cancelConnection();
|
||||
}, [cancelConnection]);
|
||||
useHotkeys('esc', onEscapeHotkey);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
@@ -274,7 +284,6 @@ export const Flow = memo(() => {
|
||||
onConnectEnd={onConnectEnd}
|
||||
onMoveEnd={handleMoveEnd}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
onSelectionChange={handleSelectionChange}
|
||||
isValidConnection={isValidConnection}
|
||||
minZoom={0.1}
|
||||
snapToGrid={shouldSnapToGrid}
|
||||
@@ -285,6 +294,7 @@ export const Flow = memo(() => {
|
||||
onPaneClick={handlePaneClick}
|
||||
deleteKeyCode={DELETE_KEYS}
|
||||
selectionMode={selectionMode}
|
||||
elevateEdgesOnSelect
|
||||
>
|
||||
<Background />
|
||||
</ReactFlow>
|
||||
|
||||
@@ -1,26 +1,33 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $pendingConnection } from 'features/nodes/store/nodesSlice';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { ConnectionLineComponentProps } from 'reactflow';
|
||||
import { getBezierPath } from 'reactflow';
|
||||
|
||||
const selectStroke = createSelector(selectNodesSlice, (nodes) =>
|
||||
nodes.shouldColorEdges ? getFieldColor(nodes.connectionStartFieldType) : colorTokenToCssVar('base.500')
|
||||
);
|
||||
|
||||
const selectClassName = createSelector(selectNodesSlice, (nodes) =>
|
||||
nodes.shouldAnimateEdges ? 'react-flow__custom_connection-path animated' : 'react-flow__custom_connection-path'
|
||||
);
|
||||
|
||||
const pathStyles: CSSProperties = { opacity: 0.8 };
|
||||
|
||||
const CustomConnectionLine = ({ fromX, fromY, fromPosition, toX, toY, toPosition }: ConnectionLineComponentProps) => {
|
||||
const stroke = useAppSelector(selectStroke);
|
||||
const className = useAppSelector(selectClassName);
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const shouldColorEdges = useAppSelector((state) => state.workflowSettings.shouldColorEdges);
|
||||
const shouldAnimateEdges = useAppSelector((state) => state.workflowSettings.shouldAnimateEdges);
|
||||
const stroke = useMemo(() => {
|
||||
if (shouldColorEdges && pendingConnection) {
|
||||
return getFieldColor(pendingConnection.fieldTemplate.type);
|
||||
} else {
|
||||
return colorTokenToCssVar('base.500');
|
||||
}
|
||||
}, [pendingConnection, shouldColorEdges]);
|
||||
const className = useMemo(() => {
|
||||
if (shouldAnimateEdges) {
|
||||
return 'react-flow__custom_connection-path animated';
|
||||
} else {
|
||||
return 'react-flow__custom_connection-path';
|
||||
}
|
||||
}, [shouldAnimateEdges]);
|
||||
|
||||
const pathParams = {
|
||||
sourceX: fromX,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Badge, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { EdgeProps } from 'reactflow';
|
||||
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
|
||||
@@ -22,9 +24,10 @@ const InvocationCollapsedEdge = ({
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps<{ count: number }>) => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected),
|
||||
[selected, source, sourceHandleId, target, targetHandleId]
|
||||
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
|
||||
[templates, selected, source, sourceHandleId, target, targetHandleId]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate } = useAppSelector(selector);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { EdgeProps } from 'reactflow';
|
||||
@@ -21,13 +23,14 @@ const InvocationDefaultEdge = ({
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps) => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected),
|
||||
[source, sourceHandleId, target, targetHandleId, selected]
|
||||
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
|
||||
[templates, source, sourceHandleId, target, targetHandleId, selected]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
|
||||
const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
@@ -14,6 +15,7 @@ const defaultReturnValue = {
|
||||
};
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
templates: Templates,
|
||||
source: string,
|
||||
sourceHandleId: string | null | undefined,
|
||||
target: string,
|
||||
@@ -22,7 +24,8 @@ export const makeEdgeSelector = (
|
||||
) =>
|
||||
createMemoizedSelector(
|
||||
selectNodesSlice,
|
||||
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
|
||||
selectWorkflowSettingsSlice,
|
||||
(nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
@@ -33,19 +36,20 @@ export const makeEdgeSelector = (
|
||||
return defaultReturnValue;
|
||||
}
|
||||
|
||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||
const sourceNodeTemplate = templates[sourceNode.data.type];
|
||||
const targetNodeTemplate = templates[targetNode.data.type];
|
||||
|
||||
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
|
||||
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
|
||||
const stroke =
|
||||
sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
label,
|
||||
};
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import type { NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi';
|
||||
|
||||
@@ -24,12 +22,7 @@ const circleStyles: SystemStyleObject = {
|
||||
};
|
||||
|
||||
const InvocationNodeStatusIndicator = ({ nodeId }: Props) => {
|
||||
const selectNodeExecutionState = useMemo(
|
||||
() => createMemoizedSelector(selectNodesSlice, (nodes) => nodes.nodeExecutionStates[nodeId]),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
|
||||
const nodeExecutionState = useExecutionState(nodeId);
|
||||
|
||||
if (!nodeExecutionState) {
|
||||
return null;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { NodeProps } from 'reactflow';
|
||||
@@ -11,13 +11,13 @@ import InvocationNodeUnknownFallback from './InvocationNodeUnknownFallback';
|
||||
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
|
||||
const { data, selected } = props;
|
||||
const { id: nodeId, type, isOpen, label } = data;
|
||||
const templates = useStore($templates);
|
||||
const hasTemplate = useMemo(() => Boolean(templates[type]), [templates, type]);
|
||||
const nodeExists = useAppSelector((s) => Boolean(s.nodes.present.nodes.find((n) => n.id === nodeId)));
|
||||
|
||||
const hasTemplateSelector = useMemo(
|
||||
() => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])),
|
||||
[type]
|
||||
);
|
||||
|
||||
const hasTemplate = useAppSelector(hasTemplateSelector);
|
||||
if (!nodeExists) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!hasTemplate) {
|
||||
return (
|
||||
|
||||
@@ -37,7 +37,8 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
const [localTitle, setLocalTitle] = useState(label || fieldTemplateTitle || t('nodes.unknownField'));
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (newTitle: string) => {
|
||||
async (newTitleRaw: string) => {
|
||||
const newTitle = newTitleRaw.trim();
|
||||
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
|
||||
return;
|
||||
}
|
||||
@@ -57,22 +58,22 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
}, [label, fieldTemplateTitle, t]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
<Editable
|
||||
value={localTitle}
|
||||
onChange={handleChange}
|
||||
onSubmit={handleSubmit}
|
||||
as={Flex}
|
||||
ref={ref}
|
||||
position="relative"
|
||||
overflow="hidden"
|
||||
alignItems="center"
|
||||
justifyContent="flex-start"
|
||||
gap={1}
|
||||
w="full"
|
||||
>
|
||||
<Editable
|
||||
value={localTitle}
|
||||
onChange={handleChange}
|
||||
onSubmit={handleSubmit}
|
||||
as={Flex}
|
||||
ref={ref}
|
||||
position="relative"
|
||||
overflow="hidden"
|
||||
alignItems="center"
|
||||
justifyContent="flex-start"
|
||||
gap={1}
|
||||
w="full"
|
||||
<Tooltip
|
||||
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<EditablePreview
|
||||
fontWeight="semibold"
|
||||
@@ -80,10 +81,10 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
noOfLines={1}
|
||||
color={isMissingInput ? 'error.300' : 'base.300'}
|
||||
/>
|
||||
<EditableInput className="nodrag" sx={editableInputStyles} />
|
||||
<EditableControls />
|
||||
</Editable>
|
||||
</Tooltip>
|
||||
</Tooltip>
|
||||
<EditableInput className="nodrag" sx={editableInputStyles} />
|
||||
<EditableControls />
|
||||
</Editable>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -127,7 +128,15 @@ const EditableControls = memo(() => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex onClick={handleClick} position="absolute" w="full" h="full" top={0} insetInlineStart={0} cursor="text" />
|
||||
<Flex
|
||||
onClick={handleClick}
|
||||
position="absolute"
|
||||
w="min-content"
|
||||
h="full"
|
||||
top={0}
|
||||
insetInlineStart={0}
|
||||
cursor="text"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldTemplate.input === 'connection') {
|
||||
if (fieldTemplate.input === 'connection' || isConnected) {
|
||||
return (
|
||||
<InputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl isInvalid={isMissingInput} isDisabled={isConnected} px={2}>
|
||||
@@ -95,7 +95,15 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
|
||||
return (
|
||||
<InputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl isInvalid={isMissingInput} isDisabled={isConnected} orientation="vertical" px={2}>
|
||||
<FormControl
|
||||
isInvalid={isMissingInput}
|
||||
isDisabled={isConnected}
|
||||
// Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
|
||||
// connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
|
||||
pointerEvents={isConnected ? 'none' : 'auto'}
|
||||
orientation="vertical"
|
||||
px={2}
|
||||
>
|
||||
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Flex>
|
||||
<EditableFieldTitle
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { nodeExclusivelySelected, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import type { MouseEvent, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
nodeId: string;
|
||||
@@ -20,16 +20,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
||||
|
||||
const selectIsInProgress = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectNodesSlice,
|
||||
(nodes) => nodes.nodeExecutionStates[nodeId]?.status === zNodeStatus.enum.IN_PROGRESS
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const isInProgress = useAppSelector(selectIsInProgress);
|
||||
const executionState = useExecutionState(nodeId);
|
||||
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
|
||||
|
||||
const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [
|
||||
'nodeInProgress',
|
||||
@@ -39,7 +31,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const opacity = useAppSelector((s) => s.nodes.nodeOpacity);
|
||||
const opacity = useAppSelector((s) => s.workflowSettings.nodeOpacity);
|
||||
const { onCloseGlobal } = useGlobalMenuClose();
|
||||
|
||||
const handleClick = useCallback(
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { CompositeSlider, Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { nodeOpacityChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { nodeOpacityChanged } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const NodeOpacitySlider = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodeOpacity = useAppSelector((s) => s.nodes.nodeOpacity);
|
||||
const nodeOpacity = useAppSelector((s) => s.workflowSettings.nodeOpacity);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
// shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { shouldShowMinimapPanelChanged } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -19,9 +16,9 @@ const ViewportControls = () => {
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const dispatch = useAppDispatch();
|
||||
// const shouldShowFieldTypeLegend = useAppSelector(
|
||||
// (s) => s.nodes.shouldShowFieldTypeLegend
|
||||
// (s) => s.nodes.present.shouldShowFieldTypeLegend
|
||||
// );
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
|
||||
|
||||
const handleClickedZoomIn = useCallback(() => {
|
||||
zoomIn();
|
||||
|
||||
@@ -16,7 +16,7 @@ const minimapStyles: SystemStyleObject = {
|
||||
};
|
||||
|
||||
const MinimapPanel = () => {
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
|
||||
|
||||
return (
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { openAddNodePopover } from 'features/nodes/store/nodesSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
const AddNodeButton = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const handleOpenAddNodePopover = useCallback(() => {
|
||||
dispatch(addNodePopoverOpened());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
tooltip={t('nodes.addNodeToolTip')}
|
||||
aria-label={t('nodes.addNode')}
|
||||
icon={<PiPlusBold />}
|
||||
onClick={handleOpenAddNodePopover}
|
||||
onClick={openAddNodePopover}
|
||||
pointerEvents="auto"
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -21,13 +21,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import ReloadNodeTemplatesButton from 'features/nodes/components/flow/panels/TopRightPanel/ReloadSchemaButton';
|
||||
import {
|
||||
selectionModeChanged,
|
||||
selectNodesSlice,
|
||||
selectWorkflowSettingsSlice,
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowEdgeLabelsChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
} from 'features/nodes/store/workflowSettingsSlice';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -35,7 +35,7 @@ import { SelectionMode } from 'reactflow';
|
||||
|
||||
const formLabelProps: FormLabelProps = { flexGrow: 1 };
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const selector = createMemoizedSelector(selectWorkflowSettingsSlice, (workflowSettings) => {
|
||||
const {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
@@ -43,7 +43,7 @@ const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
shouldColorEdges,
|
||||
shouldShowEdgeLabels,
|
||||
selectionMode,
|
||||
} = nodes;
|
||||
} = workflowSettings;
|
||||
return {
|
||||
shouldAnimateEdges,
|
||||
shouldValidateGraph,
|
||||
|
||||
@@ -3,27 +3,21 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
return {
|
||||
data: lastSelectedNode?.data,
|
||||
};
|
||||
});
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => selectLastSelectedNode(nodes));
|
||||
|
||||
const InspectorDataTab = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data } = useAppSelector(selector);
|
||||
const lastSelectedNode = useAppSelector(selector);
|
||||
|
||||
if (!data) {
|
||||
if (!lastSelectedNode) {
|
||||
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
|
||||
}
|
||||
|
||||
return <DataViewer data={data} label="Node Data" />;
|
||||
return <DataViewer data={lastSelectedNode.data} label="Node Data" />;
|
||||
};
|
||||
|
||||
export default memo(InspectorDataTab);
|
||||
|
||||
@@ -1,36 +1,39 @@
|
||||
import { Box, Flex, FormControl, FormLabel, HStack, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import EditableNodeTitle from './details/EditableNodeTitle';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
nodeId: lastSelectedNode.data.id,
|
||||
nodeVersion: lastSelectedNode.data.version,
|
||||
templateTitle: lastSelectedNodeTemplate.title,
|
||||
};
|
||||
});
|
||||
|
||||
const InspectorDetailsTab = () => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNode = selectLastSelectedNode(nodes);
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
nodeId: lastSelectedNode.data.id,
|
||||
nodeVersion: lastSelectedNode.data.version,
|
||||
templateTitle: lastSelectedNodeTemplate.title,
|
||||
};
|
||||
}),
|
||||
[templates]
|
||||
);
|
||||
const data = useAppSelector(selector);
|
||||
const { t } = useTranslation();
|
||||
|
||||
|
||||
@@ -1,46 +1,49 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { ImageOutput } from 'services/api/types';
|
||||
import type { AnyResult } from 'services/events/types';
|
||||
|
||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
||||
|
||||
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
outputs: nes.outputs,
|
||||
outputType: lastSelectedNodeTemplate.outputType,
|
||||
};
|
||||
});
|
||||
|
||||
const InspectorOutputsTab = () => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNode = selectLastSelectedNode(nodes);
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
nodeId: lastSelectedNode.id,
|
||||
outputType: lastSelectedNodeTemplate.outputType,
|
||||
};
|
||||
}),
|
||||
[templates]
|
||||
);
|
||||
const data = useAppSelector(selector);
|
||||
const nes = useExecutionState(data?.nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!data) {
|
||||
if (!data || !nes) {
|
||||
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
|
||||
}
|
||||
|
||||
if (data.outputs.length === 0) {
|
||||
if (nes.outputs.length === 0) {
|
||||
return <IAINoContentFallback label={t('nodes.noOutputRecorded')} icon={null} />;
|
||||
}
|
||||
|
||||
@@ -49,11 +52,11 @@ const InspectorOutputsTab = () => {
|
||||
<ScrollableContent>
|
||||
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
|
||||
{data.outputType === 'image_output' ? (
|
||||
data.outputs.map((result, i) => (
|
||||
nes.outputs.map((result, i) => (
|
||||
<ImageOutputPreview key={getKey(result, i)} output={result as ImageOutput} />
|
||||
))
|
||||
) : (
|
||||
<DataViewer data={data.outputs} label={t('nodes.nodeOutputs')} />
|
||||
<DataViewer data={nes.outputs} label={t('nodes.nodeOutputs')} />
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
|
||||
@@ -1,25 +1,26 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { memo } from 'react';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
return {
|
||||
template: lastSelectedNodeTemplate,
|
||||
};
|
||||
});
|
||||
|
||||
const NodeTemplateInspector = () => {
|
||||
const { template } = useAppSelector(selector);
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNode = selectLastSelectedNode(nodes);
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
return lastSelectedNodeTemplate;
|
||||
}),
|
||||
[templates]
|
||||
);
|
||||
const template = useAppSelector(selector);
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!template) {
|
||||
|
||||
@@ -1,31 +1,39 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
|
||||
const selector = useMemo(
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const selectConnectedFieldNames = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
const fields = map(template.inputs).filter(
|
||||
(field) =>
|
||||
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
}),
|
||||
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
|
||||
nodesSlice.edges
|
||||
.filter((e) => e.target === nodeId)
|
||||
.map((e) => e.targetHandle)
|
||||
.filter(Boolean)
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
|
||||
|
||||
const fieldNames = useMemo(() => {
|
||||
const fields = map(template.inputs).filter((field) => {
|
||||
if (connectedFieldNames.includes(field.name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (
|
||||
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
});
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
}, [connectedFieldNames, template.inputs]);
|
||||
|
||||
const fieldNames = useAppSelector(selector);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
|
||||
@@ -8,8 +9,7 @@ import { useCallback } from 'react';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
export const useBuildNode = () => {
|
||||
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
|
||||
|
||||
const templates = useStore($templates);
|
||||
const flow = useReactFlow();
|
||||
|
||||
return useCallback(
|
||||
@@ -41,10 +41,10 @@ export const useBuildNode = () => {
|
||||
|
||||
// TODO: Keep track of invocation types so we do not need to cast this
|
||||
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
|
||||
const template = nodeTemplates[type] as InvocationTemplate;
|
||||
const template = templates[type] as InvocationTemplate;
|
||||
|
||||
return buildInvocationNode(position, template);
|
||||
},
|
||||
[nodeTemplates, flow]
|
||||
[templates, flow]
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import {
|
||||
$isAddNodePopoverOpen,
|
||||
$isUpdatingEdge,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
connectionMade,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useConnection = () => {
|
||||
const store = useAppStore();
|
||||
const templates = useStore($templates);
|
||||
|
||||
const onConnectStart = useCallback<OnConnectStart>(
|
||||
(event, params) => {
|
||||
const nodes = store.getState().nodes.present.nodes;
|
||||
const { nodeId, handleId, handleType } = params;
|
||||
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
|
||||
const node = nodes.find((n) => n.id === nodeId);
|
||||
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
|
||||
const template = templates[node.data.type];
|
||||
assert(template, `Template not found for node type: ${node.data.type}`);
|
||||
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
|
||||
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
|
||||
$pendingConnection.set({
|
||||
node,
|
||||
template,
|
||||
fieldTemplate,
|
||||
});
|
||||
},
|
||||
[store, templates]
|
||||
);
|
||||
const onConnect = useCallback<OnConnect>(
|
||||
(connection) => {
|
||||
const { dispatch } = store;
|
||||
dispatch(connectionMade(connection));
|
||||
$pendingConnection.set(null);
|
||||
},
|
||||
[store]
|
||||
);
|
||||
const onConnectEnd = useCallback<OnConnectEnd>(() => {
|
||||
const { dispatch } = store;
|
||||
const pendingConnection = $pendingConnection.get();
|
||||
const isUpdatingEdge = $isUpdatingEdge.get();
|
||||
const mouseOverNodeId = $mouseOverNode.get();
|
||||
|
||||
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
|
||||
// update logic can finish up
|
||||
if (isUpdatingEdge && !mouseOverNodeId) {
|
||||
$pendingConnection.set(null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!pendingConnection) {
|
||||
return;
|
||||
}
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
if (mouseOverNodeId) {
|
||||
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
|
||||
if (!candidateNode) {
|
||||
// The mouse is over a non-invocation node - bail
|
||||
return;
|
||||
}
|
||||
const candidateTemplate = templates[candidateNode.data.type];
|
||||
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
|
||||
const connection = getFirstValidConnection(
|
||||
templates,
|
||||
nodes,
|
||||
edges,
|
||||
pendingConnection,
|
||||
candidateNode,
|
||||
candidateTemplate
|
||||
);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
$pendingConnection.set(null);
|
||||
} else {
|
||||
// The mouse is not over a node - we should open the add node popover
|
||||
$isAddNodePopoverOpen.set(true);
|
||||
}
|
||||
}, [store, templates]);
|
||||
|
||||
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
||||
return api;
|
||||
};
|
||||
@@ -1,34 +1,39 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
|
||||
const selector = useMemo(
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const selectConnectedFieldNames = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
// get the visible fields
|
||||
const fields = map(template.inputs).filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
}),
|
||||
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
|
||||
nodesSlice.edges
|
||||
.filter((e) => e.target === nodeId)
|
||||
.map((e) => e.targetHandle)
|
||||
.filter(Boolean)
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
|
||||
|
||||
const fieldNames = useAppSelector(selector);
|
||||
const fieldNames = useMemo(() => {
|
||||
// get the visible fields
|
||||
const fields = map(template.inputs).filter((field) => {
|
||||
if (connectedFieldNames.includes(field.name)) {
|
||||
return true;
|
||||
}
|
||||
return (
|
||||
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
);
|
||||
});
|
||||
|
||||
return getSortedFilteredFieldNames(fields);
|
||||
}, [connectedFieldNames, template.inputs]);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useFieldType } from './useFieldType.ts';
|
||||
|
||||
const selectIsConnectionInProgress = createSelector(
|
||||
selectNodesSlice,
|
||||
(nodes) => nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null
|
||||
);
|
||||
|
||||
type UseConnectionStateProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
@@ -18,6 +14,8 @@ type UseConnectionStateProps = {
|
||||
};
|
||||
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const templates = useStore($templates);
|
||||
const fieldType = useFieldType(nodeId, fieldName, kind);
|
||||
|
||||
const selectIsConnected = useMemo(
|
||||
@@ -36,25 +34,30 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
||||
[nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const selectIsConnectionStartField = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) =>
|
||||
Boolean(
|
||||
nodes.connectionStartParams?.nodeId === nodeId &&
|
||||
nodes.connectionStartParams?.handleId === fieldName &&
|
||||
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
|
||||
)
|
||||
makeConnectionErrorSelector(
|
||||
templates,
|
||||
pendingConnection,
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind === 'inputs' ? 'target' : 'source',
|
||||
fieldType
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(selectIsConnected);
|
||||
const isConnectionInProgress = useAppSelector(selectIsConnectionInProgress);
|
||||
const isConnectionStartField = useAppSelector(selectIsConnectionStartField);
|
||||
const isConnectionInProgress = useMemo(() => Boolean(pendingConnection), [pendingConnection]);
|
||||
const isConnectionStartField = useMemo(() => {
|
||||
if (!pendingConnection) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
pendingConnection.node.id === nodeId &&
|
||||
pendingConnection.fieldTemplate.name === fieldName &&
|
||||
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
|
||||
);
|
||||
}, [fieldName, kind, nodeId, pendingConnection]);
|
||||
const connectionError = useAppSelector(selectConnectionError);
|
||||
|
||||
const shouldDim = useMemo(
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import {
|
||||
$copiedEdges,
|
||||
$copiedNodes,
|
||||
$cursorPos,
|
||||
$edgesToCopiedNodes,
|
||||
selectionPasted,
|
||||
selectNodesSlice,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
|
||||
import { isEqual, uniqWith } from 'lodash-es';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const copySelection = () => {
|
||||
// Use the imperative API here so we don't have to pass the whole slice around
|
||||
const { getState } = getStore();
|
||||
const { nodes, edges } = selectNodesSlice(getState());
|
||||
const selectedNodes = nodes.filter((node) => node.selected);
|
||||
const selectedEdges = edges.filter((edge) => edge.selected);
|
||||
const edgesToSelectedNodes = edges.filter((edge) => selectedNodes.some((node) => node.id === edge.target));
|
||||
$copiedNodes.set(selectedNodes);
|
||||
$copiedEdges.set(selectedEdges);
|
||||
$edgesToCopiedNodes.set(edgesToSelectedNodes);
|
||||
};
|
||||
|
||||
const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
const { getState, dispatch } = getStore();
|
||||
const currentNodes = selectNodesSlice(getState()).nodes;
|
||||
const cursorPos = $cursorPos.get();
|
||||
|
||||
const copiedNodes = deepClone($copiedNodes.get());
|
||||
let copiedEdges = deepClone($copiedEdges.get());
|
||||
|
||||
if (withEdgesToCopiedNodes) {
|
||||
const edgesToCopiedNodes = deepClone($edgesToCopiedNodes.get());
|
||||
copiedEdges = uniqWith([...copiedEdges, ...edgesToCopiedNodes], isEqual);
|
||||
}
|
||||
|
||||
// Calculate an offset to reposition nodes to surround the cursor position, maintaining relative positioning
|
||||
const xCoords = copiedNodes.map((node) => node.position.x);
|
||||
const yCoords = copiedNodes.map((node) => node.position.y);
|
||||
const minX = Math.min(...xCoords);
|
||||
const minY = Math.min(...yCoords);
|
||||
const offsetX = cursorPos ? cursorPos.x - minX : 50;
|
||||
const offsetY = cursorPos ? cursorPos.y - minY : 50;
|
||||
|
||||
copiedNodes.forEach((node) => {
|
||||
const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY);
|
||||
node.position.x = x;
|
||||
node.position.y = y;
|
||||
// Pasted nodes are selected
|
||||
node.selected = true;
|
||||
// Also give em a fresh id
|
||||
const id = uuidv4();
|
||||
// Update the edges to point to the new node id
|
||||
for (const edge of copiedEdges) {
|
||||
if (edge.source === node.id) {
|
||||
edge.source = id;
|
||||
edge.id = edge.id.replace(node.data.id, id);
|
||||
}
|
||||
if (edge.target === node.id) {
|
||||
edge.target = id;
|
||||
edge.id = edge.id.replace(node.data.id, id);
|
||||
}
|
||||
}
|
||||
node.id = id;
|
||||
node.data.id = id;
|
||||
});
|
||||
|
||||
dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges }));
|
||||
};
|
||||
|
||||
const api = { copySelection, pasteSelection };
|
||||
|
||||
export const useCopyPaste = () => {
|
||||
return api;
|
||||
};
|
||||
@@ -0,0 +1,56 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeExecutionStates } from 'features/nodes/store/types';
|
||||
import type { NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { map } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
|
||||
export const $nodeExecutionStates = map<NodeExecutionStates>({});
|
||||
|
||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||
status: zNodeStatus.enum.PENDING,
|
||||
error: null,
|
||||
progress: null,
|
||||
progressImage: null,
|
||||
outputs: [],
|
||||
};
|
||||
|
||||
export const useExecutionState = (nodeId?: string) => {
|
||||
const executionStates = useStore($nodeExecutionStates, nodeId ? { keys: [nodeId] } : undefined);
|
||||
const executionState = useMemo(() => (nodeId ? executionStates[nodeId] : undefined), [executionStates, nodeId]);
|
||||
return executionState;
|
||||
};
|
||||
|
||||
const removeNodeExecutionState = (nodeId: string) => {
|
||||
$nodeExecutionStates.setKey(nodeId, undefined);
|
||||
};
|
||||
|
||||
export const upsertExecutionState = (nodeId: string, updates?: Partial<NodeExecutionState>) => {
|
||||
const state = $nodeExecutionStates.get()[nodeId];
|
||||
if (!state) {
|
||||
$nodeExecutionStates.setKey(nodeId, { ...deepClone(initialNodeExecutionState), nodeId, ...updates });
|
||||
} else {
|
||||
$nodeExecutionStates.setKey(nodeId, { ...state, ...updates });
|
||||
}
|
||||
};
|
||||
|
||||
const selectNodeIds = createMemoizedSelector(selectNodesSlice, (nodesSlice) => nodesSlice.nodes.map((node) => node.id));
|
||||
|
||||
export const useSyncExecutionState = () => {
|
||||
const nodeIds = useAppSelector(selectNodeIds);
|
||||
useEffect(() => {
|
||||
const nodeExecutionStates = $nodeExecutionStates.get();
|
||||
const nodeIdsToAdd = nodeIds.filter((id) => !nodeExecutionStates[id]);
|
||||
const nodeIdsToRemove = Object.keys(nodeExecutionStates).filter((id) => !nodeIds.includes(id));
|
||||
for (const id of nodeIdsToAdd) {
|
||||
upsertExecutionState(id);
|
||||
}
|
||||
for (const id of nodeIdsToRemove) {
|
||||
removeNodeExecutionState(id);
|
||||
}
|
||||
}, [nodeIds]);
|
||||
};
|
||||
@@ -1,20 +1,9 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
@@ -1,20 +1,9 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldTemplate = useMemo(() => template.outputs[fieldName] ?? null, [fieldName, template.outputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
@@ -1,27 +1,36 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectInvocationNodeType } from 'features/nodes/store/selectors';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useFieldTemplate = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
kind: 'inputs' | 'outputs'
|
||||
): FieldInputTemplate | FieldOutputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
}
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
): FieldInputTemplate | FieldOutputTemplate => {
|
||||
const templates = useStore($templates);
|
||||
const selectNodeType = useMemo(
|
||||
() => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
const nodeType = useAppSelector(selectNodeType);
|
||||
const fieldTemplate = useMemo(() => {
|
||||
const template = templates[nodeType];
|
||||
assert(template, `Template for node type ${nodeType} not found`);
|
||||
if (kind === 'inputs') {
|
||||
const fieldTemplate = template.inputs[fieldName];
|
||||
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
|
||||
return fieldTemplate;
|
||||
} else {
|
||||
const fieldTemplate = template.outputs[fieldName];
|
||||
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
|
||||
return fieldTemplate;
|
||||
}
|
||||
}, [fieldName, kind, nodeType, templates]);
|
||||
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
@@ -1,22 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null;
|
||||
}
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null;
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplateTitle = useAppSelector(selector);
|
||||
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const fieldTemplateTitle = useMemo(() => fieldTemplate.title, [fieldTemplate]);
|
||||
return fieldTemplateTitle;
|
||||
};
|
||||
|
||||
@@ -1,23 +1,9 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null;
|
||||
}
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null;
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
const fieldType = useAppSelector(selector);
|
||||
|
||||
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => {
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]);
|
||||
return fieldType;
|
||||
};
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
|
||||
const selector = createSelector(selectNodesSlice, (nodes) =>
|
||||
nodes.nodes.filter(isInvocationNode).some((node) => {
|
||||
const template = nodes.templates[node.data.type];
|
||||
if (!template) {
|
||||
return false;
|
||||
}
|
||||
return getNeedsUpdate(node, template);
|
||||
})
|
||||
);
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useGetNodesNeedUpdate = () => {
|
||||
const getNeedsUpdate = useAppSelector(selector);
|
||||
return getNeedsUpdate;
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) =>
|
||||
nodes.nodes.filter(isInvocationNode).some((node) => {
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return false;
|
||||
}
|
||||
return getNeedsUpdate(node.data, template);
|
||||
})
|
||||
),
|
||||
[templates]
|
||||
);
|
||||
const needsUpdate = useAppSelector(selector);
|
||||
return needsUpdate;
|
||||
};
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { some } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useHasImageOutput = (nodeId: string): boolean => {
|
||||
const selector = useMemo(
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const hasImageOutput = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
return some(
|
||||
template?.outputs,
|
||||
(output) =>
|
||||
output.type.name === 'ImageField' &&
|
||||
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
|
||||
template?.type !== 'image'
|
||||
);
|
||||
}),
|
||||
[nodeId]
|
||||
some(
|
||||
template?.outputs,
|
||||
(output) =>
|
||||
output.type.name === 'ImageField' &&
|
||||
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
|
||||
template?.type !== 'image'
|
||||
),
|
||||
[template]
|
||||
);
|
||||
|
||||
const hasImageOutput = useAppSelector(selector);
|
||||
return hasImageOutput;
|
||||
};
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
// TODO: enable this at some point
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
|
||||
@@ -13,7 +17,8 @@ import type { Connection, Node } from 'reactflow';
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const store = useAppStore();
|
||||
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph);
|
||||
const templates = useStore($templates);
|
||||
const shouldValidateGraph = useAppSelector((s) => s.workflowSettings.shouldValidateGraph);
|
||||
const isValidConnection = useCallback(
|
||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||
// Connection must have valid targets
|
||||
@@ -27,7 +32,7 @@ export const useIsValidConnection = () => {
|
||||
}
|
||||
|
||||
const state = store.getState();
|
||||
const { nodes, edges, templates } = state.nodes;
|
||||
const { nodes, edges } = state.nodes.present;
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
||||
@@ -40,6 +45,10 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!shouldValidateGraph) {
|
||||
// manual override!
|
||||
return true;
|
||||
@@ -57,6 +66,14 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
return isEqual(sourceFieldTemplate.type, collectItemType);
|
||||
}
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
@@ -76,7 +93,7 @@ export const useIsValidConnection = () => {
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||
},
|
||||
[shouldValidateGraph, store]
|
||||
[shouldValidateGraph, templates, store]
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
|
||||
@@ -1,19 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { Classification } from 'features/nodes/types/common';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeClassification = (nodeId: string): Classification | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId)?.classification ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const title = useAppSelector(selector);
|
||||
return title;
|
||||
export const useNodeClassification = (nodeId: string): Classification => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const classification = useMemo(() => template.classification, [template]);
|
||||
return classification;
|
||||
};
|
||||
|
||||
@@ -5,7 +5,7 @@ import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeData = (nodeId: string): InvocationNodeData | null => {
|
||||
export const useNodeData = (nodeId: string): InvocationNodeData => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
|
||||
@@ -1,25 +1,11 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeNeedsUpdate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = selectInvocationNode(nodes, nodeId);
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!node || !template) {
|
||||
return false;
|
||||
}
|
||||
return getNeedsUpdate(node, template);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const needsUpdate = useAppSelector(selector);
|
||||
|
||||
const data = useNodeData(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const needsUpdate = useMemo(() => getNeedsUpdate(data, template), [data, template]);
|
||||
return needsUpdate;
|
||||
};
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectInvocationNodeType } from 'features/nodes/store/selectors';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId);
|
||||
}),
|
||||
export const useNodeTemplate = (nodeId: string): InvocationTemplate => {
|
||||
const templates = useStore($templates);
|
||||
const selectNodeType = useMemo(
|
||||
() => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeTemplate = useAppSelector(selector);
|
||||
|
||||
return nodeTemplate;
|
||||
const nodeType = useAppSelector(selectNodeType);
|
||||
const template = useMemo(() => {
|
||||
const t = templates[nodeType];
|
||||
assert(t, `Template for node type ${nodeType} not found`);
|
||||
return t;
|
||||
}, [nodeType, templates]);
|
||||
return template;
|
||||
};
|
||||
|
||||
@@ -1,18 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplateTitle = (nodeId: string): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId)?.title ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const title = useAppSelector(selector);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const title = useMemo(() => template.title, [template.title]);
|
||||
return title;
|
||||
};
|
||||
|
||||
@@ -1,26 +1,10 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useOutputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return getSortedFilteredFieldNames(map(template.outputs));
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const fieldNames = useAppSelector(selector);
|
||||
export const useOutputFieldNames = (nodeId: string): string[] => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldNames = useMemo(() => getSortedFilteredFieldNames(map(template.outputs)), [template.outputs]);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
@@ -43,76 +42,21 @@ import {
|
||||
zT2IAdapterModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type {
|
||||
Connection,
|
||||
Edge,
|
||||
EdgeChange,
|
||||
EdgeRemoveChange,
|
||||
Node,
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import {
|
||||
addEdge,
|
||||
applyEdgeChanges,
|
||||
applyNodeChanges,
|
||||
getConnectedEdges,
|
||||
getIncomers,
|
||||
getOutgoers,
|
||||
SelectionMode,
|
||||
} from 'reactflow';
|
||||
import {
|
||||
socketGeneratorProgress,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationStarted,
|
||||
socketQueueItemStatusChanged,
|
||||
} from 'services/events/actions';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type { z } from 'zod';
|
||||
|
||||
import type { NodesState } from './types';
|
||||
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
|
||||
import type { NodesState, PendingConnection, Templates } from './types';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
|
||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||
status: zNodeStatus.enum.PENDING,
|
||||
error: null,
|
||||
progress: null,
|
||||
progressImage: null,
|
||||
outputs: [],
|
||||
};
|
||||
|
||||
const initialNodesState: NodesState = {
|
||||
_version: 1,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
templates: {},
|
||||
connectionStartParams: null,
|
||||
connectionStartFieldType: null,
|
||||
connectionMade: false,
|
||||
modifyingEdge: false,
|
||||
addNewNodePosition: null,
|
||||
shouldShowMinimapPanel: true,
|
||||
shouldValidateGraph: true,
|
||||
shouldAnimateEdges: true,
|
||||
shouldSnapToGrid: false,
|
||||
shouldColorEdges: true,
|
||||
shouldShowEdgeLabels: false,
|
||||
isAddNodePopoverOpen: false,
|
||||
nodeOpacity: 1,
|
||||
selectedNodes: [],
|
||||
selectedEdges: [],
|
||||
nodeExecutionStates: {},
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
nodesToCopy: [],
|
||||
edgesToCopy: [],
|
||||
selectionMode: SelectionMode.Partial,
|
||||
};
|
||||
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
@@ -154,12 +98,12 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
state.nodes[nodeIndex] = action.payload.node;
|
||||
},
|
||||
nodeAdded: (state, action: PayloadAction<AnyNode>) => {
|
||||
const node = action.payload;
|
||||
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
|
||||
const { node, cursorPos } = action.payload;
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
state.addNewNodePosition?.x ?? node.position.x,
|
||||
state.addNewNodePosition?.y ?? node.position.y
|
||||
cursorPos?.x ?? node.position.x,
|
||||
cursorPos?.y ?? node.position.y
|
||||
);
|
||||
node.position = position;
|
||||
node.selected = true;
|
||||
@@ -175,40 +119,6 @@ export const nodesSlice = createSlice({
|
||||
);
|
||||
|
||||
state.nodes.push(node);
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.nodeExecutionStates[node.id] = {
|
||||
nodeId: node.id,
|
||||
...initialNodeExecutionState,
|
||||
};
|
||||
|
||||
if (state.connectionStartParams) {
|
||||
const { nodeId, handleId, handleType } = state.connectionStartParams;
|
||||
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
node,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
edgeChangeStarted: (state) => {
|
||||
state.modifyingEdge = true;
|
||||
},
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
@@ -216,71 +126,8 @@ export const nodesSlice = createSlice({
|
||||
edgeAdded: (state, action: PayloadAction<Edge>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||
state.connectionStartParams = action.payload;
|
||||
state.connectionMade = state.modifyingEdge;
|
||||
const { nodeId, handleId, handleType } = action.payload;
|
||||
if (!nodeId || !handleId) {
|
||||
return;
|
||||
}
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const template = state.templates[node.data.type];
|
||||
const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId];
|
||||
state.connectionStartFieldType = field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
const fieldType = state.connectionStartFieldType;
|
||||
if (!fieldType) {
|
||||
return;
|
||||
}
|
||||
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
|
||||
|
||||
state.connectionMade = true;
|
||||
},
|
||||
connectionEnded: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
cursorPosition: XYPosition;
|
||||
mouseOverNodeId: string | null;
|
||||
}>
|
||||
) => {
|
||||
const { cursorPosition, mouseOverNodeId } = action.payload;
|
||||
if (!state.connectionMade) {
|
||||
if (mouseOverNodeId) {
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === mouseOverNodeId);
|
||||
const mouseOverNode = state.nodes?.[nodeIndex];
|
||||
if (mouseOverNode && state.connectionStartParams) {
|
||||
const { nodeId, handleId, handleType } = state.connectionStartParams;
|
||||
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
|
||||
const newConnection = findConnectionToValidHandle(
|
||||
mouseOverNode,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
state.connectionStartFieldType
|
||||
);
|
||||
if (newConnection) {
|
||||
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
|
||||
}
|
||||
}
|
||||
}
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
} else {
|
||||
state.addNewNodePosition = cursorPosition;
|
||||
state.isAddNodePopoverOpen = true;
|
||||
}
|
||||
} else {
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
}
|
||||
state.modifyingEdge = false;
|
||||
},
|
||||
fieldLabelChanged: (
|
||||
state,
|
||||
@@ -442,7 +289,6 @@ export const nodesSlice = createSlice({
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
delete state.nodeExecutionStates[node.id];
|
||||
});
|
||||
},
|
||||
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
||||
@@ -474,12 +320,6 @@ export const nodesSlice = createSlice({
|
||||
state.nodes
|
||||
);
|
||||
},
|
||||
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
|
||||
state.selectedNodes = action.payload;
|
||||
},
|
||||
selectedEdgesChanged: (state, action: PayloadAction<string[]>) => {
|
||||
state.selectedEdges = action.payload;
|
||||
},
|
||||
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStatefulFieldValue);
|
||||
},
|
||||
@@ -537,34 +377,10 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = value;
|
||||
},
|
||||
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowMinimapPanel = action.payload;
|
||||
},
|
||||
nodeEditorReset: (state) => {
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
},
|
||||
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldValidateGraph = action.payload;
|
||||
},
|
||||
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAnimateEdges = action.payload;
|
||||
},
|
||||
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowEdgeLabels = action.payload;
|
||||
},
|
||||
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldSnapToGrid = action.payload;
|
||||
},
|
||||
shouldColorEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldColorEdges = action.payload;
|
||||
},
|
||||
nodeOpacityChanged: (state, action: PayloadAction<number>) => {
|
||||
state.nodeOpacity = action.payload;
|
||||
},
|
||||
viewportChanged: (state, action: PayloadAction<Viewport>) => {
|
||||
state.viewport = action.payload;
|
||||
},
|
||||
selectedAll: (state) => {
|
||||
state.nodes = applyNodeChanges(
|
||||
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
|
||||
@@ -575,136 +391,49 @@ export const nodesSlice = createSlice({
|
||||
state.edges
|
||||
);
|
||||
},
|
||||
selectionCopied: (state) => {
|
||||
const nodesToCopy: AnyNode[] = [];
|
||||
const edgesToCopy: Edge[] = [];
|
||||
selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => {
|
||||
const { nodes, edges } = action.payload;
|
||||
|
||||
for (const node of state.nodes) {
|
||||
if (node.selected) {
|
||||
nodesToCopy.push(deepClone(node));
|
||||
}
|
||||
}
|
||||
const nodeChanges: NodeChange[] = [];
|
||||
|
||||
for (const edge of state.edges) {
|
||||
if (edge.selected) {
|
||||
edgesToCopy.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
|
||||
state.nodesToCopy = nodesToCopy;
|
||||
state.edgesToCopy = edgesToCopy;
|
||||
|
||||
if (state.nodesToCopy.length > 0) {
|
||||
const averagePosition = { x: 0, y: 0 };
|
||||
state.nodesToCopy.forEach((e) => {
|
||||
const xOffset = 0.15 * (e.width ?? 0);
|
||||
const yOffset = 0.5 * (e.height ?? 0);
|
||||
averagePosition.x += e.position.x + xOffset;
|
||||
averagePosition.y += e.position.y + yOffset;
|
||||
// Deselect existing nodes
|
||||
state.nodes.forEach((n) => {
|
||||
nodeChanges.push({
|
||||
id: n.data.id,
|
||||
type: 'select',
|
||||
selected: false,
|
||||
});
|
||||
|
||||
averagePosition.x /= state.nodesToCopy.length;
|
||||
averagePosition.y /= state.nodesToCopy.length;
|
||||
|
||||
state.nodesToCopy.forEach((e) => {
|
||||
e.position.x -= averagePosition.x;
|
||||
e.position.y -= averagePosition.y;
|
||||
});
|
||||
// Add new nodes
|
||||
nodes.forEach((n) => {
|
||||
nodeChanges.push({
|
||||
item: n,
|
||||
type: 'add',
|
||||
});
|
||||
}
|
||||
},
|
||||
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
|
||||
const { cursorPosition } = action.payload;
|
||||
const newNodes: AnyNode[] = [];
|
||||
|
||||
for (const node of state.nodesToCopy) {
|
||||
newNodes.push(deepClone(node));
|
||||
}
|
||||
|
||||
const oldNodeIds = newNodes.map((n) => n.data.id);
|
||||
|
||||
const newEdges: Edge[] = [];
|
||||
|
||||
for (const edge of state.edgesToCopy) {
|
||||
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
|
||||
newEdges.push(deepClone(edge));
|
||||
}
|
||||
}
|
||||
|
||||
newEdges.forEach((e) => (e.selected = true));
|
||||
|
||||
newNodes.forEach((node) => {
|
||||
const newNodeId = uuidv4();
|
||||
newEdges.forEach((edge) => {
|
||||
if (edge.source === node.data.id) {
|
||||
edge.source = newNodeId;
|
||||
edge.id = edge.id.replace(node.data.id, newNodeId);
|
||||
}
|
||||
if (edge.target === node.data.id) {
|
||||
edge.target = newNodeId;
|
||||
edge.id = edge.id.replace(node.data.id, newNodeId);
|
||||
}
|
||||
});
|
||||
node.selected = true;
|
||||
node.id = newNodeId;
|
||||
node.data.id = newNodeId;
|
||||
|
||||
const position = findUnoccupiedPosition(
|
||||
state.nodes,
|
||||
node.position.x + (cursorPosition?.x ?? 0),
|
||||
node.position.y + (cursorPosition?.y ?? 0)
|
||||
);
|
||||
|
||||
node.position = position;
|
||||
});
|
||||
|
||||
const nodeAdditions: NodeChange[] = newNodes.map((n) => ({
|
||||
item: n,
|
||||
type: 'add',
|
||||
}));
|
||||
const nodeSelectionChanges: NodeChange[] = state.nodes.map((n) => ({
|
||||
id: n.data.id,
|
||||
type: 'select',
|
||||
selected: false,
|
||||
}));
|
||||
|
||||
const edgeAdditions: EdgeChange[] = newEdges.map((e) => ({
|
||||
item: e,
|
||||
type: 'add',
|
||||
}));
|
||||
const edgeSelectionChanges: EdgeChange[] = state.edges.map((e) => ({
|
||||
id: e.id,
|
||||
type: 'select',
|
||||
selected: false,
|
||||
}));
|
||||
|
||||
state.nodes = applyNodeChanges(nodeAdditions.concat(nodeSelectionChanges), state.nodes);
|
||||
|
||||
state.edges = applyEdgeChanges(edgeAdditions.concat(edgeSelectionChanges), state.edges);
|
||||
|
||||
newNodes.forEach((node) => {
|
||||
state.nodeExecutionStates[node.id] = {
|
||||
nodeId: node.id,
|
||||
...initialNodeExecutionState,
|
||||
};
|
||||
const edgeChanges: EdgeChange[] = [];
|
||||
// Deselect existing edges
|
||||
state.edges.forEach((e) => {
|
||||
edgeChanges.push({
|
||||
id: e.id,
|
||||
type: 'select',
|
||||
selected: false,
|
||||
});
|
||||
});
|
||||
// Add new edges
|
||||
edges.forEach((e) => {
|
||||
edgeChanges.push({
|
||||
item: e,
|
||||
type: 'add',
|
||||
});
|
||||
});
|
||||
},
|
||||
addNodePopoverOpened: (state) => {
|
||||
state.addNewNodePosition = null; //Create the node in viewport center by default
|
||||
state.isAddNodePopoverOpen = true;
|
||||
},
|
||||
addNodePopoverClosed: (state) => {
|
||||
state.isAddNodePopoverOpen = false;
|
||||
|
||||
//Make sure these get reset if we close the popover and haven't selected a node
|
||||
state.connectionStartParams = null;
|
||||
state.connectionStartFieldType = null;
|
||||
},
|
||||
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
|
||||
},
|
||||
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
||||
state.templates = action.payload;
|
||||
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||
},
|
||||
undo: (state) => state,
|
||||
redo: (state) => state,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
@@ -720,75 +449,13 @@ export const nodesSlice = createSlice({
|
||||
edges.map((edge) => ({ item: edge, type: 'add' })),
|
||||
[]
|
||||
);
|
||||
|
||||
state.nodeExecutionStates = nodes.reduce<Record<string, NodeExecutionState>>((acc, node) => {
|
||||
acc[node.id] = {
|
||||
nodeId: node.id,
|
||||
...initialNodeExecutionState,
|
||||
};
|
||||
return acc;
|
||||
}, {});
|
||||
});
|
||||
|
||||
builder.addCase(socketInvocationStarted, (state, action) => {
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
}
|
||||
});
|
||||
builder.addCase(socketInvocationComplete, (state, action) => {
|
||||
const { source_node_id, result } = action.payload.data;
|
||||
const nes = state.nodeExecutionStates[source_node_id];
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
nes.outputs.push(result);
|
||||
}
|
||||
});
|
||||
builder.addCase(socketInvocationError, (state, action) => {
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = zNodeStatus.enum.FAILED;
|
||||
node.error = action.payload.data.error;
|
||||
node.progress = null;
|
||||
node.progressImage = null;
|
||||
}
|
||||
});
|
||||
builder.addCase(socketGeneratorProgress, (state, action) => {
|
||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
node.progress = (step + 1) / total_steps;
|
||||
node.progressImage = progress_image ?? null;
|
||||
}
|
||||
});
|
||||
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach(state.nodeExecutionStates, (nes) => {
|
||||
nes.status = zNodeStatus.enum.PENDING;
|
||||
nes.error = null;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
nes.outputs = [];
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverOpened,
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeDeleted,
|
||||
edgeChangeStarted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
fieldValueReset,
|
||||
@@ -816,31 +483,97 @@ export const {
|
||||
nodeIsOpenChanged,
|
||||
nodeLabelChanged,
|
||||
nodeNotesChanged,
|
||||
nodeOpacityChanged,
|
||||
nodesChanged,
|
||||
nodesDeleted,
|
||||
nodeUseCacheChanged,
|
||||
notesNodeValueChanged,
|
||||
selectedAll,
|
||||
selectedEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
selectionCopied,
|
||||
selectionModeChanged,
|
||||
selectionPasted,
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
viewportChanged,
|
||||
edgeAdded,
|
||||
nodeTemplatesBuilt,
|
||||
shouldShowEdgeLabelsChanged,
|
||||
undo,
|
||||
redo,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export const $cursorPos = atom<XYPosition | null>(null);
|
||||
export const $templates = atom<Templates>({});
|
||||
export const $copiedNodes = atom<AnyNode[]>([]);
|
||||
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
||||
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
|
||||
export const $pendingConnection = atom<PendingConnection | null>(null);
|
||||
export const $isUpdatingEdge = atom(false);
|
||||
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
|
||||
export const $isAddNodePopoverOpen = atom(false);
|
||||
export const closeAddNodePopover = () => {
|
||||
$isAddNodePopoverOpen.set(false);
|
||||
$pendingConnection.set(null);
|
||||
};
|
||||
export const openAddNodePopover = () => {
|
||||
$isAddNodePopoverOpen.set(true);
|
||||
};
|
||||
|
||||
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateNodesState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const nodesPersistConfig: PersistConfig<NodesState> = {
|
||||
name: nodesSlice.name,
|
||||
initialState: initialNodesState,
|
||||
migrate: migrateNodesState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected);
|
||||
|
||||
const isSelectionAction = (action: UnknownAction) => {
|
||||
if (selectionMatcher(action)) {
|
||||
return true;
|
||||
}
|
||||
if (nodesChanged.match(action)) {
|
||||
if (action.payload.every((change) => change.type === 'select')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const individualGroupByMatcher = isAnyOf(nodesChanged);
|
||||
|
||||
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
|
||||
limit: 64,
|
||||
undoType: nodesSlice.actions.undo.type,
|
||||
redoType: nodesSlice.actions.redo.type,
|
||||
groupBy: (action, state, history) => {
|
||||
if (isSelectionAction(action)) {
|
||||
// Changes to selection should never be recorded on their own
|
||||
return history.group;
|
||||
}
|
||||
if (individualGroupByMatcher(action)) {
|
||||
return action.type;
|
||||
}
|
||||
return null;
|
||||
},
|
||||
filter: (action, _state, _history) => {
|
||||
// Ignore all actions from other slices
|
||||
if (!action.type.startsWith(nodesSlice.name)) {
|
||||
return false;
|
||||
}
|
||||
if (nodesChanged.match(action)) {
|
||||
if (action.payload.every((change) => change.type === 'dimensions')) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
};
|
||||
|
||||
// This is used for tracking `state.workflow.isTouched`
|
||||
export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
@@ -873,30 +606,3 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
selectionPasted,
|
||||
edgeAdded
|
||||
);
|
||||
|
||||
export const selectNodesSlice = (state: RootState) => state.nodes;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateNodesState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const nodesPersistConfig: PersistConfig<NodesState> = {
|
||||
name: nodesSlice.name,
|
||||
initialState: initialNodesState,
|
||||
migrate: migrateNodesState,
|
||||
persistDenylist: [
|
||||
'connectionStartParams',
|
||||
'connectionStartFieldType',
|
||||
'selectedNodes',
|
||||
'selectedEdges',
|
||||
'nodesToCopy',
|
||||
'edgesToCopy',
|
||||
'connectionMade',
|
||||
'modifyingEdge',
|
||||
'addNewNodePosition',
|
||||
],
|
||||
};
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => {
|
||||
const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
|
||||
return node;
|
||||
};
|
||||
|
||||
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => {
|
||||
return selectInvocationNode(nodesSlice, nodeId)?.data ?? null;
|
||||
export const selectInvocationNodeType = (nodesSlice: NodesState, nodeId: string): string => {
|
||||
const node = selectInvocationNode(nodesSlice, nodeId);
|
||||
return node.data.type;
|
||||
};
|
||||
|
||||
export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => {
|
||||
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData => {
|
||||
const node = selectInvocationNode(nodesSlice, nodeId);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
return nodesSlice.templates[node.data.type] ?? null;
|
||||
return node.data;
|
||||
};
|
||||
|
||||
export const selectFieldInputInstance = (
|
||||
@@ -32,20 +29,10 @@ export const selectFieldInputInstance = (
|
||||
return data?.inputs[fieldName] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldInputTemplate = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldInputTemplate | null => {
|
||||
const template = selectNodeTemplate(nodesSlice, nodeId);
|
||||
return template?.inputs[fieldName] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldOutputTemplate = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldOutputTemplate | null => {
|
||||
const template = selectNodeTemplate(nodesSlice, nodeId);
|
||||
return template?.outputs[fieldName] ?? null;
|
||||
export const selectLastSelectedNode = (nodesSlice: NodesState) => {
|
||||
const selectedNodes = nodesSlice.nodes.filter((n) => n.selected);
|
||||
if (selectedNodes.length === 1) {
|
||||
return selectedNodes[0];
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -1,38 +1,31 @@
|
||||
import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type {
|
||||
FieldIdentifier,
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
StatefulFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
AnyNode,
|
||||
InvocationNode,
|
||||
InvocationNodeEdge,
|
||||
InvocationTemplate,
|
||||
NodeExecutionState,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
||||
|
||||
export type PendingConnection = {
|
||||
node: InvocationNode;
|
||||
template: InvocationTemplate;
|
||||
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
|
||||
};
|
||||
|
||||
export type NodesState = {
|
||||
_version: 1;
|
||||
nodes: AnyNode[];
|
||||
edges: InvocationNodeEdge[];
|
||||
templates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
connectionStartFieldType: FieldType | null;
|
||||
connectionMade: boolean;
|
||||
modifyingEdge: boolean;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
shouldValidateGraph: boolean;
|
||||
shouldAnimateEdges: boolean;
|
||||
nodeOpacity: number;
|
||||
shouldSnapToGrid: boolean;
|
||||
shouldColorEdges: boolean;
|
||||
shouldShowEdgeLabels: boolean;
|
||||
selectedNodes: string[];
|
||||
selectedEdges: string[];
|
||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||
viewport: Viewport;
|
||||
nodesToCopy: AnyNode[];
|
||||
edgesToCopy: InvocationNodeEdge[];
|
||||
isAddNodePopoverOpen: boolean;
|
||||
addNewNodePosition: XYPosition | null;
|
||||
selectionMode: SelectionMode;
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
|
||||
@@ -1,112 +1,105 @@
|
||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { differenceWith, isEqual, map } from 'lodash-es';
|
||||
import type { Connection } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
const isValidConnection = (
|
||||
edges: Edge[],
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType,
|
||||
node: Node,
|
||||
handle: FieldInputTemplate | FieldOutputTemplate
|
||||
) => {
|
||||
let isValidConnection = true;
|
||||
if (handleCurrentType === 'source') {
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === node.id && edge.targetHandle === handle.name;
|
||||
})
|
||||
) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
} else {
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.source === node.id && edge.sourceHandle === handle.name;
|
||||
})
|
||||
) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
|
||||
isValidConnection = false;
|
||||
}
|
||||
|
||||
return isValidConnection;
|
||||
};
|
||||
|
||||
export const findConnectionToValidHandle = (
|
||||
node: AnyNode,
|
||||
export const getFirstValidConnection = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
templates: Record<string, InvocationTemplate>,
|
||||
handleCurrentNodeId: string,
|
||||
handleCurrentName: string,
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType
|
||||
pendingConnection: PendingConnection,
|
||||
candidateNode: InvocationNode,
|
||||
candidateTemplate: InvocationTemplate
|
||||
): Connection | null => {
|
||||
if (node.id === handleCurrentNodeId || !isInvocationNode(node)) {
|
||||
if (pendingConnection.node.id === candidateNode.id) {
|
||||
// Cannot connect to self
|
||||
return null;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handles = handleCurrentType === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
//Prioritize handles whos name matches the node we're coming from
|
||||
const handle = handles[handleCurrentName];
|
||||
|
||||
if (handle) {
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
||||
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||
|
||||
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
|
||||
|
||||
if (isGraphAcyclic && valid) {
|
||||
if (pendingFieldKind === 'source') {
|
||||
// Connecting from a source to a target
|
||||
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
if (candidateNode.data.type === 'collect') {
|
||||
// Special handling for collect node - the `item` field takes any number of connections
|
||||
return {
|
||||
source: sourceID,
|
||||
sourceHandle: sourceHandle,
|
||||
target: targetID,
|
||||
targetHandle: targetHandle,
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: 'item',
|
||||
};
|
||||
}
|
||||
// Only one connection per target field is allowed - look for an unconnected target field
|
||||
const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct');
|
||||
const candidateConnectedFields = edges
|
||||
.filter((edge) => edge.target === candidateNode.id)
|
||||
.map((edge) => {
|
||||
// Edges must always have a targetHandle, safe to assert here
|
||||
assert(edge.targetHandle);
|
||||
return edge.targetHandle;
|
||||
});
|
||||
const candidateUnconnectedFields = differenceWith(
|
||||
candidateFields,
|
||||
candidateConnectedFields,
|
||||
(field, connectedFieldName) => field.name === connectedFieldName
|
||||
);
|
||||
const candidateField = candidateUnconnectedFields.find((field) =>
|
||||
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
|
||||
);
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: pendingConnection.node.id,
|
||||
sourceHandle: pendingConnection.fieldTemplate.name,
|
||||
target: candidateNode.id,
|
||||
targetHandle: candidateField.name,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Connecting from a target to a source
|
||||
// Ensure we there is not already an edge to the target, except for collect nodes
|
||||
const isCollect = pendingConnection.node.data.type === 'collect';
|
||||
const isTargetAlreadyConnected = edges.some(
|
||||
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
|
||||
);
|
||||
if (!isCollect && isTargetAlreadyConnected) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Sources/outputs can have any number of edges, we can take the first matching output field
|
||||
let candidateFields = map(candidateTemplate.outputs);
|
||||
if (isCollect) {
|
||||
// Narrow candidates to same field type as already is connected to the collect node
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
|
||||
if (collectItemType) {
|
||||
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
|
||||
}
|
||||
}
|
||||
const candidateField = candidateFields.find((field) => {
|
||||
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
|
||||
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
|
||||
return isValid && !isAlreadyConnected;
|
||||
});
|
||||
if (candidateField) {
|
||||
return {
|
||||
source: candidateNode.id,
|
||||
sourceHandle: candidateField.name,
|
||||
target: pendingConnection.node.id,
|
||||
targetHandle: pendingConnection.fieldTemplate.name,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
for (const handleName in handles) {
|
||||
const handle = handles[handleName];
|
||||
if (!handle) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
||||
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
|
||||
|
||||
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
|
||||
|
||||
if (isGraphAcyclic && valid) {
|
||||
return {
|
||||
source: sourceID,
|
||||
sourceHandle: sourceHandle,
|
||||
target: targetID,
|
||||
targetHandle: targetHandle,
|
||||
};
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -4,8 +4,8 @@ export const findUnoccupiedPosition = (nodes: Node[], x: number, y: number) => {
|
||||
let newX = x;
|
||||
let newY = y;
|
||||
while (nodes.find((n) => n.position.x === newX && n.position.y === newY)) {
|
||||
newX = newX + 50;
|
||||
newY = newY + 50;
|
||||
newX = Math.floor(newX + 50);
|
||||
newY = Math.floor(newY + 50);
|
||||
}
|
||||
return { x: newX, y: newY };
|
||||
};
|
||||
|
||||
@@ -1,39 +1,66 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import i18n from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { HandleType } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
|
||||
|
||||
export const getCollectItemType = (
|
||||
templates: Templates,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
nodeId: string
|
||||
): FieldType | null => {
|
||||
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
|
||||
if (!firstEdgeToCollect?.sourceHandle) {
|
||||
return null;
|
||||
}
|
||||
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
|
||||
return fieldType;
|
||||
};
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const makeConnectionErrorSelector = (
|
||||
templates: Templates,
|
||||
pendingConnection: PendingConnection | null,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType | null
|
||||
) => {
|
||||
return createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const { nodes, edges } = nodesSlice;
|
||||
|
||||
if (!fieldType) {
|
||||
return i18n.t('nodes.noFieldType');
|
||||
}
|
||||
|
||||
const { connectionStartFieldType, connectionStartParams, nodes, edges } = nodesSlice;
|
||||
|
||||
if (!connectionStartParams || !connectionStartFieldType) {
|
||||
if (!pendingConnection) {
|
||||
return i18n.t('nodes.noConnectionInProgress');
|
||||
}
|
||||
|
||||
const {
|
||||
handleType: connectionHandleType,
|
||||
nodeId: connectionNodeId,
|
||||
handleId: connectionFieldName,
|
||||
} = connectionStartParams;
|
||||
const connectionNodeId = pendingConnection.node.id;
|
||||
const connectionFieldName = pendingConnection.fieldTemplate.name;
|
||||
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
|
||||
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
|
||||
|
||||
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
|
||||
return i18n.t('nodes.noConnectionData');
|
||||
@@ -54,26 +81,45 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const target = handleType === 'target' ? nodeId : connectionNodeId;
|
||||
const targetHandle = handleType === 'target' ? fieldName : connectionFieldName;
|
||||
const source = handleType === 'source' ? nodeId : connectionNodeId;
|
||||
const sourceHandle = handleType === 'source' ? fieldName : connectionFieldName;
|
||||
const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
|
||||
const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
|
||||
const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
|
||||
const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
edge.target === target &&
|
||||
edge.targetHandle === targetHandle &&
|
||||
edge.source === source &&
|
||||
edge.sourceHandle === sourceHandle;
|
||||
edge.target === targetNodeId &&
|
||||
edge.targetHandle === targetFieldName &&
|
||||
edge.source === sourceNodeId &&
|
||||
edge.sourceHandle === sourceFieldName;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return i18n.t('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((node) => node.id === targetNodeId);
|
||||
assert(targetNode, `Target node not found: ${targetNodeId}`);
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
|
||||
|
||||
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
|
||||
return i18n.t('nodes.cannotConnectToDirectInput');
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType) {
|
||||
if (!isEqual(sourceType, collectItemType)) {
|
||||
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType.name !== 'CollectionItemField'
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { SelectionMode } from 'reactflow';
|
||||
|
||||
type WorkflowSettingsState = {
|
||||
_version: 1;
|
||||
shouldShowMinimapPanel: boolean;
|
||||
shouldValidateGraph: boolean;
|
||||
shouldAnimateEdges: boolean;
|
||||
nodeOpacity: number;
|
||||
shouldSnapToGrid: boolean;
|
||||
shouldColorEdges: boolean;
|
||||
shouldShowEdgeLabels: boolean;
|
||||
selectionMode: SelectionMode;
|
||||
};
|
||||
|
||||
const initialState: WorkflowSettingsState = {
|
||||
_version: 1,
|
||||
shouldShowMinimapPanel: true,
|
||||
shouldValidateGraph: true,
|
||||
shouldAnimateEdges: true,
|
||||
shouldSnapToGrid: false,
|
||||
shouldColorEdges: true,
|
||||
shouldShowEdgeLabels: false,
|
||||
nodeOpacity: 1,
|
||||
selectionMode: SelectionMode.Partial,
|
||||
};
|
||||
|
||||
export const workflowSettingsSlice = createSlice({
|
||||
name: 'workflowSettings',
|
||||
initialState,
|
||||
reducers: {
|
||||
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowMinimapPanel = action.payload;
|
||||
},
|
||||
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldValidateGraph = action.payload;
|
||||
},
|
||||
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAnimateEdges = action.payload;
|
||||
},
|
||||
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowEdgeLabels = action.payload;
|
||||
},
|
||||
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldSnapToGrid = action.payload;
|
||||
},
|
||||
shouldColorEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldColorEdges = action.payload;
|
||||
},
|
||||
nodeOpacityChanged: (state, action: PayloadAction<number>) => {
|
||||
state.nodeOpacity = action.payload;
|
||||
},
|
||||
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
shouldShowEdgeLabelsChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
nodeOpacityChanged,
|
||||
selectionModeChanged,
|
||||
} = workflowSettingsSlice.actions;
|
||||
|
||||
export const selectWorkflowSettingsSlice = (state: RootState) => state.workflowSettings;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateWorkflowSettingsState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const workflowSettingsPersistConfig: PersistConfig<WorkflowSettingsState> = {
|
||||
name: workflowSettingsSlice.name,
|
||||
initialState,
|
||||
migrate: migrateWorkflowSettingsState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
@@ -11,7 +11,7 @@ import type {
|
||||
SchedulerField,
|
||||
T2IAdapterField,
|
||||
} from 'features/nodes/types/common';
|
||||
import type { S } from 'services/api/types';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
import type { Equals, Extends } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { describe, test } from 'vitest';
|
||||
@@ -26,7 +26,7 @@ describe('Common types', () => {
|
||||
test('ImageField', () => assert<Equals<ImageField, S['ImageField']>>());
|
||||
test('BoardField', () => assert<Equals<BoardField, S['BoardField']>>());
|
||||
test('ColorField', () => assert<Equals<ColorField, S['ColorField']>>());
|
||||
test('SchedulerField', () => assert<Equals<SchedulerField, NonNullable<S['SchedulerInvocation']['scheduler']>>>());
|
||||
test('SchedulerField', () => assert<Equals<SchedulerField, NonNullable<Invocation<'scheduler'>['scheduler']>>>());
|
||||
test('ControlField', () => assert<Equals<ControlField, S['ControlField']>>());
|
||||
// @ts-expect-error TODO(psyche): fix types
|
||||
test('IPAdapterField', () => assert<Extends<IPAdapterField, S['IPAdapterField']>>());
|
||||
|
||||
@@ -1,707 +0,0 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import {
|
||||
isControlAdapterLayer,
|
||||
isInitialImageLayer,
|
||||
isIPAdapterLayer,
|
||||
isRegionalGuidanceLayer,
|
||||
rgLayerMaskImageUploaded,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { InitialImageLayer, Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||
import type {
|
||||
ControlNetConfigV2,
|
||||
ImageWithDims,
|
||||
IPAdapterConfigV2,
|
||||
ProcessorConfig,
|
||||
T2IAdapterConfigV2,
|
||||
} from 'features/controlLayers/util/controlAdapters';
|
||||
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import {
|
||||
CONTROL_NET_COLLECT,
|
||||
IMAGE_TO_LATENTS,
|
||||
IP_ADAPTER_COLLECT,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
|
||||
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||
RESIZE,
|
||||
T2I_ADAPTER_COLLECT,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
|
||||
import { size } from 'lodash-es';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type {
|
||||
BaseModelType,
|
||||
CollectInvocation,
|
||||
ControlNetInvocation,
|
||||
Edge,
|
||||
ImageDTO,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
IPAdapterInvocation,
|
||||
NonNullableGraph,
|
||||
S,
|
||||
T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const addControlLayersToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
denoiseNodeId: string
|
||||
): Promise<Layer[]> => {
|
||||
const mainModel = state.generation.model;
|
||||
assert(mainModel, 'Missing main model when building graph');
|
||||
const isSDXL = mainModel.base === 'sdxl';
|
||||
|
||||
// Filter out layers with incompatible base model, missing control image
|
||||
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, mainModel.base));
|
||||
|
||||
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
|
||||
for (const ca of validControlAdapters) {
|
||||
addGlobalControlAdapterToGraph(ca, graph, denoiseNodeId);
|
||||
}
|
||||
|
||||
const validIPAdapters = validLayers.filter(isIPAdapterLayer).map((l) => l.ipAdapter);
|
||||
for (const ipAdapter of validIPAdapters) {
|
||||
addGlobalIPAdapterToGraph(ipAdapter, graph, denoiseNodeId);
|
||||
}
|
||||
|
||||
const initialImageLayers = validLayers.filter(isInitialImageLayer);
|
||||
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
|
||||
if (initialImageLayers[0]) {
|
||||
addInitialImageLayerToGraph(state, graph, denoiseNodeId, initialImageLayers[0]);
|
||||
}
|
||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||
// the existing conditioning nodes.
|
||||
|
||||
// With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up
|
||||
const posCondCollectNode: CollectInvocation = {
|
||||
id: POSITIVE_CONDITIONING_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
|
||||
const negCondCollectNode: CollectInvocation = {
|
||||
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
|
||||
|
||||
// Re-route the denoise node's OG conditioning inputs to the collect nodes
|
||||
const newEdges: Edge[] = [];
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
|
||||
newEdges.push({
|
||||
source: edge.source,
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
|
||||
newEdges.push({
|
||||
source: edge.source,
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
newEdges.push(edge);
|
||||
}
|
||||
}
|
||||
graph.edges = newEdges;
|
||||
|
||||
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
});
|
||||
|
||||
const validRGLayers = validLayers.filter(isRegionalGuidanceLayer);
|
||||
const layerIds = validRGLayers.map((l) => l.id);
|
||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||
|
||||
for (const layer of validRGLayers) {
|
||||
const blob = blobs[layer.id];
|
||||
assert(blob, `Blob for layer ${layer.id} not found`);
|
||||
// Upload the mask image, or get the cached image if it exists
|
||||
const { image_name } = await getMaskImage(layer, blob);
|
||||
|
||||
// The main mask-to-tensor node
|
||||
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
|
||||
type: 'alpha_mask_to_tensor',
|
||||
image: {
|
||||
image_name,
|
||||
},
|
||||
};
|
||||
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
|
||||
|
||||
if (layer.positivePrompt) {
|
||||
// The main positive conditioning node
|
||||
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
};
|
||||
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
|
||||
});
|
||||
|
||||
// Connect the conditioning to the collector
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
|
||||
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
|
||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.negativePrompt) {
|
||||
// The main negative conditioning node
|
||||
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.negativePrompt,
|
||||
style: layer.negativePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.negativePrompt,
|
||||
};
|
||||
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
|
||||
});
|
||||
|
||||
// Connect the conditioning to the collector
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
|
||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
|
||||
// Copy the connections to the "global" negative conditioning node to the regional cond
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
|
||||
// We re-use the mask image, but invert it when converting to tensor
|
||||
const invertTensorMaskNode: S['InvertTensorMaskInvocation'] = {
|
||||
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
|
||||
type: 'invert_tensor_mask',
|
||||
};
|
||||
graph.nodes[invertTensorMaskNode.id] = invertTensorMaskNode;
|
||||
|
||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: maskToTensorNode.id,
|
||||
field: 'mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: invertTensorMaskNode.id,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
|
||||
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
|
||||
// positive prompt
|
||||
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
};
|
||||
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
|
||||
// Connect the inverted mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: invertTensorMaskNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
|
||||
});
|
||||
// Connect the conditioning to the negative collector
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
|
||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) =>
|
||||
isValidIPAdapter(ipa, mainModel.base)
|
||||
);
|
||||
|
||||
for (const ipAdapter of validRegionalIPAdapters) {
|
||||
addIPAdapterCollectorSafe(graph, denoiseNodeId);
|
||||
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(model, 'IP Adapter model is required');
|
||||
assert(image, 'IP Adapter image is required');
|
||||
|
||||
const ipAdapterNode: IPAdapterInvocation = {
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
is_intermediate: true,
|
||||
weight,
|
||||
method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.name,
|
||||
},
|
||||
};
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
graph.edges.push({
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: ipAdapterNode.id, field: 'mask' },
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
destination: {
|
||||
node_id: IP_ADAPTER_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
upsertMetadata(graph, { control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
|
||||
return validLayers;
|
||||
};
|
||||
|
||||
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||
if (layer.uploadedMaskImage) {
|
||||
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
const buildControlImage = (
|
||||
image: ImageWithDims | null,
|
||||
processedImage: ImageWithDims | null,
|
||||
processorConfig: ProcessorConfig | null
|
||||
): ImageField => {
|
||||
if (processedImage && processorConfig) {
|
||||
// We've processed the image in the app - use it for the control image.
|
||||
return {
|
||||
image_name: processedImage.name,
|
||||
};
|
||||
} else if (image) {
|
||||
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
|
||||
return {
|
||||
image_name: image.name,
|
||||
};
|
||||
}
|
||||
assert(false, 'Attempted to add unprocessed control image');
|
||||
};
|
||||
|
||||
const addGlobalControlAdapterToGraph = (
|
||||
controlAdapter: ControlNetConfigV2 | T2IAdapterConfigV2,
|
||||
graph: NonNullableGraph,
|
||||
denoiseNodeId: string
|
||||
) => {
|
||||
if (controlAdapter.type === 'controlnet') {
|
||||
addGlobalControlNetToGraph(controlAdapter, graph, denoiseNodeId);
|
||||
}
|
||||
if (controlAdapter.type === 't2i_adapter') {
|
||||
addGlobalT2IAdapterToGraph(controlAdapter, graph, denoiseNodeId);
|
||||
}
|
||||
};
|
||||
|
||||
const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
if (graph.nodes[CONTROL_NET_COLLECT]) {
|
||||
// You see, we've already got one!
|
||||
return;
|
||||
}
|
||||
// Add the ControlNet collector
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'control',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const addGlobalControlNetToGraph = (controlNet: ControlNetConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
|
||||
assert(model, 'ControlNet model is required');
|
||||
const controlImage = buildControlImage(image, processedImage, processorConfig);
|
||||
addControlNetCollectorSafe(graph, denoiseNodeId);
|
||||
|
||||
const controlNetNode: ControlNetInvocation = {
|
||||
id: `control_net_${id}`,
|
||||
type: 'controlnet',
|
||||
is_intermediate: true,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
control_mode: controlMode,
|
||||
resize_mode: 'just_resize',
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
image: controlImage,
|
||||
};
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
node_id: CONTROL_NET_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
if (graph.nodes[T2I_ADAPTER_COLLECT]) {
|
||||
// You see, we've already got one!
|
||||
return;
|
||||
}
|
||||
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
|
||||
const t2iAdapterCollectNode: CollectInvocation = {
|
||||
id: T2I_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.nodes[T2I_ADAPTER_COLLECT] = t2iAdapterCollectNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: T2I_ADAPTER_COLLECT, field: 'collection' },
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 't2i_adapter',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const addGlobalT2IAdapterToGraph = (t2iAdapter: T2IAdapterConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
|
||||
assert(model, 'T2I Adapter model is required');
|
||||
const controlImage = buildControlImage(image, processedImage, processorConfig);
|
||||
addT2IAdapterCollectorSafe(graph, denoiseNodeId);
|
||||
|
||||
const t2iAdapterNode: T2IAdapterInvocation = {
|
||||
id: `t2i_adapter_${id}`,
|
||||
type: 't2i_adapter',
|
||||
is_intermediate: true,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
resize_mode: 'just_resize',
|
||||
t2i_adapter_model: model,
|
||||
weight: weight,
|
||||
image: controlImage,
|
||||
};
|
||||
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
destination: {
|
||||
node_id: T2I_ADAPTER_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
if (graph.nodes[IP_ADAPTER_COLLECT]) {
|
||||
// You see, we've already got one!
|
||||
return;
|
||||
}
|
||||
|
||||
const ipAdapterCollectNode: CollectInvocation = {
|
||||
id: IP_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.nodes[IP_ADAPTER_COLLECT] = ipAdapterCollectNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: IP_ADAPTER_COLLECT, field: 'collection' },
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'ip_adapter',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const addGlobalIPAdapterToGraph = (ipAdapter: IPAdapterConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
addIPAdapterCollectorSafe(graph, denoiseNodeId);
|
||||
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(image, 'IP Adapter image is required');
|
||||
assert(model, 'IP Adapter model is required');
|
||||
|
||||
const ipAdapterNode: IPAdapterInvocation = {
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
is_intermediate: true,
|
||||
weight,
|
||||
method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.name,
|
||||
},
|
||||
};
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
destination: {
|
||||
node_id: IP_ADAPTER_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const addInitialImageLayerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
denoiseNodeId: string,
|
||||
layer: InitialImageLayer
|
||||
) => {
|
||||
const { vaePrecision, model } = state.generation;
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
assert(layer.isEnabled, 'Initial image layer is not enabled');
|
||||
assert(layer.image, 'Initial image layer has no image');
|
||||
|
||||
const isSDXL = model?.base === 'sdxl';
|
||||
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
|
||||
|
||||
const denoiseNode = graph.nodes[denoiseNodeId];
|
||||
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
|
||||
|
||||
const { denoisingStrength } = layer;
|
||||
denoiseNode.denoising_start = useRefinerStartEnd
|
||||
? Math.min(refinerStart, 1 - denoisingStrength)
|
||||
: 1 - denoisingStrength;
|
||||
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
|
||||
|
||||
const i2lNode: ImageToLatentsInvocation = {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
is_intermediate: true,
|
||||
use_cache: true,
|
||||
fp32: vaePrecision === 'fp32',
|
||||
};
|
||||
|
||||
graph.nodes[i2lNode.id] = i2lNode;
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNode.id,
|
||||
field: 'latents',
|
||||
},
|
||||
});
|
||||
|
||||
if (layer.image.width !== width || layer.image.height !== height) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: layer.image.name,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
i2lNode.image = {
|
||||
image_name: layer.image.name,
|
||||
};
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
upsertMetadata(graph, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
|
||||
};
|
||||
|
||||
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {
|
||||
// Must be have a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ca.model);
|
||||
const modelMatchesBase = ca.model?.base === base;
|
||||
const hasControlImage = Boolean(ca.image || (ca.processedImage && ca.processorConfig));
|
||||
return hasModel && modelMatchesBase && hasControlImage;
|
||||
};
|
||||
|
||||
const isValidIPAdapter = (ipa: IPAdapterConfigV2, base: BaseModelType): boolean => {
|
||||
// Must be have a model that matches the current base and must have a control image
|
||||
const hasModel = Boolean(ipa.model);
|
||||
const modelMatchesBase = ipa.model?.base === base;
|
||||
const hasImage = Boolean(ipa.image);
|
||||
return hasModel && modelMatchesBase && hasImage;
|
||||
};
|
||||
|
||||
const isValidLayer = (layer: Layer, base: BaseModelType) => {
|
||||
if (!layer.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
if (isControlAdapterLayer(layer)) {
|
||||
return isValidControlAdapter(layer.controlAdapter, base);
|
||||
}
|
||||
if (isIPAdapterLayer(layer)) {
|
||||
return isValidIPAdapter(layer.ipAdapter, base);
|
||||
}
|
||||
if (isInitialImageLayer(layer)) {
|
||||
if (!layer.image) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (isRegionalGuidanceLayer(layer)) {
|
||||
if (layer.maskObjects.length === 0) {
|
||||
// Layer has no mask, meaning any guidance would be applied to an empty region.
|
||||
return false;
|
||||
}
|
||||
const hasTextPrompt = Boolean(layer.positivePrompt) || Boolean(layer.negativePrompt);
|
||||
const hasIPAdapter = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@@ -1,356 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import type {
|
||||
DenoiseLatentsInvocation,
|
||||
Edge,
|
||||
ESRGANInvocation,
|
||||
LatentsToImageInvocation,
|
||||
NoiseInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import {
|
||||
DENOISE_LATENTS,
|
||||
DENOISE_LATENTS_HRF,
|
||||
ESRGAN_HRF,
|
||||
IMAGE_TO_LATENTS_HRF,
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_IMAGE_HRF_HR,
|
||||
LATENTS_TO_IMAGE_HRF_LR,
|
||||
MAIN_MODEL_LOADER,
|
||||
NOISE,
|
||||
NOISE_HRF,
|
||||
RESIZE_HRF,
|
||||
SEAMLESS,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { setMetadataReceivingNode, upsertMetadata } from './metadata';
|
||||
|
||||
// Copy certain connections from previous DENOISE_LATENTS to new DENOISE_LATENTS_HRF.
|
||||
function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
|
||||
const destinationFields = [
|
||||
'control',
|
||||
'ip_adapter',
|
||||
'metadata',
|
||||
'unet',
|
||||
'positive_conditioning',
|
||||
'negative_conditioning',
|
||||
];
|
||||
const newEdges: Edge[] = [];
|
||||
|
||||
// Loop through the existing edges connected to DENOISE_LATENTS
|
||||
graph.edges.forEach((edge: Edge) => {
|
||||
if (edge.destination.node_id === DENOISE_LATENTS && destinationFields.includes(edge.destination.field)) {
|
||||
// Add a similar connection to DENOISE_LATENTS_HRF
|
||||
newEdges.push({
|
||||
source: {
|
||||
node_id: edge.source.node_id,
|
||||
field: edge.source.field,
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS_HRF,
|
||||
field: edge.destination.field,
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
graph.edges = graph.edges.concat(newEdges);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the new resolution for high-resolution features (HRF) based on base model type.
|
||||
* Adjusts the width and height to maintain the aspect ratio and constrains them by the model's dimension limits,
|
||||
* rounding down to the nearest multiple of 8.
|
||||
*
|
||||
* @param {number} optimalDimension The optimal dimension for the base model.
|
||||
* @param {number} width The current width to be adjusted for HRF.
|
||||
* @param {number} height The current height to be adjusted for HRF.
|
||||
* @return {{newWidth: number, newHeight: number}} The new width and height, adjusted and rounded as needed.
|
||||
*/
|
||||
function calculateHrfRes(
|
||||
optimalDimension: number,
|
||||
width: number,
|
||||
height: number
|
||||
): { newWidth: number; newHeight: number } {
|
||||
const aspect = width / height;
|
||||
|
||||
const minDimension = Math.floor(optimalDimension * 0.5);
|
||||
const modelArea = optimalDimension * optimalDimension; // Assuming square images for model_area
|
||||
|
||||
let initWidth;
|
||||
let initHeight;
|
||||
|
||||
if (aspect > 1.0) {
|
||||
initHeight = Math.max(minDimension, Math.sqrt(modelArea / aspect));
|
||||
initWidth = initHeight * aspect;
|
||||
} else {
|
||||
initWidth = Math.max(minDimension, Math.sqrt(modelArea * aspect));
|
||||
initHeight = initWidth / aspect;
|
||||
}
|
||||
// Cap initial height and width to final height and width.
|
||||
initWidth = Math.min(width, initWidth);
|
||||
initHeight = Math.min(height, initHeight);
|
||||
|
||||
const newWidth = roundToMultiple(Math.floor(initWidth), 8);
|
||||
const newHeight = roundToMultiple(Math.floor(initHeight), 8);
|
||||
|
||||
return { newWidth, newHeight };
|
||||
}
|
||||
|
||||
// Adds the high-res fix feature to the given graph.
|
||||
export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void => {
|
||||
// Double check hrf is enabled.
|
||||
if (!state.hrf.hrfEnabled || state.config.disabledSDFeatures.includes('hrf')) {
|
||||
return;
|
||||
}
|
||||
const log = logger('generation');
|
||||
|
||||
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
|
||||
const { hrfStrength, hrfEnabled, hrfMethod } = state.hrf;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
const isAutoVae = !vae;
|
||||
const isSeamlessEnabled = seamlessXAxis || seamlessYAxis;
|
||||
const optimalDimension = selectOptimalDimension(state);
|
||||
const { newWidth: hrfWidth, newHeight: hrfHeight } = calculateHrfRes(optimalDimension, width, height);
|
||||
|
||||
// Pre-existing (original) graph nodes.
|
||||
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as DenoiseLatentsInvocation | undefined;
|
||||
const originalNoiseNode = graph.nodes[NOISE] as NoiseInvocation | undefined;
|
||||
const originalLatentsToImageNode = graph.nodes[LATENTS_TO_IMAGE] as LatentsToImageInvocation | undefined;
|
||||
if (!originalDenoiseLatentsNode) {
|
||||
log.error('originalDenoiseLatentsNode is undefined');
|
||||
return;
|
||||
}
|
||||
if (!originalNoiseNode) {
|
||||
log.error('originalNoiseNode is undefined');
|
||||
return;
|
||||
}
|
||||
if (!originalLatentsToImageNode) {
|
||||
log.error('originalLatentsToImageNode is undefined');
|
||||
return;
|
||||
}
|
||||
|
||||
// Change height and width of original noise node to initial resolution.
|
||||
if (originalNoiseNode) {
|
||||
originalNoiseNode.width = hrfWidth;
|
||||
originalNoiseNode.height = hrfHeight;
|
||||
}
|
||||
|
||||
// Define new nodes and their connections, roughly in order of operations.
|
||||
graph.nodes[LATENTS_TO_IMAGE_HRF_LR] = {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE_HRF_LR,
|
||||
fp32: originalLatentsToImageNode?.fp32,
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF_LR,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF_LR,
|
||||
field: 'vae',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
graph.nodes[RESIZE_HRF] = {
|
||||
id: RESIZE_HRF,
|
||||
type: 'img_resize',
|
||||
is_intermediate: true,
|
||||
width: width,
|
||||
height: height,
|
||||
};
|
||||
if (hrfMethod === 'ESRGAN') {
|
||||
let model_name: ESRGANInvocation['model_name'] = 'RealESRGAN_x2plus.pth';
|
||||
if ((width * height) / (hrfWidth * hrfHeight) > 2) {
|
||||
model_name = 'RealESRGAN_x4plus.pth';
|
||||
}
|
||||
graph.nodes[ESRGAN_HRF] = {
|
||||
id: ESRGAN_HRF,
|
||||
type: 'esrgan',
|
||||
model_name,
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF_LR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: ESRGAN_HRF,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: ESRGAN_HRF,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: RESIZE_HRF,
|
||||
field: 'image',
|
||||
},
|
||||
}
|
||||
);
|
||||
} else {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF_LR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: RESIZE_HRF,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
graph.nodes[NOISE_HRF] = {
|
||||
type: 'noise',
|
||||
id: NOISE_HRF,
|
||||
seed: originalNoiseNode?.seed,
|
||||
use_cpu: originalNoiseNode?.use_cpu,
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: RESIZE_HRF,
|
||||
field: 'height',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE_HRF,
|
||||
field: 'height',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RESIZE_HRF,
|
||||
field: 'width',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE_HRF,
|
||||
field: 'width',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
graph.nodes[IMAGE_TO_LATENTS_HRF] = {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS_HRF,
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS_HRF,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RESIZE_HRF,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS_HRF,
|
||||
field: 'image',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
graph.nodes[DENOISE_LATENTS_HRF] = {
|
||||
type: 'denoise_latents',
|
||||
id: DENOISE_LATENTS_HRF,
|
||||
is_intermediate: true,
|
||||
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
|
||||
scheduler: originalDenoiseLatentsNode?.scheduler,
|
||||
steps: originalDenoiseLatentsNode?.steps,
|
||||
denoising_start: 1 - hrfStrength,
|
||||
denoising_end: 1,
|
||||
};
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS_HRF,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS_HRF,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE_HRF,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS_HRF,
|
||||
field: 'noise',
|
||||
},
|
||||
}
|
||||
);
|
||||
copyConnectionsToDenoiseLatentsHrf(graph);
|
||||
|
||||
// The original l2i node is unnecessary now, remove it
|
||||
graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE);
|
||||
delete graph.nodes[LATENTS_TO_IMAGE];
|
||||
|
||||
graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE_HRF_HR,
|
||||
fp32: originalLatentsToImageNode?.fp32,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
};
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF_HR,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: DENOISE_LATENTS_HRF,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF_HR,
|
||||
field: 'latents',
|
||||
},
|
||||
}
|
||||
);
|
||||
upsertMetadata(graph, {
|
||||
hrf_strength: hrfStrength,
|
||||
hrf_enabled: hrfEnabled,
|
||||
hrf_method: hrfMethod,
|
||||
});
|
||||
setMetadataReceivingNode(graph, LATENTS_TO_IMAGE_HRF_HR);
|
||||
};
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types';
|
||||
import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
|
||||
import { ESRGAN } from './constants';
|
||||
import { addCoreMetadataNode, upsertMetadata } from './metadata';
|
||||
|
||||
type Arg = {
|
||||
image_name: string;
|
||||
@@ -13,7 +13,7 @@ type Arg = {
|
||||
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
|
||||
const { esrganModelName } = state.postprocessing;
|
||||
|
||||
const realesrganNode: ESRGANInvocation = {
|
||||
const realesrganNode: Invocation<'esrgan'> = {
|
||||
id: ESRGAN,
|
||||
type: 'esrgan',
|
||||
image: { image_name },
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
CONTROL_LAYERS_GRAPH,
|
||||
DENOISE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
export const buildGenerationTabGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
seed,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
const is_intermediate = true;
|
||||
|
||||
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: CONTROL_LAYERS_GRAPH,
|
||||
nodes: {
|
||||
[modelLoaderNodeId]: {
|
||||
type: 'main_model_loader',
|
||||
id: modelLoaderNodeId,
|
||||
is_intermediate,
|
||||
model,
|
||||
},
|
||||
[CLIP_SKIP]: {
|
||||
type: 'clip_skip',
|
||||
id: CLIP_SKIP,
|
||||
skipped_layers: clipSkip,
|
||||
is_intermediate,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
is_intermediate,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
is_intermediate,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
use_cpu,
|
||||
is_intermediate,
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: DENOISE_LATENTS,
|
||||
is_intermediate,
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: 0,
|
||||
denoising_end: 1,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
// Connect Model Loader to UNet and CLIP Skip
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: CLIP_SKIP,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
// Connect CLIP Skip to Conditioning
|
||||
{
|
||||
source: {
|
||||
node_id: CLIP_SKIP,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: CLIP_SKIP,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
// Connect everything to Denoise Latents
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
// Decode Denoised Latents To Image
|
||||
{
|
||||
source: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
clip_skip: clipSkip,
|
||||
},
|
||||
LATENTS_TO_IMAGE
|
||||
);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
modelLoaderNodeId = SEAMLESS;
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
const addedLayers = await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
const shouldUseHRF = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled && shouldUseHRF) {
|
||||
addHrfToGraph(state, graph);
|
||||
}
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
addNSFWCheckerToGraph(state, graph);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseWatermarker) {
|
||||
// must add after nsfw checker!
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
return graph;
|
||||
};
|
||||
@@ -1,278 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SDXL_CONTROL_LAYERS_GRAPH,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
const is_intermediate = true;
|
||||
|
||||
// Construct Style Prompt
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
// Model Loader ID
|
||||
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_CONTROL_LAYERS_GRAPH,
|
||||
nodes: {
|
||||
[modelLoaderNodeId]: {
|
||||
type: 'sdxl_model_loader',
|
||||
id: modelLoaderNodeId,
|
||||
model,
|
||||
is_intermediate,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
style: positiveStylePrompt,
|
||||
is_intermediate,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
style: negativeStylePrompt,
|
||||
is_intermediate,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
use_cpu,
|
||||
is_intermediate,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: SDXL_DENOISE_LATENTS,
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: 0,
|
||||
denoising_end: refinerModel ? refinerStart : 1,
|
||||
is_intermediate,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
// Connect Model Loader to UNet, VAE & CLIP
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
// Connect everything to Denoise Latents
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
// Decode Denoised Latents To Image
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
},
|
||||
LATENTS_TO_IMAGE
|
||||
);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
modelLoaderNodeId = SEAMLESS;
|
||||
}
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
addNSFWCheckerToGraph(state, graph);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseWatermarker) {
|
||||
// must add after nsfw checker!
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
return graph;
|
||||
};
|
||||
@@ -5,8 +5,8 @@ import { range } from 'lodash-es';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { getHasMetadata, removeMetadata } from './canvas/metadata';
|
||||
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
|
||||
import { getHasMetadata, removeMetadata } from './metadata';
|
||||
|
||||
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
|
||||
const { iterations, model, shouldRandomizeSeed, seed } = state.generation;
|
||||
|
||||
@@ -2,25 +2,18 @@ import type { RootState } from 'app/store/store';
|
||||
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import { CONTROL_NET_COLLECT } from 'features/nodes/util/graph/constants';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
ControlNetInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
S,
|
||||
} from 'services/api/types';
|
||||
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { CONTROL_NET_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
|
||||
const controlNetMetadata: S['CoreMetadataInvocation']['controlnets'] = [];
|
||||
const controlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
|
||||
const hasModel = Boolean(model);
|
||||
@@ -37,7 +30,7 @@ export const addControlNetToLinearGraph = async (
|
||||
|
||||
if (controlNets.length) {
|
||||
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
const controlNetIterateNode: Invocation<'collect'> = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
@@ -68,7 +61,7 @@ export const addControlNetToLinearGraph = async (
|
||||
weight,
|
||||
} = controlNet;
|
||||
|
||||
const controlNetNode: ControlNetInvocation = {
|
||||
const controlNetNode: Invocation<'controlnet'> = {
|
||||
id: `control_net_${id}`,
|
||||
type: 'controlnet',
|
||||
is_intermediate: true,
|
||||
@@ -2,19 +2,12 @@ import type { RootState } from 'app/store/store';
|
||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
IPAdapterInvocation,
|
||||
NonNullableGraph,
|
||||
S,
|
||||
} from 'services/api/types';
|
||||
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { IP_ADAPTER_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addIPAdapterToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
@@ -33,7 +26,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
|
||||
if (ipAdapters.length) {
|
||||
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
||||
const ipAdapterCollectNode: CollectInvocation = {
|
||||
const ipAdapterCollectNode: Invocation<'collect'> = {
|
||||
id: IP_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
@@ -47,7 +40,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
},
|
||||
});
|
||||
|
||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||
const ipAdapterMetdata: S['CoreMetadataInvocation']['ipAdapters'] = [];
|
||||
|
||||
for (const ipAdapter of ipAdapters) {
|
||||
if (!ipAdapter.model) {
|
||||
@@ -57,7 +50,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
|
||||
assert(controlImage, 'IP Adapter image is required');
|
||||
|
||||
const ipAdapterNode: IPAdapterInvocation = {
|
||||
const ipAdapterNode: Invocation<'ip_adapter'> = {
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
is_intermediate: true,
|
||||
@@ -1,10 +1,15 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
|
||||
|
||||
export const addLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
@@ -38,7 +43,7 @@ export const addLoRAsToGraph = async (
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoraNodeId = '';
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
|
||||
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
const { weight } = lora;
|
||||
@@ -46,7 +51,7 @@ export const addLoRAsToGraph = async (
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const loraLoaderNode: LoRALoaderInvocation = {
|
||||
const loraLoaderNode: Invocation<'lora_loader'> = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
@@ -1,15 +1,14 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
|
||||
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from 'features/nodes/util/graph/constants';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
export const addNSFWCheckerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
nodeIdToAddTo = LATENTS_TO_IMAGE
|
||||
): void => {
|
||||
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
|
||||
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
|
||||
|
||||
if (!nodeToAddTo) {
|
||||
// something has gone terribly awry
|
||||
@@ -19,14 +18,14 @@ export const addNSFWCheckerToGraph = (
|
||||
nodeToAddTo.is_intermediate = true;
|
||||
nodeToAddTo.use_cache = true;
|
||||
|
||||
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
|
||||
const nsfwCheckerNode: Invocation<'img_nsfw'> = {
|
||||
id: NSFW_CHECKER,
|
||||
type: 'img_nsfw',
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
};
|
||||
|
||||
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
|
||||
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode;
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: nodeIdToAddTo,
|
||||
@@ -1,8 +1,6 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types';
|
||||
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
LORA_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
@@ -10,8 +8,9 @@ import {
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
|
||||
|
||||
export const addSDXLLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
@@ -35,7 +34,7 @@ export const addSDXLLoRAsToGraph = async (
|
||||
return;
|
||||
}
|
||||
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
|
||||
|
||||
// Handle Seamless Plugs
|
||||
const unetLoaderId = modelLoaderNodeId;
|
||||
@@ -61,7 +60,7 @@ export const addSDXLLoRAsToGraph = async (
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const loraLoaderNode: SDXLLoRALoaderInvocation = {
|
||||
const loraLoaderNode: Invocation<'sdxl_lora_loader'> = {
|
||||
type: 'sdxl_lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
@@ -1,8 +1,6 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
|
||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||
|
||||
import { getModelMetadataField, upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
INPAINT_CREATE_MASK,
|
||||
@@ -17,9 +15,10 @@ import {
|
||||
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||
|
||||
export const addSDXLRefinerToGraph = async (
|
||||
state: RootState,
|
||||
@@ -101,7 +100,7 @@ export const addSDXLRefinerToGraph = async (
|
||||
type: 'seamless',
|
||||
seamless_x: seamlessXAxis,
|
||||
seamless_y: seamlessYAxis,
|
||||
} as SeamlessModeInvocation;
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
{
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
|
||||
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
DENOISE_LATENTS,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
@@ -11,8 +10,8 @@ import {
|
||||
SDXL_DENOISE_LATENTS,
|
||||
SEAMLESS,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
export const addSeamlessToLinearGraph = (
|
||||
state: RootState,
|
||||
@@ -28,7 +27,7 @@ export const addSeamlessToLinearGraph = (
|
||||
type: 'seamless',
|
||||
seamless_x: seamlessXAxis,
|
||||
seamless_y: seamlessYAxis,
|
||||
} as SeamlessModeInvocation;
|
||||
};
|
||||
|
||||
if (!isAutoVae) {
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
@@ -2,19 +2,12 @@ import type { RootState } from 'app/store/store';
|
||||
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import { T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
S,
|
||||
T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { T2I_ADAPTER_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
@@ -36,7 +29,7 @@ export const addT2IAdaptersToLinearGraph = async (
|
||||
|
||||
if (t2iAdapters.length) {
|
||||
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
|
||||
const t2iAdapterCollectNode: CollectInvocation = {
|
||||
const t2iAdapterCollectNode: Invocation<'collect'> = {
|
||||
id: T2I_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
@@ -50,7 +43,7 @@ export const addT2IAdaptersToLinearGraph = async (
|
||||
},
|
||||
});
|
||||
|
||||
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
const t2iAdapterMetadata: S['CoreMetadataInvocation']['t2iAdapters'] = [];
|
||||
|
||||
for (const t2iAdapter of t2iAdapters) {
|
||||
if (!t2iAdapter.model) {
|
||||
@@ -68,7 +61,7 @@ export const addT2IAdaptersToLinearGraph = async (
|
||||
weight,
|
||||
} = t2iAdapter;
|
||||
|
||||
const t2iAdapterNode: T2IAdapterInvocation = {
|
||||
const t2iAdapterNode: Invocation<'t2i_adapter'> = {
|
||||
id: `t2i_adapter_${id}`,
|
||||
type: 't2i_adapter',
|
||||
is_intermediate: true,
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
@@ -21,8 +20,8 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
export const addVAEToGraph = async (
|
||||
state: RootState,
|
||||
@@ -1,29 +1,23 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type {
|
||||
ImageNSFWBlurInvocation,
|
||||
ImageWatermarkInvocation,
|
||||
LatentsToImageInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
|
||||
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from 'features/nodes/util/graph/constants';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
export const addWatermarkerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
nodeIdToAddTo = LATENTS_TO_IMAGE
|
||||
): void => {
|
||||
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
|
||||
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
|
||||
|
||||
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined;
|
||||
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as Invocation<'img_nsfw'> | undefined;
|
||||
|
||||
if (!nodeToAddTo) {
|
||||
// something has gone terribly awry
|
||||
return;
|
||||
}
|
||||
|
||||
const watermarkerNode: ImageWatermarkInvocation = {
|
||||
const watermarkerNode: Invocation<'img_watermark'> = {
|
||||
id: WATERMARKER,
|
||||
type: 'img_watermark',
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user