Compare commits

...

78 Commits

Author SHA1 Message Date
brandonrising
59327e827b Create new data structures for captioned images, and a list of captioned images. Create auto_caption_image node which can take a single image or list of images to caption 2024-05-17 14:31:33 -04:00
psychedelicious
a18d7adad4 fix(ui): allow image dims multiple of 32 with SDXL and T2I adapter
See https://github.com/invoke-ai/InvokeAI/pull/6342#issuecomment-2109912452 for discussion.
2024-05-17 23:38:54 +10:00
psychedelicious
32dff2c4e3 feat(ui): copy/paste input edges when copying node
- Copy edges to selected nodes on copy
- If pasted with `ctrl/meta-shift-v`, also paste the input edges
2024-05-17 23:12:29 +10:00
psychedelicious
575ecb4028 feat(ui): prevent connections to direct-only inputs 2024-05-17 22:08:40 +10:00
psychedelicious
ad8778df6c feat(ui): extract node execution state from nodesSlice
This state is ephemeral and not undoable.
2024-05-17 13:24:23 +10:00
psychedelicious
d2f5103f9f fix(ui): ignore actions from other slices in nodesSlice history 2024-05-17 13:24:23 +10:00
psychedelicious
dd42a56084 tests(ui): fix parseSchema test fixture
The schema fixture wasn't formatted quite right - doesn't affect the test but still.
2024-05-17 13:24:23 +10:00
psychedelicious
23ac340a3f tests(ui): add test for parseSchema 2024-05-17 13:24:23 +10:00
psychedelicious
6791b4eaa8 chore(ui): lint 2024-05-17 13:24:23 +10:00
psychedelicious
a8b042177d feat(ui): connection validation for collection items types 2024-05-17 13:24:23 +10:00
psychedelicious
76825f4261 fix(ui): allow collect node inputs to connect to multiple fields when using lazy connect 2024-05-17 13:24:23 +10:00
psychedelicious
78cb4d75ad fix(ui): use elevateEdgesOnSelect so last-selected edge is the interactable one when updating edges 2024-05-17 13:24:23 +10:00
psychedelicious
a18bbac262 fix(ui): jank interaction between edge update and autoconnect 2024-05-17 13:24:23 +10:00
psychedelicious
9ff5596963 feat(ui): hide values for connected fields 2024-05-17 13:24:23 +10:00
psychedelicious
8ea596b1e9 fix(ui): janky editable field title
- Do not allow whitespace-only field titles
- Make only preview text trigger editable
- Tooltip over the preview, not the whole "row"
2024-05-17 13:24:23 +10:00
psychedelicious
e3a143eaed fix(ui): fix jank w/ stale connections 2024-05-17 13:24:23 +10:00
psychedelicious
c359ab6d9b fix(ui): fix dependency tracking for copy/paste hotkeys 2024-05-17 13:24:23 +10:00
psychedelicious
dbfaa07e03 feat(ui): add checks for undo/redo actions 2024-05-17 13:24:23 +10:00
psychedelicious
7f78fe7a36 feat(ui): move viewport state to nanostores 2024-05-17 13:24:23 +10:00
psychedelicious
6cf5b402c6 feat(ui): remove extraneous selectedEdges and selectedNodes state 2024-05-17 13:24:23 +10:00
psychedelicious
b0c7c7cb47 feat(ui): remove remaining extraneous state from nodes slice 2024-05-17 13:24:23 +10:00
psychedelicious
4d68cd8dbb feat(ui): recreate edge auto-add-node logic 2024-05-17 13:24:23 +10:00
psychedelicious
2c1fa30639 feat(ui): recreate edge autoconnect logic 2024-05-17 13:24:23 +10:00
psychedelicious
708c68413d tidy(ui): add type for templates 2024-05-17 13:24:23 +10:00
psychedelicious
1d884fb794 feat(ui): move invocation templates out of redux
Templates are stored in nanostores. All hooks, selectors, etc are reworked to reference the nanostore.
2024-05-17 13:24:23 +10:00
psychedelicious
f6a44681a8 feat(ui): move invocation templates out of redux (wip) 2024-05-17 13:24:23 +10:00
psychedelicious
d4df312300 feat(ui): move nodes copy/paste out of slice 2024-05-17 13:24:23 +10:00
psychedelicious
9c0d44b412 feat(ui): split workflow editor settings to separate slice
We need the undoable slice to be only undoable state - settings are not undoable.
2024-05-17 13:24:23 +10:00
psychedelicious
27826369f0 feat(ui): make nodesSlice undoable 2024-05-17 13:24:23 +10:00
H0onnn
31d8b50276 [Refactor] Update min and max values for LoRACard weight input 2024-05-17 10:38:26 +10:00
psychedelicious
40b4fa7238 feat(ui): SDXL clip skip
Uses the same CLIP Skip value for both CLIP1 and CLIP2.

Adjusted SDXL CLIP Skip min/max/markers to be within the valid range (0 to 11).

Closes #4583
2024-05-16 07:49:30 -04:00
psychedelicious
3b1743b7c2 docs: fix install reqs link 2024-05-16 10:37:42 +10:00
psychedelicious
f489c818f1 docs(ui): add comments to nsfw & watermarker helpers 2024-05-15 14:09:44 +10:00
psychedelicious
af477fa295 tidy(ui): remove unused modelLoader from refiner helper 2024-05-15 14:09:44 +10:00
psychedelicious
0ff0290735 tidy(ui): use Invocation<> helper type in canvas graph builders, elsewhere 2024-05-15 14:09:44 +10:00
psychedelicious
67dbe6d949 tidy(ui): use Invocation<> helper type in OG control adapters 2024-05-15 14:09:44 +10:00
psychedelicious
4c3c2297b9 tidy(ui): organise graph builder files 2024-05-15 14:09:44 +10:00
psychedelicious
cadea55521 tidy(ui): organise graph builder files 2024-05-15 14:09:44 +10:00
psychedelicious
c8f30b1392 tidy(ui): move testing-only types to test file 2024-05-15 14:09:44 +10:00
psychedelicious
3d14a98abf tidy(ui): use Invocation<> type in control layers types 2024-05-15 14:09:44 +10:00
psychedelicious
77024bfca7 fix(ui): fix sdxl generation mode metadata 2024-05-15 14:09:44 +10:00
psychedelicious
4a1c3786a1 tidy(ui): organise CL graph builder 2024-05-15 14:09:44 +10:00
psychedelicious
b239891986 tidy(ui): clean up base model handling in graph builder 2024-05-15 14:09:44 +10:00
psychedelicious
9fb03d43ff tests(ui): get coverage to 100% for graph builder 2024-05-15 14:09:44 +10:00
psychedelicious
bdc59786bd tidy(ui): clean up graph builder helper functions 2024-05-15 14:09:44 +10:00
psychedelicious
fb6e926500 tidy(ui): remove extraneous graph validate calls 2024-05-15 14:09:44 +10:00
psychedelicious
48ccd63dba feat(ui): use integrated metadata helper 2024-05-15 14:09:44 +10:00
psychedelicious
ee647a05dc feat(ui): move metadata util to graph class
No good reason to have it be separate. A bit cleaner this way.
2024-05-15 14:09:44 +10:00
psychedelicious
154b52ca4d docs(ui): update docstrings for Graph builder 2024-05-15 14:09:44 +10:00
psychedelicious
5dd460c3ce chore(ui): knip 2024-05-15 14:09:44 +10:00
psychedelicious
4897ce2a13 tidy(ui): remove unused files 2024-05-15 14:09:44 +10:00
psychedelicious
5425526d50 feat(ui): use graph builder for generation tab sdxl 2024-05-15 14:09:44 +10:00
psychedelicious
5a4b050e66 feat(ui): use asserts in graph builder 2024-05-15 14:09:44 +10:00
psychedelicious
8d39520232 feat(ui): port NSFW and watermark nodes to graph builder 2024-05-15 14:09:44 +10:00
psychedelicious
04d12a1e98 feat(ui): add HRF graph builder helper 2024-05-15 14:09:44 +10:00
psychedelicious
39aa70963b docs(ui): update docstrings for addGenerationTabSeamless 2024-05-15 14:09:44 +10:00
psychedelicious
5743254a41 fix(ui): use arrays for edge methods 2024-05-15 14:09:44 +10:00
psychedelicious
c538ffea26 tidy(ui): remove console.log 2024-05-15 14:09:44 +10:00
psychedelicious
e8d3a7c870 feat(ui): support multiple fields for getEdgesTo, getEdgesFrom, deleteEdgesTo, deleteEdgesFrom 2024-05-15 14:09:44 +10:00
psychedelicious
2be66b1546 feat(ui): add deleteNode and getEdges to graph util 2024-05-15 14:09:44 +10:00
psychedelicious
76e181fd44 build(ui): add eslint no-console rule 2024-05-15 14:09:44 +10:00
psychedelicious
b5d42fbc66 tidy(ui): remove unused graph helper 2024-05-15 14:09:44 +10:00
psychedelicious
b463cd763e tidy(ui): remove extraneous is_intermediate node fields 2024-05-15 14:09:44 +10:00
psychedelicious
eb320df41d feat(ui): use new lora loaders, simplify VAE loader, seamless 2024-05-15 14:09:44 +10:00
psychedelicious
de1869773f chore(ui): typegen 2024-05-15 14:09:44 +10:00
psychedelicious
ef89c7e537 feat(nodes): add LoRASelectorInvocation, LoRACollectionLoader, SDXLLoRACollectionLoader
These simplify loading multiple LoRAs. Instead of requiring chained lora loader nodes, configure each LoRA (model & weight) with a selector, collect them, then send the collection to the collection loader to apply all of the LoRAs to the UNet/CLIP models.

The collection loaders accept a single lora or collection of loras.
2024-05-15 14:09:44 +10:00
psychedelicious
008645d386 fix(ui): work through merge conflicts (wip) 2024-05-15 14:09:44 +10:00
psychedelicious
f8042ffb41 WIP, sd1.5 works 2024-05-15 14:09:44 +10:00
psychedelicious
dbe22be598 feat(ui): use graph utils in builders (wip) 2024-05-15 14:09:44 +10:00
psychedelicious
8f6078d007 feat(ui): refine graph building util
Simpler types and API surface.
2024-05-15 14:09:44 +10:00
psychedelicious
4020bf47e2 feat(ui): add MetadataUtil class
Provides methods for manipulating a graph's metadata.
2024-05-15 14:09:44 +10:00
psychedelicious
9d685da759 feat(ui): add stateful Graph class
This stateful class provides abstractions for building a graph. It exposes graph methods like adding and removing nodes and edges.

The methods are documented, tested, and strongly typed.
2024-05-15 14:09:44 +10:00
psychedelicious
e3289856c0 feat(ui): add and use type helpers for invocations and invocation outputs 2024-05-15 14:09:44 +10:00
psychedelicious
47b8153728 build(ui): enable TS strictPropertyInitialization
https://www.typescriptlang.org/tsconfig/#strictPropertyInitialization
2024-05-15 14:09:44 +10:00
psychedelicious
7901e4c082 chore(ui): typegen 2024-05-15 14:09:44 +10:00
psychedelicious
18b0977a31 feat(api): add InvocationOutputMap to OpenAPI schema
This dynamically generated schema object maps node types to their pydantic schemas. This makes it much simpler to infer node types in the UI.
2024-05-15 14:09:44 +10:00
psychedelicious
fc6b214470 tests(ui): set up vitest coverage 2024-05-15 14:09:44 +10:00
blessedcoolant
e22211dac0 fix: Fix Outpaint not applying the expanded mask correctly
In unscaled situations
2024-05-15 13:59:01 +10:00
134 changed files with 5687 additions and 3465 deletions

View File

@@ -98,7 +98,7 @@ Updating is exactly the same as installing - download the latest installer, choo
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
[installation requirements]: INSTALLATION.md#installation-requirements
[installation requirements]: INSTALL_REQUIREMENTS.md
[FAQ]: ../help/FAQ.md
[install some models]: 050_INSTALLING_MODELS.md
[configuration docs]: ../features/CONFIGURATION.md

View File

@@ -164,6 +164,12 @@ def custom_openapi() -> dict[str, Any]:
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
@@ -172,6 +178,8 @@ def custom_openapi() -> dict[str, Any]:
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?

View File

@@ -1,10 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
from typing import Literal, Optional, List, Union
import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from transformers import AutoModelForCausalLM, AutoTokenizer
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
@@ -15,7 +16,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.primitives import ImageOutput, CaptionImageOutputs, CaptionImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
@@ -66,6 +67,56 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto)
@invocation(
"auto_caption_image",
title="Automatically Caption Image",
tags=["image", "caption"],
category="image",
version="1.2.2",
)
class CaptionImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Adds a caption to an image"""
images: Union[ImageField,List[ImageField]] = InputField(description="The image to caption")
prompt: str = InputField(default="Describe this list of images in 20 words or less", description="Describe how you would like the image to be captioned.")
def invoke(self, context: InvocationContext) -> CaptionImageOutputs:
model_id = "vikhyatk/moondream2"
model_revision = "2024-04-02"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
moondream_model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=model_revision
)
output: CaptionImageOutputs = CaptionImageOutputs()
try:
from PIL.Image import Image
images: List[Image] = []
image_fields = self.images if isinstance(self.images, list) else [self.images]
for image in image_fields:
images.append(context.images.get_pil(image.image_name))
answers: List[str] = moondream_model.batch_answer(
images=images,
prompts=[self.prompt] * len(images),
tokenizer=tokenizer,
)
assert isinstance(answers, list)
for i, answer in enumerate(answers):
output.images.append(CaptionImageOutput(
image=image_fields[i],
width=images[i].width,
height=images[i].height,
caption=answer
))
except:
raise
finally:
del moondream_model
del tokenizer
return output
@invocation(
"img_crop",
title="Crop Image",
@@ -194,7 +245,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Extracts the alpha channel of an image as a mask."""
image: ImageField = InputField(description="The image to create the mask from")
image: List[ImageField] = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@@ -190,6 +190,75 @@ class LoRALoaderInvocation(BaseInvocation):
return output
@invocation_output("lora_selector_output")
class LoRASelectorOutput(BaseInvocationOutput):
"""Model loader output"""
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> LoRASelectorOutput:
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
output = LoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base in (BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2)
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output
@invocation_output("sdxl_lora_loader_output")
class SDXLLoRALoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
@@ -279,6 +348,72 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
return output
@invocation(
"sdxl_lora_collection_loader",
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.0.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
clip2: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
output = SDXLLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base is BaseModelType.StableDiffusionXL
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.clip2 is not None:
if output.clip2 is None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(lora)
return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from typing import Optional, List
import torch
@@ -247,6 +247,17 @@ class ImageOutput(BaseInvocationOutput):
)
@invocation_output("captioned_image_output")
class CaptionImageOutput(ImageOutput):
caption: str = OutputField(description="Caption for given image")
@invocation_output("captioned_image_outputs")
class CaptionImageOutputs(BaseInvocationOutput):
images: List[CaptionImageOutput] = OutputField(description="List of captioned images", default=[])
@invocation_output("image_collection_output")
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""

View File

@@ -10,6 +10,8 @@ module.exports = {
'path/no-relative-imports': ['error', { maxDepth: 0 }],
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
},
overrides: [
/**

View File

@@ -43,4 +43,5 @@ stats.html
yalc.lock
# vitest
tsconfig.vitest-temp.json
tsconfig.vitest-temp.json
coverage/

View File

@@ -35,6 +35,7 @@
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",
"test": "vitest",
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch"
},
"madge": {
@@ -132,6 +133,8 @@
"@types/react-dom": "^18.3.0",
"@types/uuid": "^9.0.8",
"@vitejs/plugin-react-swc": "^3.6.0",
"@vitest/coverage-v8": "^1.5.0",
"@vitest/ui": "^1.5.0",
"concurrently": "^8.2.2",
"dpdm": "^3.14.0",
"eslint": "^8.57.0",

View File

@@ -229,6 +229,12 @@ devDependencies:
'@vitejs/plugin-react-swc':
specifier: ^3.6.0
version: 3.6.0(vite@5.2.11)
'@vitest/coverage-v8':
specifier: ^1.5.0
version: 1.6.0(vitest@1.6.0)
'@vitest/ui':
specifier: ^1.5.0
version: 1.6.0(vitest@1.6.0)
concurrently:
specifier: ^8.2.2
version: 8.2.2
@@ -288,7 +294,7 @@ devDependencies:
version: 4.3.2(typescript@5.4.5)(vite@5.2.11)
vitest:
specifier: ^1.6.0
version: 1.6.0(@types/node@20.12.10)
version: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
packages:
@@ -1679,6 +1685,10 @@ packages:
resolution: {integrity: sha512-4iri8i1AqYHJE2DstZYkyEprg6Pq6sKx3xn5FpySk9sNhH7qN2LLlHJCfDTZRILNwQNPD7mATWM0TBui7uC1pA==}
dev: true
/@bcoe/v8-coverage@0.2.3:
resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==}
dev: true
/@chakra-ui/accordion@2.3.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react@18.3.1):
resolution: {integrity: sha512-FSXRm8iClFyU+gVaXisOSEw0/4Q+qZbFRiuhIAkVU6Boj0FxAMrlo9a8AV5TuF77rgaHytCdHk0Ng+cyUijrag==}
peerDependencies:
@@ -3635,6 +3645,11 @@ packages:
wrap-ansi-cjs: /wrap-ansi@7.0.0
dev: true
/@istanbuljs/schema@0.1.3:
resolution: {integrity: sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==}
engines: {node: '>=8'}
dev: true
/@jest/schemas@29.6.3:
resolution: {integrity: sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==}
engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0}
@@ -3822,6 +3837,10 @@ packages:
dev: true
optional: true
/@polka/url@1.0.0-next.25:
resolution: {integrity: sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==}
dev: true
/@popperjs/core@2.11.8:
resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==}
dev: false
@@ -5146,7 +5165,7 @@ packages:
dom-accessibility-api: 0.6.3
lodash: 4.17.21
redent: 3.0.0
vitest: 1.6.0(@types/node@20.12.10)
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
dev: true
/@testing-library/user-event@14.5.2(@testing-library/dom@9.3.4):
@@ -5825,6 +5844,29 @@ packages:
- '@swc/helpers'
dev: true
/@vitest/coverage-v8@1.6.0(vitest@1.6.0):
resolution: {integrity: sha512-KvapcbMY/8GYIG0rlwwOKCVNRc0OL20rrhFkg/CHNzncV03TE2XWvO5w9uZYoxNiMEBacAJt3unSOiZ7svePew==}
peerDependencies:
vitest: 1.6.0
dependencies:
'@ampproject/remapping': 2.3.0
'@bcoe/v8-coverage': 0.2.3
debug: 4.3.4
istanbul-lib-coverage: 3.2.2
istanbul-lib-report: 3.0.1
istanbul-lib-source-maps: 5.0.4
istanbul-reports: 3.1.7
magic-string: 0.30.10
magicast: 0.3.4
picocolors: 1.0.0
std-env: 3.7.0
strip-literal: 2.1.0
test-exclude: 6.0.0
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
transitivePeerDependencies:
- supports-color
dev: true
/@vitest/expect@1.3.1:
resolution: {integrity: sha512-xofQFwIzfdmLLlHa6ag0dPV8YsnKOCP1KdAeVVh34vSjN2dcUiXYCD9htu/9eM7t8Xln4v03U9HLxLpPlsXdZw==}
dependencies:
@@ -5869,6 +5911,21 @@ packages:
tinyspy: 2.2.1
dev: true
/@vitest/ui@1.6.0(vitest@1.6.0):
resolution: {integrity: sha512-k3Lyo+ONLOgylctiGovRKy7V4+dIN2yxstX3eY5cWFXH6WP+ooVX79YSyi0GagdTQzLmT43BF27T0s6dOIPBXA==}
peerDependencies:
vitest: 1.6.0
dependencies:
'@vitest/utils': 1.6.0
fast-glob: 3.3.2
fflate: 0.8.2
flatted: 3.3.1
pathe: 1.1.2
picocolors: 1.0.0
sirv: 2.0.4
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
dev: true
/@vitest/utils@1.3.1:
resolution: {integrity: sha512-d3Waie/299qqRyHTm2DjADeTaNdNSVsnwHPWrs20JMpjh6eiVq7ggggweO8rc4arhf6rRkWuHKwvxGvejUXZZQ==}
dependencies:
@@ -8521,6 +8578,10 @@ packages:
resolution: {integrity: sha512-3yurQZ2hD9VISAhJJP9bpYFNQrHHBXE2JxxjY5aLEcDi46RmAzJE2OC9FAde0yis5ElW0jTTzs0zfg/Cca4XqQ==}
dev: true
/fflate@0.8.2:
resolution: {integrity: sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==}
dev: true
/file-entry-cache@6.0.1:
resolution: {integrity: sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==}
engines: {node: ^10.12.0 || >=12.0.0}
@@ -9084,6 +9145,10 @@ packages:
resolution: {integrity: sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==}
dev: true
/html-escaper@2.0.2:
resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==}
dev: true
/html-parse-stringify@3.0.1:
resolution: {integrity: sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==}
dependencies:
@@ -9513,6 +9578,39 @@ packages:
engines: {node: '>=0.10.0'}
dev: true
/istanbul-lib-coverage@3.2.2:
resolution: {integrity: sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==}
engines: {node: '>=8'}
dev: true
/istanbul-lib-report@3.0.1:
resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==}
engines: {node: '>=10'}
dependencies:
istanbul-lib-coverage: 3.2.2
make-dir: 4.0.0
supports-color: 7.2.0
dev: true
/istanbul-lib-source-maps@5.0.4:
resolution: {integrity: sha512-wHOoEsNJTVltaJp8eVkm8w+GVkVNHT2YDYo53YdzQEL2gWm1hBX5cGFR9hQJtuGLebidVX7et3+dmDZrmclduw==}
engines: {node: '>=10'}
dependencies:
'@jridgewell/trace-mapping': 0.3.25
debug: 4.3.4
istanbul-lib-coverage: 3.2.2
transitivePeerDependencies:
- supports-color
dev: true
/istanbul-reports@3.1.7:
resolution: {integrity: sha512-BewmUXImeuRk2YY0PVbxgKAysvhRPUQE0h5QRM++nVWyubKGV0l8qQ5op8+B2DOmwSe63Jivj0BjkPQVf8fP5g==}
engines: {node: '>=8'}
dependencies:
html-escaper: 2.0.2
istanbul-lib-report: 3.0.1
dev: true
/iterable-lookahead@1.0.0:
resolution: {integrity: sha512-hJnEP2Xk4+44DDwJqUQGdXal5VbyeWLaPyDl2AQc242Zr7iqz4DgpQOrEzglWVMGHMDCkguLHEKxd1+rOsmgSQ==}
engines: {node: '>=4'}
@@ -9912,6 +10010,14 @@ packages:
'@jridgewell/sourcemap-codec': 1.4.15
dev: true
/magicast@0.3.4:
resolution: {integrity: sha512-TyDF/Pn36bBji9rWKHlZe+PZb6Mx5V8IHCSxk7X4aljM4e/vyDvZZYwHewdVaqiA0nb3ghfHU/6AUpDxWoER2Q==}
dependencies:
'@babel/parser': 7.24.5
'@babel/types': 7.24.5
source-map-js: 1.2.0
dev: true
/make-dir@2.1.0:
resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==}
engines: {node: '>=6'}
@@ -9927,6 +10033,13 @@ packages:
semver: 6.3.1
dev: true
/make-dir@4.0.0:
resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==}
engines: {node: '>=10'}
dependencies:
semver: 7.6.0
dev: true
/map-obj@2.0.0:
resolution: {integrity: sha512-TzQSV2DiMYgoF5RycneKVUzIa9bQsj/B3tTgsE3dOGqlzHnGIDaC7XBE7grnA+8kZPnfqSGFe95VHc2oc0VFUQ==}
engines: {node: '>=4'}
@@ -10101,6 +10214,11 @@ packages:
resolution: {integrity: sha512-iSAJLHYKnX41mKcJKjqvnAN9sf0LMDTXDEvFv+ffuRR9a1MIuXLjMNL6EsnDHSkKLTWNqQQ5uo61P4EbU4NU+Q==}
dev: false
/mrmime@2.0.0:
resolution: {integrity: sha512-eu38+hdgojoyq63s+yTpN4XMBdt5l8HhMhc4VKLO9KM5caLIBvUm4thi7fFaxyTmCKeNnXZ5pAlBwCUnhA09uw==}
engines: {node: '>=10'}
dev: true
/ms@2.0.0:
resolution: {integrity: sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==}
dev: true
@@ -11766,6 +11884,15 @@ packages:
engines: {node: '>=14'}
dev: true
/sirv@2.0.4:
resolution: {integrity: sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==}
engines: {node: '>= 10'}
dependencies:
'@polka/url': 1.0.0-next.25
mrmime: 2.0.0
totalist: 3.0.1
dev: true
/sisteransi@1.0.5:
resolution: {integrity: sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==}
dev: true
@@ -12191,6 +12318,15 @@ packages:
unique-string: 2.0.0
dev: true
/test-exclude@6.0.0:
resolution: {integrity: sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w==}
engines: {node: '>=8'}
dependencies:
'@istanbuljs/schema': 0.1.3
glob: 7.2.3
minimatch: 3.1.2
dev: true
/text-table@0.2.0:
resolution: {integrity: sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==}
dev: true
@@ -12264,6 +12400,11 @@ packages:
engines: {node: '>=0.6'}
dev: true
/totalist@3.0.1:
resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==}
engines: {node: '>=6'}
dev: true
/tr46@0.0.3:
resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==}
@@ -12837,7 +12978,7 @@ packages:
fsevents: 2.3.3
dev: true
/vitest@1.6.0(@types/node@20.12.10):
/vitest@1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0):
resolution: {integrity: sha512-H5r/dN06swuFnzNFhq/dnz37bPXnq8xB2xB5JOVk8K09rUtoeNN+LHWkoQ0A/i3hvbUKKcCei9KpbxqHMLhLLA==}
engines: {node: ^18.0.0 || >=20.0.0}
hasBin: true
@@ -12867,6 +13008,7 @@ packages:
'@vitest/runner': 1.6.0
'@vitest/snapshot': 1.6.0
'@vitest/spy': 1.6.0
'@vitest/ui': 1.6.0(vitest@1.6.0)
'@vitest/utils': 1.6.0
acorn-walk: 8.3.2
chai: 4.4.1

View File

@@ -774,6 +774,7 @@
"cannotConnectOutputToOutput": "Cannot connect output to output",
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"nodePack": "Node pack",
"collection": "Collection",
"collectionFieldType": "{{name}} Collection",

View File

@@ -1,3 +1,4 @@
/* eslint-disable no-console */
import fs from 'node:fs';
import openapiTS from 'openapi-typescript';

View File

@@ -67,6 +67,8 @@ export const useSocketIO = () => {
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = $socketOptions;
// This is only enabled manually for debugging, console is allowed.
/* eslint-disable-next-line no-console */
console.log('Socket initialized', socket);
}
@@ -75,6 +77,8 @@ export const useSocketIO = () => {
return () => {
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = undefined;
// This is only enabled manually for debugging, console is allowed.
/* eslint-disable-next-line no-console */
console.log('Socket teardown', socket);
}
socket.disconnect();

View File

@@ -1,3 +1,6 @@
/* eslint-disable no-console */
// This is only enabled manually for debugging, console is allowed.
import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { diff } from 'jsondiffpatch';

View File

@@ -1,7 +1,6 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
@@ -25,13 +24,6 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {

View File

@@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const { canvas, nodes, controlAdapters, controlLayers } = getState();
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name);
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
if (imageUsage.isCanvasImage && !wasCanvasReset) {
dispatch(resetCanvas());

View File

@@ -148,7 +148,6 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
log.trace('Control Adapter preprocessor cancelled');
} else {
// Some other error condition...
console.log(error);
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {

View File

@@ -8,8 +8,8 @@ import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { buildCanvasGraph } from 'features/nodes/util/graph/buildCanvasGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';

View File

@@ -1,9 +1,9 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
import { queueApi } from 'services/api/endpoints/queue';
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
@@ -18,7 +18,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let graph;
if (model && model.base === 'sdxl') {
if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state);
} else {
graph = await buildGenerationTabGraph(state);

View File

@@ -11,9 +11,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { nodes, edges } = state.nodes;
const { nodes, edges } = state.nodes.present;
const workflow = state.workflow;
const graph = buildNodesGraph(state.nodes);
const graph = buildNodesGraph(state.nodes.present);
const builtWorkflow = buildWorkflowWithValidation({
nodes,
edges,

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { appInfoApi } from 'services/api/endpoints/appInfo';
@@ -9,7 +9,7 @@ import { appInfoApi } from 'services/api/endpoints/appInfo';
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
effect: (action, { dispatch, getState }) => {
effect: (action, { getState }) => {
const log = logger('system');
const schemaJSON = action.payload;
@@ -20,7 +20,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
dispatch(nodeTemplatesBuilt(nodeTemplates));
$templates.set(nodeTemplates);
},
});

View File

@@ -29,7 +29,7 @@ import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.nodes.forEach((node) => {
state.nodes.present.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}

View File

@@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
const log = logger('socketio');
@@ -9,6 +12,13 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(action.payload, `Generator progress`);
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
}
},
});
};

View File

@@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
@@ -9,7 +10,9 @@ import {
isImageViewerOpenChanged,
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { isImageOutput } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
@@ -28,7 +31,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
const { result, node, queue_batch_id } = data;
const { result, node, queue_batch_id, source_node_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
const { image_name } = result.image;
@@ -110,6 +113,16 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
}
}
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationError } from 'services/events/actions';
const log = logger('socketio');
@@ -9,6 +12,15 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
actionCreator: socketInvocationError,
effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error;
nes.progress = null;
nes.progressImage = null;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions';
const log = logger('socketio');
@@ -9,6 +12,12 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@@ -1,5 +1,9 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
@@ -54,6 +58,21 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
dispatch(
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
);
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
}
const clone = deepClone(nes);
clone.status = zNodeStatus.enum.PENDING;
clone.error = null;
clone.progress = null;
clone.progressImage = null;
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
}
},
});
};

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
@@ -14,7 +14,8 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
actionCreator: updateAllNodesRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const { nodes, templates } = getState().nodes;
const { nodes } = getState().nodes.present;
const templates = $templates.get();
let unableToUpdateCount = 0;
@@ -24,7 +25,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
unableToUpdateCount++;
return;
}
if (!getNeedsUpdate(node, template)) {
if (!getNeedsUpdate(node.data, template)) {
// No need to increment the count here, since we're not actually updating
return;
}

View File

@@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
@@ -14,10 +15,10 @@ import { fromZodError } from 'zod-validation-error';
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
effect: (action, { dispatch }) => {
const log = logger('nodes');
const { workflow, asCopy } = action.payload;
const nodeTemplates = getState().nodes.templates;
const nodeTemplates = $templates.get();
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);

View File

@@ -21,7 +21,8 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
@@ -50,7 +51,7 @@ const allReducers = {
[canvasSlice.name]: canvasSlice.reducer,
[gallerySlice.name]: gallerySlice.reducer,
[generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: nodesSlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[postprocessingSlice.name]: postprocessingSlice.reducer,
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
@@ -66,6 +67,7 @@ const allReducers = {
[workflowSlice.name]: workflowSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[controlLayersSlice.name]: undoable(controlLayersSlice.reducer, controlLayersUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[api.reducerPath]: api.reducer,
};
@@ -111,6 +113,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@@ -1,3 +1,4 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
@@ -9,13 +10,16 @@ import { selectControlLayersSlice } from 'features/controlLayers/store/controlLa
import type { Layer } from 'features/controlLayers/store/types';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
@@ -25,199 +29,208 @@ const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
regional_guidance_layer: 'controlLayers.regionalGuidance',
};
const selector = createMemoizedSelector(
[
selectControlAdaptersSlice,
selectGenerationSlice,
selectSystemSlice,
selectNodesSlice,
selectDynamicPromptsSlice,
selectControlLayersSlice,
activeTabNameSelector,
],
(controlAdapters, generation, system, nodes, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present;
const createSelector = (templates: Templates) =>
createMemoizedSelector(
[
selectControlAdaptersSlice,
selectGenerationSlice,
selectSystemSlice,
selectNodesSlice,
selectWorkflowSettingsSlice,
selectDynamicPromptsSlice,
selectControlLayersSlice,
activeTabNameSelector,
],
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present;
const { isConnected } = system;
const { isConnected } = system;
const reasons: { prefix?: string; content: string }[] = [];
const reasons: { prefix?: string; content: string }[] = [];
// Cannot generate if not connected
if (!isConnected) {
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
}
// Cannot generate if not connected
if (!isConnected) {
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
}
if (activeTabName === 'workflows') {
if (nodes.shouldValidateGraph) {
if (!nodes.nodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
if (activeTabName === 'workflows') {
if (workflowSettings.shouldValidateGraph) {
if (!nodes.nodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
}
nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
const nodeTemplate = templates[node.data.type];
if (!nodeTemplate) {
// Node type not found
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
return;
}
const connectedEdges = getConnectedEdges([node], nodes.edges);
forEach(node.data.inputs, (field) => {
const fieldTemplate = nodeTemplate.inputs[field.name];
const hasConnection = connectedEdges.some(
(edge) => edge.target === node.id && edge.targetHandle === field.name
);
if (!fieldTemplate) {
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
return;
}
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({
content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
}),
});
return;
}
});
});
}
} else {
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
}
nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
const nodeTemplate = nodes.templates[node.data.type];
if (!nodeTemplate) {
// Node type not found
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
return;
}
const connectedEdges = getConnectedEdges([node], nodes.edges);
forEach(node.data.inputs, (field) => {
const fieldTemplate = nodeTemplate.inputs[field.name];
const hasConnection = connectedEdges.some(
(edge) => edge.target === node.id && edge.targetHandle === field.name
);
if (!fieldTemplate) {
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
return;
}
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({
content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
}),
});
return;
}
});
});
}
} else {
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
}
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (activeTabName === 'generation') {
// Handling for generation tab
controlLayers.present.layers
.filter((l) => l.isEnabled)
.forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
if (l.type === 'control_adapter_layer') {
// Must have model
if (!l.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (l.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// Must have a control image OR, if it has a processor, it must have a processed image
if (!l.controlAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
}
// T2I Adapters require images have dimensions that are multiples of 64
if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
}
if (l.type === 'ip_adapter_layer') {
// Must have model
if (!l.ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (l.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!l.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
}
if (l.type === 'initial_image_layer') {
// Must have an image
if (!l.image) {
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
}
}
if (l.type === 'regional_guidance_layer') {
// Must have a region
if (l.maskObjects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
l.ipAdapters.forEach((ipAdapter) => {
if (activeTabName === 'generation') {
// Handling for generation tab
controlLayers.present.layers
.filter((l) => l.isEnabled)
.forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
if (l.type === 'control_adapter_layer') {
// Must have model
if (!ipAdapter.model) {
if (!l.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (l.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// Must have a control image OR, if it has a processor, it must have a processed image
if (!l.controlAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
}
// T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
if (l.controlAdapter.type === 't2i_adapter') {
const multiple = model?.base === 'sdxl' ? 32 : 64;
if (size.width % multiple !== 0 || size.height % multiple !== 0) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
}
}
if (l.type === 'ip_adapter_layer') {
// Must have model
if (!l.ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipAdapter.model?.base !== model?.base) {
if (l.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
if (!l.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
}
}
if (problems.length) {
const content = upperFirst(problems.join(', '));
reasons.push({ prefix, content });
}
});
} else {
// Handling for all other tabs
selectControlAdapterAll(controlAdapters)
.filter((ca) => ca.isEnabled)
.forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (l.type === 'initial_image_layer') {
// Must have an image
if (!l.image) {
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
}
}
if (!ca.model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push({
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
});
}
if (l.type === 'regional_guidance_layer') {
// Must have a region
if (l.maskObjects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
l.ipAdapters.forEach((ipAdapter) => {
// Must have model
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
) {
reasons.push({ content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }) });
}
});
if (problems.length) {
const content = upperFirst(problems.join(', '));
reasons.push({ prefix, content });
}
});
} else {
// Handling for all other tabs
selectControlAdapterAll(controlAdapters)
.filter((ca) => ca.isEnabled)
.forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (!ca.model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push({
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
});
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
) {
reasons.push({
content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }),
});
}
});
}
}
}
return { isReady: !reasons.length, reasons };
}
);
return { isReady: !reasons.length, reasons };
}
);
export const useIsReadyToEnqueue = () => {
const templates = useStore($templates);
const selector = useMemo(() => createSelector(templates), [templates]);
const value = useAppSelector(selector);
return value;
};

View File

@@ -5,22 +5,7 @@ import type {
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import type { components } from 'services/api/schema';
import type {
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
DepthAnythingImageProcessorInvocation,
DWOpenposeImageProcessorInvocation,
HedImageProcessorInvocation,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
MediapipeFaceProcessorInvocation,
MidasDepthImageProcessorInvocation,
MlsdImageProcessorInvocation,
NormalbaeImageProcessorInvocation,
PidiImageProcessorInvocation,
ZoeDepthImageProcessorInvocation,
} from 'services/api/types';
import type { Invocation } from 'services/api/types';
import type { O } from 'ts-toolbelt';
import { z } from 'zod';
@@ -28,20 +13,20 @@ import { z } from 'zod';
* Any ControlNet processor node
*/
export type ControlAdapterProcessorNode =
| CannyImageProcessorInvocation
| ColorMapImageProcessorInvocation
| ContentShuffleImageProcessorInvocation
| DepthAnythingImageProcessorInvocation
| HedImageProcessorInvocation
| LineartAnimeImageProcessorInvocation
| LineartImageProcessorInvocation
| MediapipeFaceProcessorInvocation
| MidasDepthImageProcessorInvocation
| MlsdImageProcessorInvocation
| NormalbaeImageProcessorInvocation
| DWOpenposeImageProcessorInvocation
| PidiImageProcessorInvocation
| ZoeDepthImageProcessorInvocation;
| Invocation<'canny_image_processor'>
| Invocation<'color_map_image_processor'>
| Invocation<'content_shuffle_image_processor'>
| Invocation<'depth_anything_image_processor'>
| Invocation<'hed_image_processor'>
| Invocation<'lineart_anime_image_processor'>
| Invocation<'lineart_image_processor'>
| Invocation<'mediapipe_face_processor'>
| Invocation<'midas_depth_image_processor'>
| Invocation<'mlsd_image_processor'>
| Invocation<'normalbae_image_processor'>
| Invocation<'dw_openpose_image_processor'>
| Invocation<'pidi_image_processor'>
| Invocation<'zoe_depth_image_processor'>;
/**
* Any ControlNet processor type
@@ -71,7 +56,7 @@ export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterPr
* The Canny processor node, with parameters flagged as required
*/
export type RequiredCannyImageProcessorInvocation = O.Required<
CannyImageProcessorInvocation,
Invocation<'canny_image_processor'>,
'type' | 'low_threshold' | 'high_threshold' | 'image_resolution' | 'detect_resolution'
>;
@@ -79,7 +64,7 @@ export type RequiredCannyImageProcessorInvocation = O.Required<
* The Color Map processor node, with parameters flagged as required
*/
export type RequiredColorMapImageProcessorInvocation = O.Required<
ColorMapImageProcessorInvocation,
Invocation<'color_map_image_processor'>,
'type' | 'color_map_tile_size'
>;
@@ -87,7 +72,7 @@ export type RequiredColorMapImageProcessorInvocation = O.Required<
* The ContentShuffle processor node, with parameters flagged as required
*/
export type RequiredContentShuffleImageProcessorInvocation = O.Required<
ContentShuffleImageProcessorInvocation,
Invocation<'content_shuffle_image_processor'>,
'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f'
>;
@@ -95,7 +80,7 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
* The DepthAnything processor node, with parameters flagged as required
*/
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
DepthAnythingImageProcessorInvocation,
Invocation<'depth_anything_image_processor'>,
'type' | 'model_size' | 'resolution' | 'offload'
>;
@@ -108,7 +93,7 @@ export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSiz
* The HED processor node, with parameters flagged as required
*/
export type RequiredHedImageProcessorInvocation = O.Required<
HedImageProcessorInvocation,
Invocation<'hed_image_processor'>,
'type' | 'detect_resolution' | 'image_resolution' | 'scribble'
>;
@@ -116,7 +101,7 @@ export type RequiredHedImageProcessorInvocation = O.Required<
* The Lineart Anime processor node, with parameters flagged as required
*/
export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
LineartAnimeImageProcessorInvocation,
Invocation<'lineart_anime_image_processor'>,
'type' | 'detect_resolution' | 'image_resolution'
>;
@@ -124,7 +109,7 @@ export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
* The Lineart processor node, with parameters flagged as required
*/
export type RequiredLineartImageProcessorInvocation = O.Required<
LineartImageProcessorInvocation,
Invocation<'lineart_image_processor'>,
'type' | 'detect_resolution' | 'image_resolution' | 'coarse'
>;
@@ -132,7 +117,7 @@ export type RequiredLineartImageProcessorInvocation = O.Required<
* The MediapipeFace processor node, with parameters flagged as required
*/
export type RequiredMediapipeFaceProcessorInvocation = O.Required<
MediapipeFaceProcessorInvocation,
Invocation<'mediapipe_face_processor'>,
'type' | 'max_faces' | 'min_confidence' | 'image_resolution' | 'detect_resolution'
>;
@@ -140,7 +125,7 @@ export type RequiredMediapipeFaceProcessorInvocation = O.Required<
* The MidasDepth processor node, with parameters flagged as required
*/
export type RequiredMidasDepthImageProcessorInvocation = O.Required<
MidasDepthImageProcessorInvocation,
Invocation<'midas_depth_image_processor'>,
'type' | 'a_mult' | 'bg_th' | 'image_resolution' | 'detect_resolution'
>;
@@ -148,7 +133,7 @@ export type RequiredMidasDepthImageProcessorInvocation = O.Required<
* The MLSD processor node, with parameters flagged as required
*/
export type RequiredMlsdImageProcessorInvocation = O.Required<
MlsdImageProcessorInvocation,
Invocation<'mlsd_image_processor'>,
'type' | 'detect_resolution' | 'image_resolution' | 'thr_v' | 'thr_d'
>;
@@ -156,7 +141,7 @@ export type RequiredMlsdImageProcessorInvocation = O.Required<
* The NormalBae processor node, with parameters flagged as required
*/
export type RequiredNormalbaeImageProcessorInvocation = O.Required<
NormalbaeImageProcessorInvocation,
Invocation<'normalbae_image_processor'>,
'type' | 'detect_resolution' | 'image_resolution'
>;
@@ -164,7 +149,7 @@ export type RequiredNormalbaeImageProcessorInvocation = O.Required<
* The DW Openpose processor node, with parameters flagged as required
*/
export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
DWOpenposeImageProcessorInvocation,
Invocation<'dw_openpose_image_processor'>,
'type' | 'image_resolution' | 'draw_body' | 'draw_face' | 'draw_hands'
>;
@@ -172,14 +157,14 @@ export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
* The Pidi processor node, with parameters flagged as required
*/
export type RequiredPidiImageProcessorInvocation = O.Required<
PidiImageProcessorInvocation,
Invocation<'pidi_image_processor'>,
'type' | 'detect_resolution' | 'image_resolution' | 'safe' | 'scribble'
>;
/**
* The ZoeDepth processor node, with parameters flagged as required
*/
export type RequiredZoeDepthImageProcessorInvocation = O.Required<ZoeDepthImageProcessorInvocation, 'type'>;
export type RequiredZoeDepthImageProcessorInvocation = O.Required<Invocation<'zoe_depth_image_processor'>, 'type'>;
/**
* Any ControlNet Processor node, with its parameters flagged as required

View File

@@ -1,23 +1,9 @@
import type { S } from 'services/api/types';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
import type {
_CannyProcessorConfig,
_ColorMapProcessorConfig,
_ContentShuffleProcessorConfig,
_DepthAnythingProcessorConfig,
_DWOpenposeProcessorConfig,
_HedProcessorConfig,
_LineartAnimeProcessorConfig,
_LineartProcessorConfig,
_MediapipeFaceProcessorConfig,
_MidasDepthProcessorConfig,
_MlsdProcessorConfig,
_NormalbaeProcessorConfig,
_PidiProcessorConfig,
_ZoeDepthProcessorConfig,
CannyProcessorConfig,
CLIPVisionModelV2,
ColorMapProcessorConfig,
@@ -45,16 +31,16 @@ describe('Control Adapter Types', () => {
assert<Equals<ProcessorConfig['type'], ProcessorTypeV2>>();
});
test('IP Adapter Method', () => {
assert<Equals<NonNullable<S['IPAdapterInvocation']['method']>, IPMethodV2>>();
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();
});
test('CLIP Vision Model', () => {
assert<Equals<NonNullable<S['IPAdapterInvocation']['clip_vision_model']>, CLIPVisionModelV2>>();
assert<Equals<NonNullable<Invocation<'ip_adapter'>['clip_vision_model']>, CLIPVisionModelV2>>();
});
test('Control Mode', () => {
assert<Equals<NonNullable<S['ControlNetInvocation']['control_mode']>, ControlModeV2>>();
assert<Equals<NonNullable<Invocation<'controlnet'>['control_mode']>, ControlModeV2>>();
});
test('DepthAnything Model Size', () => {
assert<Equals<NonNullable<S['DepthAnythingImageProcessorInvocation']['model_size']>, DepthAnythingModelSize>>();
assert<Equals<NonNullable<Invocation<'depth_anything_image_processor'>['model_size']>, DepthAnythingModelSize>>();
});
test('Processor Configs', () => {
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
@@ -75,3 +61,33 @@ describe('Control Adapter Types', () => {
assert<Equals<_ZoeDepthProcessorConfig, ZoeDepthProcessorConfig>>();
});
});
// Types derived from OpenAPI
type _CannyProcessorConfig = Required<
Pick<Invocation<'canny_image_processor'>, 'id' | 'type' | 'low_threshold' | 'high_threshold'>
>;
type _ColorMapProcessorConfig = Required<
Pick<Invocation<'color_map_image_processor'>, 'id' | 'type' | 'color_map_tile_size'>
>;
type _ContentShuffleProcessorConfig = Required<
Pick<Invocation<'content_shuffle_image_processor'>, 'id' | 'type' | 'w' | 'h' | 'f'>
>;
type _DepthAnythingProcessorConfig = Required<
Pick<Invocation<'depth_anything_image_processor'>, 'id' | 'type' | 'model_size'>
>;
type _HedProcessorConfig = Required<Pick<Invocation<'hed_image_processor'>, 'id' | 'type' | 'scribble'>>;
type _LineartAnimeProcessorConfig = Required<Pick<Invocation<'lineart_anime_image_processor'>, 'id' | 'type'>>;
type _LineartProcessorConfig = Required<Pick<Invocation<'lineart_image_processor'>, 'id' | 'type' | 'coarse'>>;
type _MediapipeFaceProcessorConfig = Required<
Pick<Invocation<'mediapipe_face_processor'>, 'id' | 'type' | 'max_faces' | 'min_confidence'>
>;
type _MidasDepthProcessorConfig = Required<
Pick<Invocation<'midas_depth_image_processor'>, 'id' | 'type' | 'a_mult' | 'bg_th'>
>;
type _MlsdProcessorConfig = Required<Pick<Invocation<'mlsd_image_processor'>, 'id' | 'type' | 'thr_v' | 'thr_d'>>;
type _NormalbaeProcessorConfig = Required<Pick<Invocation<'normalbae_image_processor'>, 'id' | 'type'>>;
type _DWOpenposeProcessorConfig = Required<
Pick<Invocation<'dw_openpose_image_processor'>, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>;
type _PidiProcessorConfig = Required<Pick<Invocation<'pidi_image_processor'>, 'id' | 'type' | 'safe' | 'scribble'>>;
type _ZoeDepthProcessorConfig = Required<Pick<Invocation<'zoe_depth_image_processor'>, 'id' | 'type'>>;

View File

@@ -1,27 +1,7 @@
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, omit } from 'lodash-es';
import type {
BaseModelType,
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
ControlNetModelConfig,
DepthAnythingImageProcessorInvocation,
DWOpenposeImageProcessorInvocation,
Graph,
HedImageProcessorInvocation,
ImageDTO,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
MediapipeFaceProcessorInvocation,
MidasDepthImageProcessorInvocation,
MlsdImageProcessorInvocation,
NormalbaeImageProcessorInvocation,
PidiImageProcessorInvocation,
T2IAdapterModelConfig,
ZoeDepthImageProcessorInvocation,
} from 'services/api/types';
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import { z } from 'zod';
const zId = z.string().min(1);
@@ -32,9 +12,6 @@ const zCannyProcessorConfig = z.object({
low_threshold: z.number().int().gte(0).lte(255),
high_threshold: z.number().int().gte(0).lte(255),
});
export type _CannyProcessorConfig = Required<
Pick<CannyImageProcessorInvocation, 'id' | 'type' | 'low_threshold' | 'high_threshold'>
>;
export type CannyProcessorConfig = z.infer<typeof zCannyProcessorConfig>;
const zColorMapProcessorConfig = z.object({
@@ -42,9 +19,6 @@ const zColorMapProcessorConfig = z.object({
type: z.literal('color_map_image_processor'),
color_map_tile_size: z.number().int().gte(1),
});
export type _ColorMapProcessorConfig = Required<
Pick<ColorMapImageProcessorInvocation, 'id' | 'type' | 'color_map_tile_size'>
>;
export type ColorMapProcessorConfig = z.infer<typeof zColorMapProcessorConfig>;
const zContentShuffleProcessorConfig = z.object({
@@ -54,9 +28,6 @@ const zContentShuffleProcessorConfig = z.object({
h: z.number().int().gte(0),
f: z.number().int().gte(0),
});
export type _ContentShuffleProcessorConfig = Required<
Pick<ContentShuffleImageProcessorInvocation, 'id' | 'type' | 'w' | 'h' | 'f'>
>;
export type ContentShuffleProcessorConfig = z.infer<typeof zContentShuffleProcessorConfig>;
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
@@ -68,9 +39,6 @@ const zDepthAnythingProcessorConfig = z.object({
type: z.literal('depth_anything_image_processor'),
model_size: zDepthAnythingModelSize,
});
export type _DepthAnythingProcessorConfig = Required<
Pick<DepthAnythingImageProcessorInvocation, 'id' | 'type' | 'model_size'>
>;
export type DepthAnythingProcessorConfig = z.infer<typeof zDepthAnythingProcessorConfig>;
const zHedProcessorConfig = z.object({
@@ -78,14 +46,12 @@ const zHedProcessorConfig = z.object({
type: z.literal('hed_image_processor'),
scribble: z.boolean(),
});
export type _HedProcessorConfig = Required<Pick<HedImageProcessorInvocation, 'id' | 'type' | 'scribble'>>;
export type HedProcessorConfig = z.infer<typeof zHedProcessorConfig>;
const zLineartAnimeProcessorConfig = z.object({
id: zId,
type: z.literal('lineart_anime_image_processor'),
});
export type _LineartAnimeProcessorConfig = Required<Pick<LineartAnimeImageProcessorInvocation, 'id' | 'type'>>;
export type LineartAnimeProcessorConfig = z.infer<typeof zLineartAnimeProcessorConfig>;
const zLineartProcessorConfig = z.object({
@@ -93,7 +59,6 @@ const zLineartProcessorConfig = z.object({
type: z.literal('lineart_image_processor'),
coarse: z.boolean(),
});
export type _LineartProcessorConfig = Required<Pick<LineartImageProcessorInvocation, 'id' | 'type' | 'coarse'>>;
export type LineartProcessorConfig = z.infer<typeof zLineartProcessorConfig>;
const zMediapipeFaceProcessorConfig = z.object({
@@ -102,9 +67,6 @@ const zMediapipeFaceProcessorConfig = z.object({
max_faces: z.number().int().gte(1),
min_confidence: z.number().gte(0).lte(1),
});
export type _MediapipeFaceProcessorConfig = Required<
Pick<MediapipeFaceProcessorInvocation, 'id' | 'type' | 'max_faces' | 'min_confidence'>
>;
export type MediapipeFaceProcessorConfig = z.infer<typeof zMediapipeFaceProcessorConfig>;
const zMidasDepthProcessorConfig = z.object({
@@ -113,9 +75,6 @@ const zMidasDepthProcessorConfig = z.object({
a_mult: z.number().gte(0),
bg_th: z.number().gte(0),
});
export type _MidasDepthProcessorConfig = Required<
Pick<MidasDepthImageProcessorInvocation, 'id' | 'type' | 'a_mult' | 'bg_th'>
>;
export type MidasDepthProcessorConfig = z.infer<typeof zMidasDepthProcessorConfig>;
const zMlsdProcessorConfig = z.object({
@@ -124,14 +83,12 @@ const zMlsdProcessorConfig = z.object({
thr_v: z.number().gte(0),
thr_d: z.number().gte(0),
});
export type _MlsdProcessorConfig = Required<Pick<MlsdImageProcessorInvocation, 'id' | 'type' | 'thr_v' | 'thr_d'>>;
export type MlsdProcessorConfig = z.infer<typeof zMlsdProcessorConfig>;
const zNormalbaeProcessorConfig = z.object({
id: zId,
type: z.literal('normalbae_image_processor'),
});
export type _NormalbaeProcessorConfig = Required<Pick<NormalbaeImageProcessorInvocation, 'id' | 'type'>>;
export type NormalbaeProcessorConfig = z.infer<typeof zNormalbaeProcessorConfig>;
const zDWOpenposeProcessorConfig = z.object({
@@ -141,9 +98,6 @@ const zDWOpenposeProcessorConfig = z.object({
draw_face: z.boolean(),
draw_hands: z.boolean(),
});
export type _DWOpenposeProcessorConfig = Required<
Pick<DWOpenposeImageProcessorInvocation, 'id' | 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
>;
export type DWOpenposeProcessorConfig = z.infer<typeof zDWOpenposeProcessorConfig>;
const zPidiProcessorConfig = z.object({
@@ -152,14 +106,12 @@ const zPidiProcessorConfig = z.object({
safe: z.boolean(),
scribble: z.boolean(),
});
export type _PidiProcessorConfig = Required<Pick<PidiImageProcessorInvocation, 'id' | 'type' | 'safe' | 'scribble'>>;
export type PidiProcessorConfig = z.infer<typeof zPidiProcessorConfig>;
const zZoeDepthProcessorConfig = z.object({
id: zId,
type: z.literal('zoe_depth_image_processor'),
});
export type _ZoeDepthProcessorConfig = Required<Pick<ZoeDepthImageProcessorInvocation, 'id' | 'type'>>;
export type ZoeDepthProcessorConfig = z.infer<typeof zZoeDepthProcessorConfig>;
const zProcessorConfig = z.discriminatedUnion('type', [

View File

@@ -1,23 +1,21 @@
import type { Modifier } from '@dnd-kit/core';
import { getEventCoordinates } from '@dnd-kit/utilities';
import { createSelector } from '@reduxjs/toolkit';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $viewport } from 'features/nodes/store/nodesSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
const selectZoom = createSelector([selectNodesSlice, activeTabNameSelector], (nodes, activeTabName) =>
activeTabName === 'workflows' ? nodes.viewport.zoom : 1
);
/**
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
*/
export const useScaledModifer = () => {
const zoom = useAppSelector(selectZoom);
const activeTabName = useAppSelector(activeTabNameSelector);
const workflowsViewport = useStore($viewport);
const modifier: Modifier = useCallback(
({ activatorEvent, draggingNodeRect, transform }) => {
if (draggingNodeRect && activatorEvent) {
const zoom = activeTabName === 'workflows' ? workflowsViewport.zoom : 1;
const activatorCoordinates = getEventCoordinates(activatorEvent);
if (!activatorCoordinates) {
@@ -42,7 +40,7 @@ export const useScaledModifer = () => {
return transform;
},
[zoom]
[activeTabName, workflowsViewport.zoom]
);
return modifier;

View File

@@ -75,8 +75,8 @@ export const LoRACard = memo((props: LoRACardProps) => {
<CompositeNumberInput
value={lora.weight}
onChange={handleChange}
min={-5}
max={5}
min={-10}
max={10}
step={0.01}
w={20}
flexShrink={0}

View File

@@ -2,26 +2,34 @@ import 'reactflow/dist/style.css';
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
addNodePopoverClosed,
addNodePopoverOpened,
$cursorPos,
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
closeAddNodePopover,
connectionMade,
nodeAdded,
selectNodesSlice,
openAddNodePopover,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { filter, map, memoize, some } from 'lodash-es';
import type { KeyboardEventHandler } from 'react';
import { memo, useCallback, useRef } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import { assert } from 'tsafe';
const createRegex = memoize(
(inputValue: string) =>
@@ -54,26 +62,30 @@ const AddNodePopover = () => {
const { t } = useTranslation();
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
const inputRef = useRef<HTMLInputElement>(null);
const templates = useStore($templates);
const pendingConnection = useStore($pendingConnection);
const isOpen = useStore($isAddNodePopoverOpen);
const store = useAppStore();
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const filteredTemplates = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices
const filteredNodeTemplates = fieldFilter
? filter(nodes.templates, (template) => {
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
if (!pendingConnection) {
return map(templates);
}
return some(handles, (handle) => {
const sourceType = handleFilter === 'source' ? fieldFilter : handle.type;
const targetType = handleFilter === 'target' ? fieldFilter : handle.type;
return filter(templates, (template) => {
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
});
}, [templates, pendingConnection]);
return validateSourceAndTargetTypes(sourceType, targetType);
});
})
: map(nodes.templates);
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
const options = useMemo(() => {
const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
return {
label: template.title,
value: template.type,
@@ -83,15 +95,15 @@ const AddNodePopover = () => {
});
//We only want these nodes if we're not filtered
if (fieldFilter === null) {
options.push({
if (!pendingConnection) {
_options.push({
label: t('nodes.currentImage'),
value: 'current_image',
description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
options.push({
_options.push({
label: t('nodes.notes'),
value: 'notes',
description: t('nodes.notesDescription'),
@@ -99,18 +111,15 @@ const AddNodePopover = () => {
});
}
options.sort((a, b) => a.label.localeCompare(b.label));
_options.sort((a, b) => a.label.localeCompare(b.label));
return { options };
});
const { options } = useAppSelector(selector);
const isOpen = useAppSelector((s) => s.nodes.isAddNodePopoverOpen);
return _options;
}, [filteredTemplates, pendingConnection, t]);
const addNode = useCallback(
(nodeType: string) => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
(nodeType: string): AnyNode | null => {
const node = buildInvocation(nodeType);
if (!node) {
const errorMessage = t('nodes.unknownNode', {
nodeType: nodeType,
});
@@ -118,10 +127,11 @@ const AddNodePopover = () => {
status: 'error',
title: errorMessage,
});
return;
return null;
}
dispatch(nodeAdded(invocation));
const cursorPos = $cursorPos.get();
dispatch(nodeAdded({ node, cursorPos }));
return node;
},
[dispatch, buildInvocation, toaster, t]
);
@@ -131,52 +141,50 @@ const AddNodePopover = () => {
if (!v) {
return;
}
addNode(v.value);
dispatch(addNodePopoverClosed());
const node = addNode(v.value);
// Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) {
const template = templates[node.data.type];
assert(template, 'Template not found');
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
if (connection) {
dispatch(connectionMade(connection));
}
}
closeAddNodePopover();
},
[addNode, dispatch]
[addNode, dispatch, pendingConnection, store, templates]
);
const onClose = useCallback(() => {
dispatch(addNodePopoverClosed());
}, [dispatch]);
const onOpen = useCallback(() => {
dispatch(addNodePopoverOpened());
}, [dispatch]);
const handleHotkeyOpen: HotkeyCallback = useCallback(
(e) => {
e.preventDefault();
onOpen();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
},
[onOpen]
);
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
e.preventDefault();
openAddNodePopover();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
}, []);
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
onClose();
}, [onClose]);
closeAddNodePopover();
}, []);
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
useHotkeys(['escape'], handleHotkeyClose);
const onKeyDown: KeyboardEventHandler = useCallback(
(e) => {
if (e.key === 'Escape') {
onClose();
}
},
[onClose]
);
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
if (e.key === 'Escape') {
closeAddNodePopover();
}
}, []);
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
return (
<Popover
isOpen={isOpen}
onClose={onClose}
onClose={closeAddNodePopover}
placement="bottom"
openDelay={0}
closeDelay={0}
@@ -206,7 +214,7 @@ const AddNodePopover = () => {
noOptionsMessage={noOptionsMessage}
filterOption={filterOption}
onChange={onChange}
onMenuClose={onClose}
onMenuClose={closeAddNodePopover}
onKeyDown={onKeyDown}
inputRef={inputRef}
closeMenuOnSelect={false}

View File

@@ -1,34 +1,33 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import {
connectionEnded,
$cursorPos,
$isAddNodePopoverOpen,
$isUpdatingEdge,
$pendingConnection,
$viewport,
connectionMade,
connectionStarted,
edgeAdded,
edgeChangeStarted,
edgeDeleted,
edgesChanged,
edgesDeleted,
nodesChanged,
nodesDeleted,
redo,
selectedAll,
selectedEdgesChanged,
selectedNodesChanged,
selectionCopied,
selectionPasted,
viewportChanged,
undo,
} from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import type {
OnConnect,
OnConnectEnd,
OnConnectStart,
OnEdgesChange,
OnEdgesDelete,
OnEdgeUpdateFunc,
@@ -36,12 +35,11 @@ import type {
OnMoveEnd,
OnNodesChange,
OnNodesDelete,
OnSelectionChangeFunc,
ProOptions,
ReactFlowProps,
XYPosition,
ReactFlowState,
} from 'reactflow';
import { Background, ReactFlow } from 'reactflow';
import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow';
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
@@ -68,17 +66,23 @@ const proOptions: ProOptions = { hideAttribution: true };
const snapGrid: [number, number] = [25, 25];
const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection;
export const Flow = memo(() => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((s) => s.nodes.nodes);
const edges = useAppSelector((s) => s.nodes.edges);
const viewport = useAppSelector((s) => s.nodes.viewport);
const shouldSnapToGrid = useAppSelector((s) => s.nodes.shouldSnapToGrid);
const selectionMode = useAppSelector((s) => s.nodes.selectionMode);
const nodes = useAppSelector((s) => s.nodes.present.nodes);
const edges = useAppSelector((s) => s.nodes.present.edges);
const viewport = useStore($viewport);
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
const selectionMode = useAppSelector((s) => s.workflowSettings.selectionMode);
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
const flowWrapper = useRef<HTMLDivElement>(null);
const cursorPosition = useRef<XYPosition | null>(null);
const isValidConnection = useIsValidConnection();
const cancelConnection = useReactFlowStore(selectCancelConnection);
useWorkflowWatcher();
useSyncExecutionState();
const [borderRadius] = useToken('radii', ['base']);
const flowStyles = useMemo<CSSProperties>(
@@ -102,32 +106,6 @@ export const Flow = memo(() => {
[dispatch]
);
const onConnectStart: OnConnectStart = useCallback(
(event, params) => {
dispatch(connectionStarted(params));
},
[dispatch]
);
const onConnect: OnConnect = useCallback(
(connection) => {
dispatch(connectionMade(connection));
},
[dispatch]
);
const onConnectEnd: OnConnectEnd = useCallback(() => {
if (!cursorPosition.current) {
return;
}
dispatch(
connectionEnded({
cursorPosition: cursorPosition.current,
mouseOverNodeId: $mouseOverNode.get(),
})
);
}, [dispatch]);
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
@@ -142,20 +120,9 @@ export const Flow = memo(() => {
[dispatch]
);
const handleSelectionChange: OnSelectionChangeFunc = useCallback(
({ nodes, edges }) => {
dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
},
[dispatch]
);
const handleMoveEnd: OnMoveEnd = useCallback(
(e, viewport) => {
dispatch(viewportChanged(viewport));
},
[dispatch]
);
const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => {
$viewport.set(viewport);
}, []);
const { onCloseGlobal } = useGlobalMenuClose();
const handlePaneClick = useCallback(() => {
@@ -169,11 +136,12 @@ export const Flow = memo(() => {
const onMouseMove = useCallback((event: MouseEvent<HTMLDivElement>) => {
if (flowWrapper.current?.getBoundingClientRect()) {
cursorPosition.current =
$cursorPos.set(
$flow.get()?.screenToFlowPosition({
x: event.clientX,
y: event.clientY,
}) ?? null;
}) ?? null
);
}
}, []);
@@ -195,19 +163,18 @@ export const Flow = memo(() => {
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback(
(e, edge, _handleType) => {
$isUpdatingEdge.set(true);
// update mouse event
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
dispatch(edgeChangeStarted());
},
[dispatch]
);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(_oldEdge, newConnection) => {
// instead of updating the edge (we deleted it earlier), we instead create
// a new one.
// Because we deleted the edge when the update started, we must create a new edge from the connection
dispatch(connectionMade(newConnection));
},
[dispatch]
@@ -215,8 +182,10 @@ export const Flow = memo(() => {
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> = useCallback(
(e, edge, _handleType) => {
// Handle the case where user begins a drag but didn't move the cursor -
// bc we deleted the edge, we need to add it back
$isUpdatingEdge.set(false);
$pendingConnection.set(null);
// Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting
// the edge update - we need to add it back
if (
// ignore touch events
!('touches' in e) &&
@@ -233,23 +202,64 @@ export const Flow = memo(() => {
// #endregion
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault();
dispatch(selectionCopied());
});
const { copySelection, pasteSelection } = useCopyPaste();
useHotkeys(['Ctrl+a', 'Meta+a'], (e) => {
e.preventDefault();
dispatch(selectedAll());
});
const onCopyHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
copySelection();
},
[copySelection]
);
useHotkeys(['Ctrl+c', 'Meta+c'], onCopyHotkey);
useHotkeys(['Ctrl+v', 'Meta+v'], (e) => {
if (!cursorPosition.current) {
return;
const onSelectAllHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
dispatch(selectedAll());
},
[dispatch]
);
useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey);
const onPasteHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
pasteSelection();
},
[pasteSelection]
);
useHotkeys(['Ctrl+v', 'Meta+v'], onPasteHotkey);
const onPasteWithEdgesToNodesHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
pasteSelection(true);
},
[pasteSelection]
);
useHotkeys(['Ctrl+shift+v', 'Meta+shift+v'], onPasteWithEdgesToNodesHotkey);
const onUndoHotkey = useCallback(() => {
if (mayUndo) {
dispatch(undo());
}
e.preventDefault();
dispatch(selectionPasted({ cursorPosition: cursorPosition.current }));
});
}, [dispatch, mayUndo]);
useHotkeys(['meta+z', 'ctrl+z'], onUndoHotkey);
const onRedoHotkey = useCallback(() => {
if (mayRedo) {
dispatch(redo());
}
}, [dispatch, mayRedo]);
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
const onEscapeHotkey = useCallback(() => {
$pendingConnection.set(null);
$isAddNodePopoverOpen.set(false);
cancelConnection();
}, [cancelConnection]);
useHotkeys('esc', onEscapeHotkey);
return (
<ReactFlow
@@ -274,7 +284,6 @@ export const Flow = memo(() => {
onConnectEnd={onConnectEnd}
onMoveEnd={handleMoveEnd}
connectionLineComponent={CustomConnectionLine}
onSelectionChange={handleSelectionChange}
isValidConnection={isValidConnection}
minZoom={0.1}
snapToGrid={shouldSnapToGrid}
@@ -285,6 +294,7 @@ export const Flow = memo(() => {
onPaneClick={handlePaneClick}
deleteKeyCode={DELETE_KEYS}
selectionMode={selectionMode}
elevateEdgesOnSelect
>
<Background />
</ReactFlow>

View File

@@ -1,26 +1,33 @@
import { createSelector } from '@reduxjs/toolkit';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $pendingConnection } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import type { ConnectionLineComponentProps } from 'reactflow';
import { getBezierPath } from 'reactflow';
const selectStroke = createSelector(selectNodesSlice, (nodes) =>
nodes.shouldColorEdges ? getFieldColor(nodes.connectionStartFieldType) : colorTokenToCssVar('base.500')
);
const selectClassName = createSelector(selectNodesSlice, (nodes) =>
nodes.shouldAnimateEdges ? 'react-flow__custom_connection-path animated' : 'react-flow__custom_connection-path'
);
const pathStyles: CSSProperties = { opacity: 0.8 };
const CustomConnectionLine = ({ fromX, fromY, fromPosition, toX, toY, toPosition }: ConnectionLineComponentProps) => {
const stroke = useAppSelector(selectStroke);
const className = useAppSelector(selectClassName);
const pendingConnection = useStore($pendingConnection);
const shouldColorEdges = useAppSelector((state) => state.workflowSettings.shouldColorEdges);
const shouldAnimateEdges = useAppSelector((state) => state.workflowSettings.shouldAnimateEdges);
const stroke = useMemo(() => {
if (shouldColorEdges && pendingConnection) {
return getFieldColor(pendingConnection.fieldTemplate.type);
} else {
return colorTokenToCssVar('base.500');
}
}, [pendingConnection, shouldColorEdges]);
const className = useMemo(() => {
if (shouldAnimateEdges) {
return 'react-flow__custom_connection-path animated';
} else {
return 'react-flow__custom_connection-path';
}
}, [shouldAnimateEdges]);
const pathParams = {
sourceX: fromX,

View File

@@ -1,6 +1,8 @@
import { Badge, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { $templates } from 'features/nodes/store/nodesSlice';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
@@ -22,9 +24,10 @@ const InvocationCollapsedEdge = ({
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected),
[selected, source, sourceHandleId, target, targetHandleId]
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, selected, source, sourceHandleId, target, targetHandleId]
);
const { isSelected, shouldAnimate } = useAppSelector(selector);

View File

@@ -1,5 +1,7 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
@@ -21,13 +23,14 @@ const InvocationDefaultEdge = ({
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected),
[source, sourceHandleId, target, targetHandleId, selected]
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels);
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,

View File

@@ -1,7 +1,8 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
@@ -14,6 +15,7 @@ const defaultReturnValue = {
};
export const makeEdgeSelector = (
templates: Templates,
source: string,
sourceHandleId: string | null | undefined,
target: string,
@@ -22,7 +24,8 @@ export const makeEdgeSelector = (
) =>
createMemoizedSelector(
selectNodesSlice,
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
selectWorkflowSettingsSlice,
(nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
@@ -33,19 +36,20 @@ export const makeEdgeSelector = (
return defaultReturnValue;
}
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceNodeTemplate = templates[sourceNode.data.type];
const targetNodeTemplate = templates[targetNode.data.type];
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
const stroke =
sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected,
stroke,
label,
};

View File

@@ -1,12 +1,10 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi';
@@ -24,12 +22,7 @@ const circleStyles: SystemStyleObject = {
};
const InvocationNodeStatusIndicator = ({ nodeId }: Props) => {
const selectNodeExecutionState = useMemo(
() => createMemoizedSelector(selectNodesSlice, (nodes) => nodes.nodeExecutionStates[nodeId]),
[nodeId]
);
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
const nodeExecutionState = useExecutionState(nodeId);
if (!nodeExecutionState) {
return null;

View File

@@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import type { NodeProps } from 'reactflow';
@@ -11,13 +11,13 @@ import InvocationNodeUnknownFallback from './InvocationNodeUnknownFallback';
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { data, selected } = props;
const { id: nodeId, type, isOpen, label } = data;
const templates = useStore($templates);
const hasTemplate = useMemo(() => Boolean(templates[type]), [templates, type]);
const nodeExists = useAppSelector((s) => Boolean(s.nodes.present.nodes.find((n) => n.id === nodeId)));
const hasTemplateSelector = useMemo(
() => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])),
[type]
);
const hasTemplate = useAppSelector(hasTemplateSelector);
if (!nodeExists) {
return null;
}
if (!hasTemplate) {
return (

View File

@@ -37,7 +37,8 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
const [localTitle, setLocalTitle] = useState(label || fieldTemplateTitle || t('nodes.unknownField'));
const handleSubmit = useCallback(
async (newTitle: string) => {
async (newTitleRaw: string) => {
const newTitle = newTitleRaw.trim();
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
@@ -57,22 +58,22 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
}, [label, fieldTemplateTitle, t]);
return (
<Tooltip
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
ref={ref}
position="relative"
overflow="hidden"
alignItems="center"
justifyContent="flex-start"
gap={1}
w="full"
>
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
ref={ref}
position="relative"
overflow="hidden"
alignItems="center"
justifyContent="flex-start"
gap={1}
w="full"
<Tooltip
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<EditablePreview
fontWeight="semibold"
@@ -80,10 +81,10 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
noOfLines={1}
color={isMissingInput ? 'error.300' : 'base.300'}
/>
<EditableInput className="nodrag" sx={editableInputStyles} />
<EditableControls />
</Editable>
</Tooltip>
</Tooltip>
<EditableInput className="nodrag" sx={editableInputStyles} />
<EditableControls />
</Editable>
);
});
@@ -127,7 +128,15 @@ const EditableControls = memo(() => {
}
return (
<Flex onClick={handleClick} position="absolute" w="full" h="full" top={0} insetInlineStart={0} cursor="text" />
<Flex
onClick={handleClick}
position="absolute"
w="min-content"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
});

View File

@@ -69,7 +69,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
);
}
if (fieldTemplate.input === 'connection') {
if (fieldTemplate.input === 'connection' || isConnected) {
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl isInvalid={isMissingInput} isDisabled={isConnected} px={2}>
@@ -95,7 +95,15 @@ const InputField = ({ nodeId, fieldName }: Props) => {
return (
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl isInvalid={isMissingInput} isDisabled={isConnected} orientation="vertical" px={2}>
<FormControl
isInvalid={isMissingInput}
isDisabled={isConnected}
// Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
// connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
pointerEvents={isConnected ? 'none' : 'auto'}
orientation="vertical"
px={2}
>
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
<Flex>
<EditableFieldTitle

View File

@@ -1,14 +1,14 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { nodeExclusivelySelected, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { MouseEvent, PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
@@ -20,16 +20,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
const selectIsInProgress = useMemo(
() =>
createSelector(
selectNodesSlice,
(nodes) => nodes.nodeExecutionStates[nodeId]?.status === zNodeStatus.enum.IN_PROGRESS
),
[nodeId]
);
const isInProgress = useAppSelector(selectIsInProgress);
const executionState = useExecutionState(nodeId);
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [
'nodeInProgress',
@@ -39,7 +31,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const dispatch = useAppDispatch();
const opacity = useAppSelector((s) => s.nodes.nodeOpacity);
const opacity = useAppSelector((s) => s.workflowSettings.nodeOpacity);
const { onCloseGlobal } = useGlobalMenuClose();
const handleClick = useCallback(

View File

@@ -1,12 +1,12 @@
import { CompositeSlider, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeOpacityChanged } from 'features/nodes/store/nodesSlice';
import { nodeOpacityChanged } from 'features/nodes/store/workflowSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const NodeOpacitySlider = () => {
const dispatch = useAppDispatch();
const nodeOpacity = useAppSelector((s) => s.nodes.nodeOpacity);
const nodeOpacity = useAppSelector((s) => s.workflowSettings.nodeOpacity);
const { t } = useTranslation();
const handleChange = useCallback(

View File

@@ -1,9 +1,6 @@
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
// shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
} from 'features/nodes/store/nodesSlice';
import { shouldShowMinimapPanelChanged } from 'features/nodes/store/workflowSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
@@ -19,9 +16,9 @@ const ViewportControls = () => {
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
// const shouldShowFieldTypeLegend = useAppSelector(
// (s) => s.nodes.shouldShowFieldTypeLegend
// (s) => s.nodes.present.shouldShowFieldTypeLegend
// );
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
const handleClickedZoomIn = useCallback(() => {
zoomIn();

View File

@@ -16,7 +16,7 @@ const minimapStyles: SystemStyleObject = {
};
const MinimapPanel = () => {
const shouldShowMinimapPanel = useAppSelector((s) => s.nodes.shouldShowMinimapPanel);
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
return (
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>

View File

@@ -1,23 +1,18 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice';
import { memo, useCallback } from 'react';
import { openAddNodePopover } from 'features/nodes/store/nodesSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
const AddNodeButton = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleOpenAddNodePopover = useCallback(() => {
dispatch(addNodePopoverOpened());
}, [dispatch]);
return (
<IconButton
tooltip={t('nodes.addNodeToolTip')}
aria-label={t('nodes.addNode')}
icon={<PiPlusBold />}
onClick={handleOpenAddNodePopover}
onClick={openAddNodePopover}
pointerEvents="auto"
/>
);

View File

@@ -21,13 +21,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ReloadNodeTemplatesButton from 'features/nodes/components/flow/panels/TopRightPanel/ReloadSchemaButton';
import {
selectionModeChanged,
selectNodesSlice,
selectWorkflowSettingsSlice,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from 'features/nodes/store/nodesSlice';
} from 'features/nodes/store/workflowSettingsSlice';
import type { ChangeEvent, ReactNode } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -35,7 +35,7 @@ import { SelectionMode } from 'reactflow';
const formLabelProps: FormLabelProps = { flexGrow: 1 };
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const selector = createMemoizedSelector(selectWorkflowSettingsSlice, (workflowSettings) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
@@ -43,7 +43,7 @@ const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
shouldColorEdges,
shouldShowEdgeLabels,
selectionMode,
} = nodes;
} = workflowSettings;
return {
shouldAnimateEdges,
shouldValidateGraph,

View File

@@ -3,27 +3,21 @@ import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
return {
data: lastSelectedNode?.data,
};
});
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => selectLastSelectedNode(nodes));
const InspectorDataTab = () => {
const { t } = useTranslation();
const { data } = useAppSelector(selector);
const lastSelectedNode = useAppSelector(selector);
if (!data) {
if (!lastSelectedNode) {
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
}
return <DataViewer data={data} label="Node Data" />;
return <DataViewer data={lastSelectedNode.data} label="Node Data" />;
};
export default memo(InspectorDataTab);

View File

@@ -1,36 +1,39 @@
import { Box, Flex, FormControl, FormLabel, HStack, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import EditableNodeTitle from './details/EditableNodeTitle';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
});
const InspectorDetailsTab = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNode = selectLastSelectedNode(nodes);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
}),
[templates]
);
const data = useAppSelector(selector);
const { t } = useTranslation();

View File

@@ -1,46 +1,49 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { ImageOutput } from 'services/api/types';
import type { AnyResult } from 'services/events/types';
import ImageOutputPreview from './outputs/ImageOutputPreview';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
return;
}
return {
outputs: nes.outputs,
outputType: lastSelectedNodeTemplate.outputType,
};
});
const InspectorOutputsTab = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNode = selectLastSelectedNode(nodes);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.id,
outputType: lastSelectedNodeTemplate.outputType,
};
}),
[templates]
);
const data = useAppSelector(selector);
const nes = useExecutionState(data?.nodeId);
const { t } = useTranslation();
if (!data) {
if (!data || !nes) {
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
}
if (data.outputs.length === 0) {
if (nes.outputs.length === 0) {
return <IAINoContentFallback label={t('nodes.noOutputRecorded')} icon={null} />;
}
@@ -49,11 +52,11 @@ const InspectorOutputsTab = () => {
<ScrollableContent>
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
{data.outputType === 'image_output' ? (
data.outputs.map((result, i) => (
nes.outputs.map((result, i) => (
<ImageOutputPreview key={getKey(result, i)} output={result as ImageOutput} />
))
) : (
<DataViewer data={data.outputs} label={t('nodes.nodeOutputs')} />
<DataViewer data={nes.outputs} label={t('nodes.nodeOutputs')} />
)}
</Flex>
</ScrollableContent>

View File

@@ -1,25 +1,26 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { memo } from 'react';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
return {
template: lastSelectedNodeTemplate,
};
});
const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector);
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNode = selectLastSelectedNode(nodes);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
return lastSelectedNodeTemplate;
}),
[templates]
);
const template = useAppSelector(selector);
const { t } = useTranslation();
if (!template) {

View File

@@ -1,31 +1,39 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
const selector = useMemo(
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
}
const fields = map(template.inputs).filter(
(field) =>
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
}),
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useMemo(() => {
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return false;
}
return (
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
const fieldNames = useAppSelector(selector);
return fieldNames;
};

View File

@@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $templates } from 'features/nodes/store/nodesSlice';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
@@ -8,8 +9,7 @@ import { useCallback } from 'react';
import { useReactFlow } from 'reactflow';
export const useBuildNode = () => {
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
const templates = useStore($templates);
const flow = useReactFlow();
return useCallback(
@@ -41,10 +41,10 @@ export const useBuildNode = () => {
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = nodeTemplates[type] as InvocationTemplate;
const template = templates[type] as InvocationTemplate;
return buildInvocationNode(position, template);
},
[nodeTemplates, flow]
[templates, flow]
);
};

View File

@@ -0,0 +1,93 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import {
$isAddNodePopoverOpen,
$isUpdatingEdge,
$pendingConnection,
$templates,
connectionMade,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useCallback, useMemo } from 'react';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { assert } from 'tsafe';
export const useConnection = () => {
const store = useAppStore();
const templates = useStore($templates);
const onConnectStart = useCallback<OnConnectStart>(
(event, params) => {
const nodes = store.getState().nodes.present.nodes;
const { nodeId, handleId, handleType } = params;
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
const node = nodes.find((n) => n.id === nodeId);
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
const template = templates[node.data.type];
assert(template, `Template not found for node type: ${node.data.type}`);
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
$pendingConnection.set({
node,
template,
fieldTemplate,
});
},
[store, templates]
);
const onConnect = useCallback<OnConnect>(
(connection) => {
const { dispatch } = store;
dispatch(connectionMade(connection));
$pendingConnection.set(null);
},
[store]
);
const onConnectEnd = useCallback<OnConnectEnd>(() => {
const { dispatch } = store;
const pendingConnection = $pendingConnection.get();
const isUpdatingEdge = $isUpdatingEdge.get();
const mouseOverNodeId = $mouseOverNode.get();
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
// update logic can finish up
if (isUpdatingEdge && !mouseOverNodeId) {
$pendingConnection.set(null);
return;
}
if (!pendingConnection) {
return;
}
const { nodes, edges } = store.getState().nodes.present;
if (mouseOverNodeId) {
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
if (!candidateNode) {
// The mouse is over a non-invocation node - bail
return;
}
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const connection = getFirstValidConnection(
templates,
nodes,
edges,
pendingConnection,
candidateNode,
candidateTemplate
);
if (connection) {
dispatch(connectionMade(connection));
}
$pendingConnection.set(null);
} else {
// The mouse is not over a node - we should open the add node popover
$isAddNodePopoverOpen.set(true);
}
}, [store, templates]);
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
return api;
};

View File

@@ -1,34 +1,39 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
import { useMemo } from 'react';
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
const selector = useMemo(
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
}
// get the visible fields
const fields = map(template.inputs).filter(
(field) =>
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
}),
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useAppSelector(selector);
const fieldNames = useMemo(() => {
// get the visible fields
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return true;
}
return (
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
return fieldNames;
};

View File

@@ -1,16 +1,12 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { useMemo } from 'react';
import { useFieldType } from './useFieldType.ts';
const selectIsConnectionInProgress = createSelector(
selectNodesSlice,
(nodes) => nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null
);
type UseConnectionStateProps = {
nodeId: string;
fieldName: string;
@@ -18,6 +14,8 @@ type UseConnectionStateProps = {
};
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates);
const fieldType = useFieldType(nodeId, fieldName, kind);
const selectIsConnected = useMemo(
@@ -36,25 +34,30 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
);
const selectConnectionError = useMemo(
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
[nodeId, fieldName, kind, fieldType]
);
const selectIsConnectionStartField = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) =>
Boolean(
nodes.connectionStartParams?.nodeId === nodeId &&
nodes.connectionStartParams?.handleId === fieldName &&
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
)
makeConnectionErrorSelector(
templates,
pendingConnection,
nodeId,
fieldName,
kind === 'inputs' ? 'target' : 'source',
fieldType
),
[fieldName, kind, nodeId]
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
);
const isConnected = useAppSelector(selectIsConnected);
const isConnectionInProgress = useAppSelector(selectIsConnectionInProgress);
const isConnectionStartField = useAppSelector(selectIsConnectionStartField);
const isConnectionInProgress = useMemo(() => Boolean(pendingConnection), [pendingConnection]);
const isConnectionStartField = useMemo(() => {
if (!pendingConnection) {
return false;
}
return (
pendingConnection.node.id === nodeId &&
pendingConnection.fieldTemplate.name === fieldName &&
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
);
}, [fieldName, kind, nodeId, pendingConnection]);
const connectionError = useAppSelector(selectConnectionError);
const shouldDim = useMemo(

View File

@@ -0,0 +1,78 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import {
$copiedEdges,
$copiedNodes,
$cursorPos,
$edgesToCopiedNodes,
selectionPasted,
selectNodesSlice,
} from 'features/nodes/store/nodesSlice';
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
import { isEqual, uniqWith } from 'lodash-es';
import { v4 as uuidv4 } from 'uuid';
const copySelection = () => {
// Use the imperative API here so we don't have to pass the whole slice around
const { getState } = getStore();
const { nodes, edges } = selectNodesSlice(getState());
const selectedNodes = nodes.filter((node) => node.selected);
const selectedEdges = edges.filter((edge) => edge.selected);
const edgesToSelectedNodes = edges.filter((edge) => selectedNodes.some((node) => node.id === edge.target));
$copiedNodes.set(selectedNodes);
$copiedEdges.set(selectedEdges);
$edgesToCopiedNodes.set(edgesToSelectedNodes);
};
const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
const { getState, dispatch } = getStore();
const currentNodes = selectNodesSlice(getState()).nodes;
const cursorPos = $cursorPos.get();
const copiedNodes = deepClone($copiedNodes.get());
let copiedEdges = deepClone($copiedEdges.get());
if (withEdgesToCopiedNodes) {
const edgesToCopiedNodes = deepClone($edgesToCopiedNodes.get());
copiedEdges = uniqWith([...copiedEdges, ...edgesToCopiedNodes], isEqual);
}
// Calculate an offset to reposition nodes to surround the cursor position, maintaining relative positioning
const xCoords = copiedNodes.map((node) => node.position.x);
const yCoords = copiedNodes.map((node) => node.position.y);
const minX = Math.min(...xCoords);
const minY = Math.min(...yCoords);
const offsetX = cursorPos ? cursorPos.x - minX : 50;
const offsetY = cursorPos ? cursorPos.y - minY : 50;
copiedNodes.forEach((node) => {
const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY);
node.position.x = x;
node.position.y = y;
// Pasted nodes are selected
node.selected = true;
// Also give em a fresh id
const id = uuidv4();
// Update the edges to point to the new node id
for (const edge of copiedEdges) {
if (edge.source === node.id) {
edge.source = id;
edge.id = edge.id.replace(node.data.id, id);
}
if (edge.target === node.id) {
edge.target = id;
edge.id = edge.id.replace(node.data.id, id);
}
}
node.id = id;
node.data.id = id;
});
dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges }));
};
const api = { copySelection, pasteSelection };
export const useCopyPaste = () => {
return api;
};

View File

@@ -0,0 +1,56 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodeExecutionStates } from 'features/nodes/store/types';
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { map } from 'nanostores';
import { useEffect, useMemo } from 'react';
export const $nodeExecutionStates = map<NodeExecutionStates>({});
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
export const useExecutionState = (nodeId?: string) => {
const executionStates = useStore($nodeExecutionStates, nodeId ? { keys: [nodeId] } : undefined);
const executionState = useMemo(() => (nodeId ? executionStates[nodeId] : undefined), [executionStates, nodeId]);
return executionState;
};
const removeNodeExecutionState = (nodeId: string) => {
$nodeExecutionStates.setKey(nodeId, undefined);
};
export const upsertExecutionState = (nodeId: string, updates?: Partial<NodeExecutionState>) => {
const state = $nodeExecutionStates.get()[nodeId];
if (!state) {
$nodeExecutionStates.setKey(nodeId, { ...deepClone(initialNodeExecutionState), nodeId, ...updates });
} else {
$nodeExecutionStates.setKey(nodeId, { ...state, ...updates });
}
};
const selectNodeIds = createMemoizedSelector(selectNodesSlice, (nodesSlice) => nodesSlice.nodes.map((node) => node.id));
export const useSyncExecutionState = () => {
const nodeIds = useAppSelector(selectNodeIds);
useEffect(() => {
const nodeExecutionStates = $nodeExecutionStates.get();
const nodeIdsToAdd = nodeIds.filter((id) => !nodeExecutionStates[id]);
const nodeIdsToRemove = Object.keys(nodeExecutionStates).filter((id) => !nodeIds.includes(id));
for (const id of nodeIdsToAdd) {
upsertExecutionState(id);
}
for (const id of nodeIdsToRemove) {
removeNodeExecutionState(id);
}
}, [nodeIds]);
};

View File

@@ -1,20 +1,9 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectFieldInputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, nodeId]
);
const fieldTemplate = useAppSelector(selector);
const template = useNodeTemplate(nodeId);
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
return fieldTemplate;
};

View File

@@ -1,20 +1,9 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import type { FieldOutputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, nodeId]
);
const fieldTemplate = useAppSelector(selector);
const template = useNodeTemplate(nodeId);
const fieldTemplate = useMemo(() => template.outputs[fieldName] ?? null, [fieldName, template.outputs]);
return fieldTemplate;
};

View File

@@ -1,27 +1,36 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectInvocationNodeType } from 'features/nodes/store/selectors';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useFieldTemplate = (
nodeId: string,
fieldName: string,
kind: 'inputs' | 'outputs'
): FieldInputTemplate | FieldOutputTemplate | null => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName);
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, kind, nodeId]
): FieldInputTemplate | FieldOutputTemplate => {
const templates = useStore($templates);
const selectNodeType = useMemo(
() => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
[nodeId]
);
const fieldTemplate = useAppSelector(selector);
const nodeType = useAppSelector(selectNodeType);
const fieldTemplate = useMemo(() => {
const template = templates[nodeType];
assert(template, `Template for node type ${nodeType} not found`);
if (kind === 'inputs') {
const fieldTemplate = template.inputs[fieldName];
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
return fieldTemplate;
} else {
const fieldTemplate = template.outputs[fieldName];
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
return fieldTemplate;
}
}, [fieldName, kind, nodeType, templates]);
return fieldTemplate;
};

View File

@@ -1,22 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { useMemo } from 'react';
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null;
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null;
}),
[fieldName, kind, nodeId]
);
const fieldTemplateTitle = useAppSelector(selector);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const fieldTemplateTitle = useMemo(() => fieldTemplate.title, [fieldTemplate]);
return fieldTemplateTitle;
};

View File

@@ -1,23 +1,9 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import type { FieldType } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null;
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null;
}),
[fieldName, kind, nodeId]
);
const fieldType = useAppSelector(selector);
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]);
return fieldType;
};

View File

@@ -1,20 +1,26 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
const selector = createSelector(selectNodesSlice, (nodes) =>
nodes.nodes.filter(isInvocationNode).some((node) => {
const template = nodes.templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node, template);
})
);
import { useMemo } from 'react';
export const useGetNodesNeedUpdate = () => {
const getNeedsUpdate = useAppSelector(selector);
return getNeedsUpdate;
const templates = useStore($templates);
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) =>
nodes.nodes.filter(isInvocationNode).some((node) => {
const template = templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node.data, template);
})
),
[templates]
);
const needsUpdate = useAppSelector(selector);
return needsUpdate;
};

View File

@@ -1,26 +1,20 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { some } from 'lodash-es';
import { useMemo } from 'react';
export const useHasImageOutput = (nodeId: string): boolean => {
const selector = useMemo(
const template = useNodeTemplate(nodeId);
const hasImageOutput = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
return some(
template?.outputs,
(output) =>
output.type.name === 'ImageField' &&
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
template?.type !== 'image'
);
}),
[nodeId]
some(
template?.outputs,
(output) =>
output.type.name === 'ImageField' &&
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
template?.type !== 'image'
),
[template]
);
const hasImageOutput = useAppSelector(selector);
return hasImageOutput;
};

View File

@@ -1,8 +1,12 @@
// TODO: enable this at some point
import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { isEqual } from 'lodash-es';
import { useCallback } from 'react';
import type { Connection, Node } from 'reactflow';
@@ -13,7 +17,8 @@ import type { Connection, Node } from 'reactflow';
export const useIsValidConnection = () => {
const store = useAppStore();
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph);
const templates = useStore($templates);
const shouldValidateGraph = useAppSelector((s) => s.workflowSettings.shouldValidateGraph);
const isValidConnection = useCallback(
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
// Connection must have valid targets
@@ -27,7 +32,7 @@ export const useIsValidConnection = () => {
}
const state = store.getState();
const { nodes, edges, templates } = state.nodes;
const { nodes, edges } = state.nodes.present;
// Find the source and target nodes
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
@@ -40,6 +45,10 @@ export const useIsValidConnection = () => {
return false;
}
if (targetFieldTemplate.input === 'direct') {
return false;
}
if (!shouldValidateGraph) {
// manual override!
return true;
@@ -57,6 +66,14 @@ export const useIsValidConnection = () => {
return false;
}
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
return isEqual(sourceFieldTemplate.type, collectItemType);
}
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
@@ -76,7 +93,7 @@ export const useIsValidConnection = () => {
// Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges);
},
[shouldValidateGraph, store]
[shouldValidateGraph, templates, store]
);
return isValidConnection;

View File

@@ -1,19 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import type { Classification } from 'features/nodes/types/common';
import { useMemo } from 'react';
export const useNodeClassification = (nodeId: string): Classification | null => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
return selectNodeTemplate(nodes, nodeId)?.classification ?? null;
}),
[nodeId]
);
const title = useAppSelector(selector);
return title;
export const useNodeClassification = (nodeId: string): Classification => {
const template = useNodeTemplate(nodeId);
const classification = useMemo(() => template.classification, [template]);
return classification;
};

View File

@@ -5,7 +5,7 @@ import { selectNodeData } from 'features/nodes/store/selectors';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useNodeData = (nodeId: string): InvocationNodeData | null => {
export const useNodeData = (nodeId: string): InvocationNodeData => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {

View File

@@ -1,25 +1,11 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
import { useMemo } from 'react';
export const useNodeNeedsUpdate = (nodeId: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const node = selectInvocationNode(nodes, nodeId);
const template = selectNodeTemplate(nodes, nodeId);
if (!node || !template) {
return false;
}
return getNeedsUpdate(node, template);
}),
[nodeId]
);
const needsUpdate = useAppSelector(selector);
const data = useNodeData(nodeId);
const template = useNodeTemplate(nodeId);
const needsUpdate = useMemo(() => getNeedsUpdate(data, template), [data, template]);
return needsUpdate;
};

View File

@@ -1,20 +1,23 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectInvocationNodeType } from 'features/nodes/store/selectors';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
return selectNodeTemplate(nodes, nodeId);
}),
export const useNodeTemplate = (nodeId: string): InvocationTemplate => {
const templates = useStore($templates);
const selectNodeType = useMemo(
() => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
[nodeId]
);
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
const nodeType = useAppSelector(selectNodeType);
const template = useMemo(() => {
const t = templates[nodeType];
assert(t, `Template for node type ${nodeType} not found`);
return t;
}, [nodeType, templates]);
return template;
};

View File

@@ -1,18 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useMemo } from 'react';
export const useNodeTemplateTitle = (nodeId: string): string | null => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
return selectNodeTemplate(nodes, nodeId)?.title ?? null;
}),
[nodeId]
);
const title = useAppSelector(selector);
const template = useNodeTemplate(nodeId);
const title = useMemo(() => template.title, [template.title]);
return title;
};

View File

@@ -1,26 +1,10 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { map } from 'lodash-es';
import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
}
return getSortedFilteredFieldNames(map(template.outputs));
}),
[nodeId]
);
const fieldNames = useAppSelector(selector);
export const useOutputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const fieldNames = useMemo(() => getSortedFilteredFieldNames(map(template.outputs)), [template.outputs]);
return fieldNames;
};

View File

@@ -1,7 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
@@ -43,76 +42,21 @@ import {
zT2IAdapterModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import type {
Connection,
Edge,
EdgeChange,
EdgeRemoveChange,
Node,
NodeChange,
OnConnectStartParams,
Viewport,
XYPosition,
} from 'reactflow';
import {
addEdge,
applyEdgeChanges,
applyNodeChanges,
getConnectedEdges,
getIncomers,
getOutgoers,
SelectionMode,
} from 'reactflow';
import {
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationError,
socketInvocationStarted,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom } from 'nanostores';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import type { z } from 'zod';
import type { NodesState } from './types';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
import type { NodesState, PendingConnection, Templates } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
const initialNodesState: NodesState = {
_version: 1,
nodes: [],
edges: [],
templates: {},
connectionStartParams: null,
connectionStartFieldType: null,
connectionMade: false,
modifyingEdge: false,
addNewNodePosition: null,
shouldShowMinimapPanel: true,
shouldValidateGraph: true,
shouldAnimateEdges: true,
shouldSnapToGrid: false,
shouldColorEdges: true,
shouldShowEdgeLabels: false,
isAddNodePopoverOpen: false,
nodeOpacity: 1,
selectedNodes: [],
selectedEdges: [],
nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 },
nodesToCopy: [],
edgesToCopy: [],
selectionMode: SelectionMode.Partial,
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
@@ -154,12 +98,12 @@ export const nodesSlice = createSlice({
}
state.nodes[nodeIndex] = action.payload.node;
},
nodeAdded: (state, action: PayloadAction<AnyNode>) => {
const node = action.payload;
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
const { node, cursorPos } = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
state.addNewNodePosition?.x ?? node.position.x,
state.addNewNodePosition?.y ?? node.position.y
cursorPos?.x ?? node.position.x,
cursorPos?.y ?? node.position.y
);
node.position = position;
node.selected = true;
@@ -175,40 +119,6 @@ export const nodesSlice = createSlice({
);
state.nodes.push(node);
if (!isInvocationNode(node)) {
return;
}
state.nodeExecutionStates[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
if (state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
const newConnection = findConnectionToValidHandle(
node,
state.nodes,
state.edges,
state.templates,
nodeId,
handleId,
handleType,
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
}
}
}
state.connectionStartParams = null;
state.connectionStartFieldType = null;
},
edgeChangeStarted: (state) => {
state.modifyingEdge = true;
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
@@ -216,71 +126,8 @@ export const nodesSlice = createSlice({
edgeAdded: (state, action: PayloadAction<Edge>) => {
state.edges = addEdge(action.payload, state.edges);
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload;
state.connectionMade = state.modifyingEdge;
const { nodeId, handleId, handleType } = action.payload;
if (!nodeId || !handleId) {
return;
}
const node = state.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const template = state.templates[node.data.type];
const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId];
state.connectionStartFieldType = field?.type ?? null;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
const fieldType = state.connectionStartFieldType;
if (!fieldType) {
return;
}
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
state.connectionMade = true;
},
connectionEnded: (
state,
action: PayloadAction<{
cursorPosition: XYPosition;
mouseOverNodeId: string | null;
}>
) => {
const { cursorPosition, mouseOverNodeId } = action.payload;
if (!state.connectionMade) {
if (mouseOverNodeId) {
const nodeIndex = state.nodes.findIndex((n) => n.id === mouseOverNodeId);
const mouseOverNode = state.nodes?.[nodeIndex];
if (mouseOverNode && state.connectionStartParams) {
const { nodeId, handleId, handleType } = state.connectionStartParams;
if (nodeId && handleId && handleType && state.connectionStartFieldType) {
const newConnection = findConnectionToValidHandle(
mouseOverNode,
state.nodes,
state.edges,
state.templates,
nodeId,
handleId,
handleType,
state.connectionStartFieldType
);
if (newConnection) {
state.edges = addEdge({ ...newConnection, type: 'default' }, state.edges);
}
}
}
state.connectionStartParams = null;
state.connectionStartFieldType = null;
} else {
state.addNewNodePosition = cursorPosition;
state.isAddNodePopoverOpen = true;
}
} else {
state.connectionStartParams = null;
state.connectionStartFieldType = null;
}
state.modifyingEdge = false;
},
fieldLabelChanged: (
state,
@@ -442,7 +289,6 @@ export const nodesSlice = createSlice({
if (!isInvocationNode(node)) {
return;
}
delete state.nodeExecutionStates[node.id];
});
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
@@ -474,12 +320,6 @@ export const nodesSlice = createSlice({
state.nodes
);
},
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
state.selectedNodes = action.payload;
},
selectedEdgesChanged: (state, action: PayloadAction<string[]>) => {
state.selectedEdges = action.payload;
},
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
fieldValueReducer(state, action, zStatefulFieldValue);
},
@@ -537,34 +377,10 @@ export const nodesSlice = createSlice({
}
node.data.notes = value;
},
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
},
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
},
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload;
},
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAnimateEdges = action.payload;
},
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowEdgeLabels = action.payload;
},
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
state.shouldSnapToGrid = action.payload;
},
shouldColorEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldColorEdges = action.payload;
},
nodeOpacityChanged: (state, action: PayloadAction<number>) => {
state.nodeOpacity = action.payload;
},
viewportChanged: (state, action: PayloadAction<Viewport>) => {
state.viewport = action.payload;
},
selectedAll: (state) => {
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
@@ -575,136 +391,49 @@ export const nodesSlice = createSlice({
state.edges
);
},
selectionCopied: (state) => {
const nodesToCopy: AnyNode[] = [];
const edgesToCopy: Edge[] = [];
selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => {
const { nodes, edges } = action.payload;
for (const node of state.nodes) {
if (node.selected) {
nodesToCopy.push(deepClone(node));
}
}
const nodeChanges: NodeChange[] = [];
for (const edge of state.edges) {
if (edge.selected) {
edgesToCopy.push(deepClone(edge));
}
}
state.nodesToCopy = nodesToCopy;
state.edgesToCopy = edgesToCopy;
if (state.nodesToCopy.length > 0) {
const averagePosition = { x: 0, y: 0 };
state.nodesToCopy.forEach((e) => {
const xOffset = 0.15 * (e.width ?? 0);
const yOffset = 0.5 * (e.height ?? 0);
averagePosition.x += e.position.x + xOffset;
averagePosition.y += e.position.y + yOffset;
// Deselect existing nodes
state.nodes.forEach((n) => {
nodeChanges.push({
id: n.data.id,
type: 'select',
selected: false,
});
averagePosition.x /= state.nodesToCopy.length;
averagePosition.y /= state.nodesToCopy.length;
state.nodesToCopy.forEach((e) => {
e.position.x -= averagePosition.x;
e.position.y -= averagePosition.y;
});
// Add new nodes
nodes.forEach((n) => {
nodeChanges.push({
item: n,
type: 'add',
});
}
},
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
const { cursorPosition } = action.payload;
const newNodes: AnyNode[] = [];
for (const node of state.nodesToCopy) {
newNodes.push(deepClone(node));
}
const oldNodeIds = newNodes.map((n) => n.data.id);
const newEdges: Edge[] = [];
for (const edge of state.edgesToCopy) {
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
newEdges.push(deepClone(edge));
}
}
newEdges.forEach((e) => (e.selected = true));
newNodes.forEach((node) => {
const newNodeId = uuidv4();
newEdges.forEach((edge) => {
if (edge.source === node.data.id) {
edge.source = newNodeId;
edge.id = edge.id.replace(node.data.id, newNodeId);
}
if (edge.target === node.data.id) {
edge.target = newNodeId;
edge.id = edge.id.replace(node.data.id, newNodeId);
}
});
node.selected = true;
node.id = newNodeId;
node.data.id = newNodeId;
const position = findUnoccupiedPosition(
state.nodes,
node.position.x + (cursorPosition?.x ?? 0),
node.position.y + (cursorPosition?.y ?? 0)
);
node.position = position;
});
const nodeAdditions: NodeChange[] = newNodes.map((n) => ({
item: n,
type: 'add',
}));
const nodeSelectionChanges: NodeChange[] = state.nodes.map((n) => ({
id: n.data.id,
type: 'select',
selected: false,
}));
const edgeAdditions: EdgeChange[] = newEdges.map((e) => ({
item: e,
type: 'add',
}));
const edgeSelectionChanges: EdgeChange[] = state.edges.map((e) => ({
id: e.id,
type: 'select',
selected: false,
}));
state.nodes = applyNodeChanges(nodeAdditions.concat(nodeSelectionChanges), state.nodes);
state.edges = applyEdgeChanges(edgeAdditions.concat(edgeSelectionChanges), state.edges);
newNodes.forEach((node) => {
state.nodeExecutionStates[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
const edgeChanges: EdgeChange[] = [];
// Deselect existing edges
state.edges.forEach((e) => {
edgeChanges.push({
id: e.id,
type: 'select',
selected: false,
});
});
// Add new edges
edges.forEach((e) => {
edgeChanges.push({
item: e,
type: 'add',
});
});
},
addNodePopoverOpened: (state) => {
state.addNewNodePosition = null; //Create the node in viewport center by default
state.isAddNodePopoverOpen = true;
},
addNodePopoverClosed: (state) => {
state.isAddNodePopoverOpen = false;
//Make sure these get reset if we close the popover and haven't selected a node
state.connectionStartParams = null;
state.connectionStartFieldType = null;
},
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
},
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
state.templates = action.payload;
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
undo: (state) => state,
redo: (state) => state,
},
extraReducers: (builder) => {
builder.addCase(workflowLoaded, (state, action) => {
@@ -720,75 +449,13 @@ export const nodesSlice = createSlice({
edges.map((edge) => ({ item: edge, type: 'add' })),
[]
);
state.nodeExecutionStates = nodes.reduce<Record<string, NodeExecutionState>>((acc, node) => {
acc[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
return acc;
}, {});
});
builder.addCase(socketInvocationStarted, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.IN_PROGRESS;
}
});
builder.addCase(socketInvocationComplete, (state, action) => {
const { source_node_id, result } = action.payload.data;
const nes = state.nodeExecutionStates[source_node_id];
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
}
});
builder.addCase(socketInvocationError, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.FAILED;
node.error = action.payload.data.error;
node.progress = null;
node.progressImage = null;
}
});
builder.addCase(socketGeneratorProgress, (state, action) => {
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.IN_PROGRESS;
node.progress = (step + 1) / total_steps;
node.progressImage = progress_image ?? null;
}
});
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = zNodeStatus.enum.PENDING;
nes.error = null;
nes.progress = null;
nes.progressImage = null;
nes.outputs = [];
});
}
});
},
});
export const {
addNodePopoverClosed,
addNodePopoverOpened,
connectionEnded,
connectionMade,
connectionStarted,
edgeDeleted,
edgeChangeStarted,
edgesChanged,
edgesDeleted,
fieldValueReset,
@@ -816,31 +483,97 @@ export const {
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
nodeOpacityChanged,
nodesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged,
selectedAll,
selectedEdgesChanged,
selectedNodesChanged,
selectionCopied,
selectionModeChanged,
selectionPasted,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowMinimapPanelChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
viewportChanged,
edgeAdded,
nodeTemplatesBuilt,
shouldShowEdgeLabelsChanged,
undo,
redo,
} = nodesSlice.actions;
export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Templates>({});
export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null);
export const $isUpdatingEdge = atom(false);
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
export const $isAddNodePopoverOpen = atom(false);
export const closeAddNodePopover = () => {
$isAddNodePopoverOpen.set(false);
$pendingConnection.set(null);
};
export const openAddNodePopover = () => {
$isAddNodePopoverOpen.set(true);
};
export const selectNodesSlice = (state: RootState) => state.nodes.present;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialNodesState,
migrate: migrateNodesState,
persistDenylist: [],
};
const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected);
const isSelectionAction = (action: UnknownAction) => {
if (selectionMatcher(action)) {
return true;
}
if (nodesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'select')) {
return true;
}
}
return false;
};
const individualGroupByMatcher = isAnyOf(nodesChanged);
export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
limit: 64,
undoType: nodesSlice.actions.undo.type,
redoType: nodesSlice.actions.redo.type,
groupBy: (action, state, history) => {
if (isSelectionAction(action)) {
// Changes to selection should never be recorded on their own
return history.group;
}
if (individualGroupByMatcher(action)) {
return action.type;
}
return null;
},
filter: (action, _state, _history) => {
// Ignore all actions from other slices
if (!action.type.startsWith(nodesSlice.name)) {
return false;
}
if (nodesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'dimensions')) {
return false;
}
}
return true;
},
};
// This is used for tracking `state.workflow.isTouched`
export const isAnyNodeOrEdgeMutation = isAnyOf(
connectionEnded,
connectionMade,
edgeDeleted,
edgesChanged,
@@ -873,30 +606,3 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
selectionPasted,
edgeAdded
);
export const selectNodesSlice = (state: RootState) => state.nodes;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const nodesPersistConfig: PersistConfig<NodesState> = {
name: nodesSlice.name,
initialState: initialNodesState,
migrate: migrateNodesState,
persistDenylist: [
'connectionStartParams',
'connectionStartFieldType',
'selectedNodes',
'selectedEdges',
'nodesToCopy',
'edgesToCopy',
'connectionMade',
'modifyingEdge',
'addNewNodePosition',
],
};

View File

@@ -1,26 +1,23 @@
import type { NodesState } from 'features/nodes/store/types';
import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { assert } from 'tsafe';
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => {
const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return null;
}
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
return node;
};
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => {
return selectInvocationNode(nodesSlice, nodeId)?.data ?? null;
export const selectInvocationNodeType = (nodesSlice: NodesState, nodeId: string): string => {
const node = selectInvocationNode(nodesSlice, nodeId);
return node.data.type;
};
export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => {
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData => {
const node = selectInvocationNode(nodesSlice, nodeId);
if (!node) {
return null;
}
return nodesSlice.templates[node.data.type] ?? null;
return node.data;
};
export const selectFieldInputInstance = (
@@ -32,20 +29,10 @@ export const selectFieldInputInstance = (
return data?.inputs[fieldName] ?? null;
};
export const selectFieldInputTemplate = (
nodesSlice: NodesState,
nodeId: string,
fieldName: string
): FieldInputTemplate | null => {
const template = selectNodeTemplate(nodesSlice, nodeId);
return template?.inputs[fieldName] ?? null;
};
export const selectFieldOutputTemplate = (
nodesSlice: NodesState,
nodeId: string,
fieldName: string
): FieldOutputTemplate | null => {
const template = selectNodeTemplate(nodesSlice, nodeId);
return template?.outputs[fieldName] ?? null;
export const selectLastSelectedNode = (nodesSlice: NodesState) => {
const selectedNodes = nodesSlice.nodes.filter((n) => n.selected);
if (selectedNodes.length === 1) {
return selectedNodes[0];
}
return null;
};

View File

@@ -1,38 +1,31 @@
import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/nodes/types/field';
import type {
FieldIdentifier,
FieldInputTemplate,
FieldOutputTemplate,
StatefulFieldValue,
} from 'features/nodes/types/field';
import type {
AnyNode,
InvocationNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
export type PendingConnection = {
node: InvocationNode;
template: InvocationTemplate;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};
export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: InvocationNodeEdge[];
templates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
connectionStartFieldType: FieldType | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectedNodes: string[];
selectedEdges: string[];
nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport;
nodesToCopy: AnyNode[];
edgesToCopy: InvocationNodeEdge[];
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode;
};
export type WorkflowMode = 'edit' | 'view';

View File

@@ -1,112 +1,105 @@
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { Connection, Edge, HandleType, Node } from 'reactflow';
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, isEqual, map } from 'lodash-es';
import type { Connection } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
node: Node,
handle: FieldInputTemplate | FieldOutputTemplate
) => {
let isValidConnection = true;
if (handleCurrentType === 'source') {
if (
edges.find((edge) => {
return edge.target === node.id && edge.targetHandle === handle.name;
})
) {
isValidConnection = false;
}
} else {
if (
edges.find((edge) => {
return edge.source === node.id && edge.sourceHandle === handle.name;
})
) {
isValidConnection = false;
}
}
if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
isValidConnection = false;
}
return isValidConnection;
};
export const findConnectionToValidHandle = (
node: AnyNode,
export const getFirstValidConnection = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
templates: Record<string, InvocationTemplate>,
handleCurrentNodeId: string,
handleCurrentName: string,
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate
): Connection | null => {
if (node.id === handleCurrentNodeId || !isInvocationNode(node)) {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
return null;
}
const template = templates[node.data.type];
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (!template) {
return null;
}
const handles = handleCurrentType === 'source' ? template.inputs : template.outputs;
//Prioritize handles whos name matches the node we're coming from
const handle = handles[handleCurrentName];
if (handle) {
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
if (isGraphAcyclic && valid) {
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct');
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
);
const candidateField = candidateUnconnectedFields.find((field) =>
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null;
}
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return null;
}
// Sources/outputs can have any number of edges, we can take the first matching output field
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
}
for (const handleName in handles) {
const handle = handles[handleName];
if (!handle) {
continue;
}
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName;
const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges);
const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle);
if (isGraphAcyclic && valid) {
return {
source: sourceID,
sourceHandle: sourceHandle,
target: targetID,
targetHandle: targetHandle,
};
}
}
return null;
};

View File

@@ -4,8 +4,8 @@ export const findUnoccupiedPosition = (nodes: Node[], x: number, y: number) => {
let newX = x;
let newY = y;
while (nodes.find((n) => n.position.x === newX && n.position.y === newY)) {
newX = newX + 50;
newY = newY + 50;
newX = Math.floor(newX + 50);
newY = Math.floor(newY + 50);
}
return { x: newX, y: newY };
};

View File

@@ -1,39 +1,66 @@
import { createSelector } from '@reduxjs/toolkit';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import i18n from 'i18next';
import { isEqual } from 'lodash-es';
import type { HandleType } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getCollectItemType = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null;
return fieldType;
};
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const makeConnectionErrorSelector = (
templates: Templates,
pendingConnection: PendingConnection | null,
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType?: FieldType | null
) => {
return createSelector(selectNodesSlice, (nodesSlice) => {
const { nodes, edges } = nodesSlice;
if (!fieldType) {
return i18n.t('nodes.noFieldType');
}
const { connectionStartFieldType, connectionStartParams, nodes, edges } = nodesSlice;
if (!connectionStartParams || !connectionStartFieldType) {
if (!pendingConnection) {
return i18n.t('nodes.noConnectionInProgress');
}
const {
handleType: connectionHandleType,
nodeId: connectionNodeId,
handleId: connectionFieldName,
} = connectionStartParams;
const connectionNodeId = pendingConnection.node.id;
const connectionFieldName = pendingConnection.fieldTemplate.name;
const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
const connectionStartFieldType = pendingConnection.fieldTemplate.type;
if (!connectionHandleType || !connectionNodeId || !connectionFieldName) {
return i18n.t('nodes.noConnectionData');
@@ -54,26 +81,45 @@ export const makeConnectionErrorSelector = (
}
// we have to figure out which is the target and which is the source
const target = handleType === 'target' ? nodeId : connectionNodeId;
const targetHandle = handleType === 'target' ? fieldName : connectionFieldName;
const source = handleType === 'source' ? nodeId : connectionNodeId;
const sourceHandle = handleType === 'source' ? fieldName : connectionFieldName;
const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId;
const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName;
const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId;
const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName;
if (
edges.find((edge) => {
edge.target === target &&
edge.targetHandle === targetHandle &&
edge.source === source &&
edge.sourceHandle === sourceHandle;
edge.target === targetNodeId &&
edge.targetHandle === targetFieldName &&
edge.source === sourceNodeId &&
edge.sourceHandle === sourceFieldName;
})
) {
// We already have a connection from this source to this target
return i18n.t('nodes.cannotDuplicateConnection');
}
const targetNode = nodes.find((node) => node.id === targetNodeId);
assert(targetNode, `Target node not found: ${targetNodeId}`);
const targetTemplate = templates[targetNode.data.type];
assert(targetTemplate, `Target template not found: ${targetNode.data.type}`);
if (targetTemplate.inputs[targetFieldName]?.input === 'direct') {
return i18n.t('nodes.cannotConnectToDirectInput');
}
if (targetNode.data.type === 'collect' && targetFieldName === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
if (!isEqual(sourceType, collectItemType)) {
return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes');
}
}
}
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
return edge.target === targetNodeId && edge.targetHandle === targetFieldName;
}) &&
// except CollectionItem inputs can have multiples
targetType.name !== 'CollectionItemField'

View File

@@ -0,0 +1,87 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { SelectionMode } from 'reactflow';
type WorkflowSettingsState = {
_version: 1;
shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectionMode: SelectionMode;
};
const initialState: WorkflowSettingsState = {
_version: 1,
shouldShowMinimapPanel: true,
shouldValidateGraph: true,
shouldAnimateEdges: true,
shouldSnapToGrid: false,
shouldColorEdges: true,
shouldShowEdgeLabels: false,
nodeOpacity: 1,
selectionMode: SelectionMode.Partial,
};
export const workflowSettingsSlice = createSlice({
name: 'workflowSettings',
initialState,
reducers: {
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload;
},
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload;
},
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAnimateEdges = action.payload;
},
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowEdgeLabels = action.payload;
},
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
state.shouldSnapToGrid = action.payload;
},
shouldColorEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldColorEdges = action.payload;
},
nodeOpacityChanged: (state, action: PayloadAction<number>) => {
state.nodeOpacity = action.payload;
},
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
},
},
});
export const {
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowMinimapPanelChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
nodeOpacityChanged,
selectionModeChanged,
} = workflowSettingsSlice.actions;
export const selectWorkflowSettingsSlice = (state: RootState) => state.workflowSettings;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateWorkflowSettingsState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const workflowSettingsPersistConfig: PersistConfig<WorkflowSettingsState> = {
name: workflowSettingsSlice.name,
initialState,
migrate: migrateWorkflowSettingsState,
persistDenylist: [],
};

View File

@@ -11,7 +11,7 @@ import type {
SchedulerField,
T2IAdapterField,
} from 'features/nodes/types/common';
import type { S } from 'services/api/types';
import type { Invocation, S } from 'services/api/types';
import type { Equals, Extends } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
@@ -26,7 +26,7 @@ describe('Common types', () => {
test('ImageField', () => assert<Equals<ImageField, S['ImageField']>>());
test('BoardField', () => assert<Equals<BoardField, S['BoardField']>>());
test('ColorField', () => assert<Equals<ColorField, S['ColorField']>>());
test('SchedulerField', () => assert<Equals<SchedulerField, NonNullable<S['SchedulerInvocation']['scheduler']>>>());
test('SchedulerField', () => assert<Equals<SchedulerField, NonNullable<Invocation<'scheduler'>['scheduler']>>>());
test('ControlField', () => assert<Equals<ControlField, S['ControlField']>>());
// @ts-expect-error TODO(psyche): fix types
test('IPAdapterField', () => assert<Extends<IPAdapterField, S['IPAdapterField']>>());

View File

@@ -1,707 +0,0 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import {
isControlAdapterLayer,
isInitialImageLayer,
isIPAdapterLayer,
isRegionalGuidanceLayer,
rgLayerMaskImageUploaded,
} from 'features/controlLayers/store/controlLayersSlice';
import type { InitialImageLayer, Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
import type {
ControlNetConfigV2,
ImageWithDims,
IPAdapterConfigV2,
ProcessorConfig,
T2IAdapterConfigV2,
} from 'features/controlLayers/util/controlAdapters';
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
import type { ImageField } from 'features/nodes/types/common';
import {
CONTROL_NET_COLLECT,
IMAGE_TO_LATENTS,
IP_ADAPTER_COLLECT,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
RESIZE,
T2I_ADAPTER_COLLECT,
} from 'features/nodes/util/graph/constants';
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type {
BaseModelType,
CollectInvocation,
ControlNetInvocation,
Edge,
ImageDTO,
ImageResizeInvocation,
ImageToLatentsInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import { assert } from 'tsafe';
export const addControlLayersToGraph = async (
state: RootState,
graph: NonNullableGraph,
denoiseNodeId: string
): Promise<Layer[]> => {
const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl';
// Filter out layers with incompatible base model, missing control image
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, mainModel.base));
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
for (const ca of validControlAdapters) {
addGlobalControlAdapterToGraph(ca, graph, denoiseNodeId);
}
const validIPAdapters = validLayers.filter(isIPAdapterLayer).map((l) => l.ipAdapter);
for (const ipAdapter of validIPAdapters) {
addGlobalIPAdapterToGraph(ipAdapter, graph, denoiseNodeId);
}
const initialImageLayers = validLayers.filter(isInitialImageLayer);
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
if (initialImageLayers[0]) {
addInitialImageLayerToGraph(state, graph, denoiseNodeId, initialImageLayers[0]);
}
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes.
// With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up
const posCondCollectNode: CollectInvocation = {
id: POSITIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
const negCondCollectNode: CollectInvocation = {
id: NEGATIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
// Re-route the denoise node's OG conditioning inputs to the collect nodes
const newEdges: Edge[] = [];
for (const edge of graph.edges) {
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else {
newEdges.push(edge);
}
}
graph.edges = newEdges;
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
graph.edges.push({
source: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'negative_conditioning',
},
});
const validRGLayers = validLayers.filter(isRegionalGuidanceLayer);
const layerIds = validRGLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
for (const layer of validRGLayers) {
const blob = blobs[layer.id];
assert(blob, `Blob for layer ${layer.id} not found`);
// Upload the mask image, or get the cached image if it exists
const { image_name } = await getMaskImage(layer, blob);
// The main mask-to-tensor node
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
type: 'alpha_mask_to_tensor',
image: {
image_name,
},
};
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
if (layer.positivePrompt) {
// The main positive conditioning node
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
});
// Connect the conditioning to the collector
graph.edges.push({
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" positive conditioning node to the regional cond
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
});
}
}
}
if (layer.negativePrompt) {
// The main negative conditioning node
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
style: layer.negativePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
};
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
});
// Connect the conditioning to the collector
graph.edges.push({
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" negative conditioning node to the regional cond
for (const edge of graph.edges) {
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
});
}
}
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMaskNode: S['InvertTensorMaskInvocation'] = {
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
type: 'invert_tensor_mask',
};
graph.nodes[invertTensorMaskNode.id] = invertTensorMaskNode;
// Connect the OG mask image to the inverted mask-to-tensor node
graph.edges.push({
source: {
node_id: maskToTensorNode.id,
field: 'mask',
},
destination: {
node_id: invertTensorMaskNode.id,
field: 'mask',
},
});
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
// positive prompt
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
// Connect the inverted mask to the conditioning
graph.edges.push({
source: { node_id: invertTensorMaskNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
});
// Connect the conditioning to the negative collector
graph.edges.push({
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" positive conditioning node to our regional node
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
});
}
}
}
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) =>
isValidIPAdapter(ipa, mainModel.base)
);
for (const ipAdapter of validRegionalIPAdapters) {
addIPAdapterCollectorSafe(graph, denoiseNodeId);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.name,
},
};
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: ipAdapterNode.id, field: 'mask' },
});
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: IP_ADAPTER_COLLECT,
field: 'item',
},
});
}
}
upsertMetadata(graph, { control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
return validLayers;
};
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
const addGlobalControlAdapterToGraph = (
controlAdapter: ControlNetConfigV2 | T2IAdapterConfigV2,
graph: NonNullableGraph,
denoiseNodeId: string
) => {
if (controlAdapter.type === 'controlnet') {
addGlobalControlNetToGraph(controlAdapter, graph, denoiseNodeId);
}
if (controlAdapter.type === 't2i_adapter') {
addGlobalT2IAdapterToGraph(controlAdapter, graph, denoiseNodeId);
}
};
const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
if (graph.nodes[CONTROL_NET_COLLECT]) {
// You see, we've already got one!
return;
}
// Add the ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
destination: {
node_id: denoiseNodeId,
field: 'control',
},
});
};
const addGlobalControlNetToGraph = (controlNet: ControlNetConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
assert(model, 'ControlNet model is required');
const controlImage = buildControlImage(image, processedImage, processorConfig);
addControlNetCollectorSafe(graph, denoiseNodeId);
const controlNetNode: ControlNetInvocation = {
id: `control_net_${id}`,
type: 'controlnet',
is_intermediate: true,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
control_mode: controlMode,
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,
image: controlImage,
};
graph.nodes[controlNetNode.id] = controlNetNode;
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
};
const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
if (graph.nodes[T2I_ADAPTER_COLLECT]) {
// You see, we've already got one!
return;
}
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = {
id: T2I_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[T2I_ADAPTER_COLLECT] = t2iAdapterCollectNode;
graph.edges.push({
source: { node_id: T2I_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: denoiseNodeId,
field: 't2i_adapter',
},
});
};
const addGlobalT2IAdapterToGraph = (t2iAdapter: T2IAdapterConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
assert(model, 'T2I Adapter model is required');
const controlImage = buildControlImage(image, processedImage, processorConfig);
addT2IAdapterCollectorSafe(graph, denoiseNodeId);
const t2iAdapterNode: T2IAdapterInvocation = {
id: `t2i_adapter_${id}`,
type: 't2i_adapter',
is_intermediate: true,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
resize_mode: 'just_resize',
t2i_adapter_model: model,
weight: weight,
image: controlImage,
};
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
destination: {
node_id: T2I_ADAPTER_COLLECT,
field: 'item',
},
});
};
const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
if (graph.nodes[IP_ADAPTER_COLLECT]) {
// You see, we've already got one!
return;
}
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[IP_ADAPTER_COLLECT] = ipAdapterCollectNode;
graph.edges.push({
source: { node_id: IP_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: denoiseNodeId,
field: 'ip_adapter',
},
});
};
const addGlobalIPAdapterToGraph = (ipAdapter: IPAdapterConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
addIPAdapterCollectorSafe(graph, denoiseNodeId);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.name,
},
};
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: IP_ADAPTER_COLLECT,
field: 'item',
},
});
};
const addInitialImageLayerToGraph = (
state: RootState,
graph: NonNullableGraph,
denoiseNodeId: string,
layer: InitialImageLayer
) => {
const { vaePrecision, model } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
const { width, height } = state.controlLayers.present.size;
assert(layer.isEnabled, 'Initial image layer is not enabled');
assert(layer.image, 'Initial image layer has no image');
const isSDXL = model?.base === 'sdxl';
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
const denoiseNode = graph.nodes[denoiseNodeId];
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
const { denoisingStrength } = layer;
denoiseNode.denoising_start = useRefinerStartEnd
? Math.min(refinerStart, 1 - denoisingStrength)
: 1 - denoisingStrength;
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
const i2lNode: ImageToLatentsInvocation = {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate: true,
use_cache: true,
fp32: vaePrecision === 'fp32',
};
graph.nodes[i2lNode.id] = i2lNode;
graph.edges.push({
source: {
node_id: IMAGE_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: denoiseNode.id,
field: 'latents',
},
});
if (layer.image.width !== width || layer.image.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: layer.image.name,
},
is_intermediate: true,
width,
height,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
i2lNode.image = {
image_name: layer.image.name,
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
upsertMetadata(graph, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
};
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === base;
const hasControlImage = Boolean(ca.image || (ca.processedImage && ca.processorConfig));
return hasModel && modelMatchesBase && hasControlImage;
};
const isValidIPAdapter = (ipa: IPAdapterConfigV2, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipa.model);
const modelMatchesBase = ipa.model?.base === base;
const hasImage = Boolean(ipa.image);
return hasModel && modelMatchesBase && hasImage;
};
const isValidLayer = (layer: Layer, base: BaseModelType) => {
if (!layer.isEnabled) {
return false;
}
if (isControlAdapterLayer(layer)) {
return isValidControlAdapter(layer.controlAdapter, base);
}
if (isIPAdapterLayer(layer)) {
return isValidIPAdapter(layer.ipAdapter, base);
}
if (isInitialImageLayer(layer)) {
if (!layer.image) {
return false;
}
return true;
}
if (isRegionalGuidanceLayer(layer)) {
if (layer.maskObjects.length === 0) {
// Layer has no mask, meaning any guidance would be applied to an empty region.
return false;
}
const hasTextPrompt = Boolean(layer.positivePrompt) || Boolean(layer.negativePrompt);
const hasIPAdapter = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
return hasTextPrompt || hasIPAdapter;
}
return false;
};

View File

@@ -1,356 +0,0 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import type {
DenoiseLatentsInvocation,
Edge,
ESRGANInvocation,
LatentsToImageInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import {
DENOISE_LATENTS,
DENOISE_LATENTS_HRF,
ESRGAN_HRF,
IMAGE_TO_LATENTS_HRF,
LATENTS_TO_IMAGE,
LATENTS_TO_IMAGE_HRF_HR,
LATENTS_TO_IMAGE_HRF_LR,
MAIN_MODEL_LOADER,
NOISE,
NOISE_HRF,
RESIZE_HRF,
SEAMLESS,
VAE_LOADER,
} from './constants';
import { setMetadataReceivingNode, upsertMetadata } from './metadata';
// Copy certain connections from previous DENOISE_LATENTS to new DENOISE_LATENTS_HRF.
function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
const destinationFields = [
'control',
'ip_adapter',
'metadata',
'unet',
'positive_conditioning',
'negative_conditioning',
];
const newEdges: Edge[] = [];
// Loop through the existing edges connected to DENOISE_LATENTS
graph.edges.forEach((edge: Edge) => {
if (edge.destination.node_id === DENOISE_LATENTS && destinationFields.includes(edge.destination.field)) {
// Add a similar connection to DENOISE_LATENTS_HRF
newEdges.push({
source: {
node_id: edge.source.node_id,
field: edge.source.field,
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: edge.destination.field,
},
});
}
});
graph.edges = graph.edges.concat(newEdges);
}
/**
* Calculates the new resolution for high-resolution features (HRF) based on base model type.
* Adjusts the width and height to maintain the aspect ratio and constrains them by the model's dimension limits,
* rounding down to the nearest multiple of 8.
*
* @param {number} optimalDimension The optimal dimension for the base model.
* @param {number} width The current width to be adjusted for HRF.
* @param {number} height The current height to be adjusted for HRF.
* @return {{newWidth: number, newHeight: number}} The new width and height, adjusted and rounded as needed.
*/
function calculateHrfRes(
optimalDimension: number,
width: number,
height: number
): { newWidth: number; newHeight: number } {
const aspect = width / height;
const minDimension = Math.floor(optimalDimension * 0.5);
const modelArea = optimalDimension * optimalDimension; // Assuming square images for model_area
let initWidth;
let initHeight;
if (aspect > 1.0) {
initHeight = Math.max(minDimension, Math.sqrt(modelArea / aspect));
initWidth = initHeight * aspect;
} else {
initWidth = Math.max(minDimension, Math.sqrt(modelArea * aspect));
initHeight = initWidth / aspect;
}
// Cap initial height and width to final height and width.
initWidth = Math.min(width, initWidth);
initHeight = Math.min(height, initHeight);
const newWidth = roundToMultiple(Math.floor(initWidth), 8);
const newHeight = roundToMultiple(Math.floor(initHeight), 8);
return { newWidth, newHeight };
}
// Adds the high-res fix feature to the given graph.
export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void => {
// Double check hrf is enabled.
if (!state.hrf.hrfEnabled || state.config.disabledSDFeatures.includes('hrf')) {
return;
}
const log = logger('generation');
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
const { hrfStrength, hrfEnabled, hrfMethod } = state.hrf;
const { width, height } = state.controlLayers.present.size;
const isAutoVae = !vae;
const isSeamlessEnabled = seamlessXAxis || seamlessYAxis;
const optimalDimension = selectOptimalDimension(state);
const { newWidth: hrfWidth, newHeight: hrfHeight } = calculateHrfRes(optimalDimension, width, height);
// Pre-existing (original) graph nodes.
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as DenoiseLatentsInvocation | undefined;
const originalNoiseNode = graph.nodes[NOISE] as NoiseInvocation | undefined;
const originalLatentsToImageNode = graph.nodes[LATENTS_TO_IMAGE] as LatentsToImageInvocation | undefined;
if (!originalDenoiseLatentsNode) {
log.error('originalDenoiseLatentsNode is undefined');
return;
}
if (!originalNoiseNode) {
log.error('originalNoiseNode is undefined');
return;
}
if (!originalLatentsToImageNode) {
log.error('originalLatentsToImageNode is undefined');
return;
}
// Change height and width of original noise node to initial resolution.
if (originalNoiseNode) {
originalNoiseNode.width = hrfWidth;
originalNoiseNode.height = hrfHeight;
}
// Define new nodes and their connections, roughly in order of operations.
graph.nodes[LATENTS_TO_IMAGE_HRF_LR] = {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_LR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'latents',
},
},
{
source: {
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'vae',
},
}
);
graph.nodes[RESIZE_HRF] = {
id: RESIZE_HRF,
type: 'img_resize',
is_intermediate: true,
width: width,
height: height,
};
if (hrfMethod === 'ESRGAN') {
let model_name: ESRGANInvocation['model_name'] = 'RealESRGAN_x2plus.pth';
if ((width * height) / (hrfWidth * hrfHeight) > 2) {
model_name = 'RealESRGAN_x4plus.pth';
}
graph.nodes[ESRGAN_HRF] = {
id: ESRGAN_HRF,
type: 'esrgan',
model_name,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: ESRGAN_HRF,
field: 'image',
},
},
{
source: {
node_id: ESRGAN_HRF,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
}
);
} else {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
});
}
graph.nodes[NOISE_HRF] = {
type: 'noise',
id: NOISE_HRF,
seed: originalNoiseNode?.seed,
use_cpu: originalNoiseNode?.use_cpu,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: RESIZE_HRF,
field: 'height',
},
destination: {
node_id: NOISE_HRF,
field: 'height',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'width',
},
destination: {
node_id: NOISE_HRF,
field: 'width',
},
}
);
graph.nodes[IMAGE_TO_LATENTS_HRF] = {
type: 'i2l',
id: IMAGE_TO_LATENTS_HRF,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'vae',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'image',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'image',
},
}
);
graph.nodes[DENOISE_LATENTS_HRF] = {
type: 'denoise_latents',
id: DENOISE_LATENTS_HRF,
is_intermediate: true,
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
scheduler: originalDenoiseLatentsNode?.scheduler,
steps: originalDenoiseLatentsNode?.steps,
denoising_start: 1 - hrfStrength,
denoising_end: 1,
};
graph.edges.push(
{
source: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
},
{
source: {
node_id: NOISE_HRF,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'noise',
},
}
);
copyConnectionsToDenoiseLatentsHrf(graph);
// The original l2i node is unnecessary now, remove it
graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE);
delete graph.nodes[LATENTS_TO_IMAGE];
graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_HR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.edges.push(
{
source: {
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'vae',
},
},
{
source: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'latents',
},
}
);
upsertMetadata(graph, {
hrf_strength: hrfStrength,
hrf_enabled: hrfEnabled,
hrf_method: hrfMethod,
});
setMetadataReceivingNode(graph, LATENTS_TO_IMAGE_HRF_HR);
};

View File

@@ -1,9 +1,9 @@
import type { RootState } from 'app/store/store';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types';
import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
import { ESRGAN } from './constants';
import { addCoreMetadataNode, upsertMetadata } from './metadata';
type Arg = {
image_name: string;
@@ -13,7 +13,7 @@ type Arg = {
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
const { esrganModelName } = state.postprocessing;
const realesrganNode: ESRGANInvocation = {
const realesrganNode: Invocation<'esrgan'> = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },

View File

@@ -1,267 +0,0 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addHrfToGraph } from './addHrfToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
CONTROL_LAYERS_GRAPH,
DENOISE_LATENTS,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
export const buildGenerationTabGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
steps,
clipSkip,
shouldUseCpuNoise,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
seed,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
const use_cpu = shouldUseCpuNoise;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
let modelLoaderNodeId = MAIN_MODEL_LOADER;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: CONTROL_LAYERS_GRAPH,
nodes: {
[modelLoaderNodeId]: {
type: 'main_model_loader',
id: modelLoaderNodeId,
is_intermediate,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers: clipSkip,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
id: NOISE,
seed,
width,
height,
use_cpu,
is_intermediate,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
use_cache: false,
},
},
edges: [
// Connect Model Loader to UNet and CLIP Skip
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: CLIP_SKIP,
field: 'clip',
},
},
// Connect CLIP Skip to Conditioning
{
source: {
node_id: CLIP_SKIP,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: CLIP_SKIP,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
// Connect everything to Denoise Latents
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'noise',
},
},
// Decode Denoised Latents To Image
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
clip_skip: clipSkip,
},
LATENTS_TO_IMAGE
);
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
const addedLayers = await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
const shouldUseHRF = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
// High resolution fix.
if (state.hrf.hrfEnabled && shouldUseHRF) {
addHrfToGraph(state, graph);
}
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@@ -1,278 +0,0 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
SDXL_CONTROL_LAYERS_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
shouldUseCpuNoise,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
const { refinerModel, refinerStart } = state.sdxl;
const use_cpu = shouldUseCpuNoise;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
// Construct Style Prompt
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_CONTROL_LAYERS_GRAPH,
nodes: {
[modelLoaderNodeId]: {
type: 'sdxl_model_loader',
id: modelLoaderNodeId,
model,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
id: NOISE,
seed,
width,
height,
use_cpu,
is_intermediate,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: refinerModel ? refinerStart : 1,
is_intermediate,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
use_cache: false,
},
},
edges: [
// Connect Model Loader to UNet, VAE & CLIP
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
// Connect everything to Denoise Latents
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'noise',
},
},
// Decode Denoised Latents To Image
{
source: {
node_id: SDXL_DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
},
LATENTS_TO_IMAGE
);
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// add LoRA support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@@ -5,8 +5,8 @@ import { range } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
import { getHasMetadata, removeMetadata } from './canvas/metadata';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
import { getHasMetadata, removeMetadata } from './metadata';
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
const { iterations, model, shouldRandomizeSeed, seed } = state.generation;

View File

@@ -2,25 +2,18 @@ import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { CONTROL_NET_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
const controlNetMetadata: S['CoreMetadataInvocation']['controlnets'] = [];
const controlNets = selectValidControlNets(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
@@ -37,7 +30,7 @@ export const addControlNetToLinearGraph = async (
if (controlNets.length) {
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = {
const controlNetIterateNode: Invocation<'collect'> = {
id: CONTROL_NET_COLLECT,
type: 'collect',
is_intermediate: true,
@@ -68,7 +61,7 @@ export const addControlNetToLinearGraph = async (
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
const controlNetNode: Invocation<'controlnet'> = {
id: `control_net_${id}`,
type: 'controlnet',
is_intermediate: true,

View File

@@ -2,19 +2,12 @@ import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
@@ -33,7 +26,7 @@ export const addIPAdapterToLinearGraph = async (
if (ipAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = {
const ipAdapterCollectNode: Invocation<'collect'> = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
@@ -47,7 +40,7 @@ export const addIPAdapterToLinearGraph = async (
},
});
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
const ipAdapterMetdata: S['CoreMetadataInvocation']['ipAdapters'] = [];
for (const ipAdapter of ipAdapters) {
if (!ipAdapter.model) {
@@ -57,7 +50,7 @@ export const addIPAdapterToLinearGraph = async (
assert(controlImage, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
const ipAdapterNode: Invocation<'ip_adapter'> = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,

View File

@@ -1,10 +1,15 @@
import type { RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import {
CLIP_SKIP,
LORA_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
} from 'features/nodes/util/graph/constants';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types';
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
import { upsertMetadata } from './metadata';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
export const addLoRAsToGraph = async (
state: RootState,
@@ -38,7 +43,7 @@ export const addLoRAsToGraph = async (
// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = [];
const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
enabledLoRAs.forEach(async (lora) => {
const { weight } = lora;
@@ -46,7 +51,7 @@ export const addLoRAsToGraph = async (
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: LoRALoaderInvocation = {
const loraLoaderNode: Invocation<'lora_loader'> = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,

View File

@@ -1,15 +1,14 @@
import type { RootState } from 'app/store/store';
import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types';
export const addNSFWCheckerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
@@ -19,14 +18,14 @@ export const addNSFWCheckerToGraph = (
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
const nsfwCheckerNode: Invocation<'img_nsfw'> = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode;
graph.edges.push({
source: {
node_id: nodeIdToAddTo,

View File

@@ -1,8 +1,6 @@
import type { RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import {
LORA_LOADER,
NEGATIVE_CONDITIONING,
@@ -10,8 +8,9 @@ import {
SDXL_MODEL_LOADER,
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS,
} from './constants';
import { upsertMetadata } from './metadata';
} from 'features/nodes/util/graph/constants';
import { filter, size } from 'lodash-es';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
export const addSDXLLoRAsToGraph = async (
state: RootState,
@@ -35,7 +34,7 @@ export const addSDXLLoRAsToGraph = async (
return;
}
const loraMetadata: CoreMetadataInvocation['loras'] = [];
const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
// Handle Seamless Plugs
const unetLoaderId = modelLoaderNodeId;
@@ -61,7 +60,7 @@ export const addSDXLLoRAsToGraph = async (
const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: SDXLLoRALoaderInvocation = {
const loraLoaderNode: Invocation<'sdxl_lora_loader'> = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,

View File

@@ -1,8 +1,6 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import { getModelMetadataField, upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import {
CANVAS_OUTPUT,
INPAINT_CREATE_MASK,
@@ -17,9 +15,10 @@ import {
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './graphBuilderUtils';
import { getModelMetadataField, upsertMetadata } from './metadata';
} from 'features/nodes/util/graph/constants';
import { getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
export const addSDXLRefinerToGraph = async (
state: RootState,
@@ -101,7 +100,7 @@ export const addSDXLRefinerToGraph = async (
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
};
graph.edges.push(
{

View File

@@ -1,6 +1,5 @@
import type { RootState } from 'app/store/store';
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import {
DENOISE_LATENTS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
@@ -11,8 +10,8 @@ import {
SDXL_DENOISE_LATENTS,
SEAMLESS,
VAE_LOADER,
} from './constants';
import { upsertMetadata } from './metadata';
} from 'features/nodes/util/graph/constants';
import type { NonNullableGraph } from 'services/api/types';
export const addSeamlessToLinearGraph = (
state: RootState,
@@ -28,7 +27,7 @@ export const addSeamlessToLinearGraph = (
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
};
if (!isAutoVae) {
graph.nodes[VAE_LOADER] = {

View File

@@ -2,19 +2,12 @@ import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
import { T2I_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
@@ -36,7 +29,7 @@ export const addT2IAdaptersToLinearGraph = async (
if (t2iAdapters.length) {
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = {
const t2iAdapterCollectNode: Invocation<'collect'> = {
id: T2I_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
@@ -50,7 +43,7 @@ export const addT2IAdaptersToLinearGraph = async (
},
});
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
const t2iAdapterMetadata: S['CoreMetadataInvocation']['t2iAdapters'] = [];
for (const t2iAdapter of t2iAdapters) {
if (!t2iAdapter.model) {
@@ -68,7 +61,7 @@ export const addT2IAdaptersToLinearGraph = async (
weight,
} = t2iAdapter;
const t2iAdapterNode: T2IAdapterInvocation = {
const t2iAdapterNode: Invocation<'t2i_adapter'> = {
id: `t2i_adapter_${id}`,
type: 't2i_adapter',
is_intermediate: true,

View File

@@ -1,6 +1,5 @@
import type { RootState } from 'app/store/store';
import type { NonNullableGraph } from 'services/api/types';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import {
CANVAS_IMAGE_TO_IMAGE_GRAPH,
CANVAS_INPAINT_GRAPH,
@@ -21,8 +20,8 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
VAE_LOADER,
} from './constants';
import { upsertMetadata } from './metadata';
} from 'features/nodes/util/graph/constants';
import type { NonNullableGraph } from 'services/api/types';
export const addVAEToGraph = async (
state: RootState,

View File

@@ -1,29 +1,23 @@
import type { RootState } from 'app/store/store';
import type {
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
LatentsToImageInvocation,
NonNullableGraph,
} from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types';
export const addWatermarkerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as Invocation<'img_nsfw'> | undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
const watermarkerNode: ImageWatermarkInvocation = {
const watermarkerNode: Invocation<'img_watermark'> = {
id: WATERMARKER,
type: 'img_watermark',
is_intermediate: getIsIntermediate(state),

Some files were not shown because too many files have changed in this diff Show More