mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 15:37:55 -05:00
Compare commits
72 Commits
5.10.0dev1
...
v5.10.0dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e200a2ba2 | ||
|
|
4610b55a5d | ||
|
|
b3b3dbd92d | ||
|
|
6c36b0508b | ||
|
|
2756c539e0 | ||
|
|
a34383d460 | ||
|
|
77f22497d2 | ||
|
|
5967d4e1da | ||
|
|
1253ad5053 | ||
|
|
5aa08ab09b | ||
|
|
6ce527768b | ||
|
|
fe88012236 | ||
|
|
8609b98217 | ||
|
|
19f0bf828c | ||
|
|
26cbeccfdf | ||
|
|
b5be81b97b | ||
|
|
f14d07968b | ||
|
|
525a89900a | ||
|
|
d8df31a8ac | ||
|
|
380a41be34 | ||
|
|
e990afbccb | ||
|
|
c591478d24 | ||
|
|
30def6a9bd | ||
|
|
6cf88a601d | ||
|
|
5e14545c32 | ||
|
|
eefbcd2485 | ||
|
|
13cc44a22c | ||
|
|
2cca339a5c | ||
|
|
0a7cf6c0ec | ||
|
|
06abc1d40a | ||
|
|
2cde86b7b8 | ||
|
|
0a49463c79 | ||
|
|
f3402b6ce7 | ||
|
|
5d3fb822c5 | ||
|
|
9e70d8eb6e | ||
|
|
402758d502 | ||
|
|
b97cc51f23 | ||
|
|
f6f33b5999 | ||
|
|
cd873f1fe5 | ||
|
|
5f3d398074 | ||
|
|
e6b366ff61 | ||
|
|
bcd50ed688 | ||
|
|
a5966c3197 | ||
|
|
f28b054872 | ||
|
|
31681f4ad7 | ||
|
|
aaf042de48 | ||
|
|
c28e685409 | ||
|
|
d6ac822a1f | ||
|
|
f0a4d7ac7f | ||
|
|
04b0e658df | ||
|
|
68845f4d85 | ||
|
|
6df5614b54 | ||
|
|
0bd6f0245b | ||
|
|
6c9165046e | ||
|
|
2b5da91beb | ||
|
|
74bede14be | ||
|
|
04ea3c491a | ||
|
|
38e7b23d18 | ||
|
|
c052846e05 | ||
|
|
af3a31dfec | ||
|
|
571710fab6 | ||
|
|
a175a5c252 | ||
|
|
8b3c36c6fa | ||
|
|
b9ffacd4bf | ||
|
|
ae45fc8a74 | ||
|
|
85db9c65e5 | ||
|
|
ddddaef7ca | ||
|
|
e4678201cb | ||
|
|
d66fdfde71 | ||
|
|
08ee08557b | ||
|
|
496f1262c6 | ||
|
|
188d52e4a5 |
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@@ -2,11 +2,11 @@
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
@@ -22,7 +22,7 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername @jazzhaiku
|
||||
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
|
||||
2
.github/workflows/build-installer.yml
vendored
2
.github/workflows/build-installer.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
|
||||
5. Choose a version to install. Review the [GitHub releases page](https://github.com/invoke-ai/InvokeAI/releases).
|
||||
|
||||
6. Determine the package package specifier to use when installing. This is a performance optimization.
|
||||
6. Determine the package specifier to use when installing. This is a performance optimization.
|
||||
|
||||
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
|
||||
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
FieldIdentifier,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
@@ -34,6 +35,12 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
class ValidationRunData(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow being published.")
|
||||
input_fields: list[FieldIdentifier] = Body(description="The input fields for the published workflow")
|
||||
output_fields: list[FieldIdentifier] = Body(description="The output fields for the published workflow")
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
@@ -45,6 +52,10 @@ async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
validation_run_data: Optional[ValidationRunData] = Body(
|
||||
default=None,
|
||||
description="The validation run data to use for this batch. This is only used if this is a validation run.",
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
|
||||
@@ -106,6 +106,7 @@ async def list_workflows(
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
@@ -118,6 +119,7 @@ async def list_workflows(
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
is_published=is_published,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
workflows_with_thumbnails.append(
|
||||
|
||||
128
invokeai/app/invocations/controlnet.py
Normal file
128
invokeai/app/invocations/controlnet.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,716 +0,0 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from controlnet_aux import (
|
||||
ContentShuffleDetector,
|
||||
LeresDetector,
|
||||
MediapipeFaceDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v):
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.images.save(image=processed_image)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width=image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height=image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"canny_image_processor",
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.3.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processed_image = get_canny_edges(
|
||||
image,
|
||||
self.low_threshold,
|
||||
self.high_threshold,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"hed_image_processor",
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_image_processor",
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
lineart_processor = LineartProcessor()
|
||||
processed_image = lineart_processor.run(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"lineart_anime_image_processor",
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"midas_depth_image_processor",
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"normalbae_image_processor",
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor",
|
||||
title="MLSD Processor",
|
||||
tags=["controlnet", "mlsd"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor",
|
||||
title="PIDI Processor",
|
||||
tags=["controlnet", "pidi"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"content_shuffle_image_processor",
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
@invocation(
|
||||
"zoe_depth_image_processor",
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mediapipe_face_processor",
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
max_faces=self.max_faces,
|
||||
min_confidence=self.min_confidence,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"leres_image_processor",
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"tile_image_processor",
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(
|
||||
self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
np_img = HWC3(np_img)
|
||||
if down_sampling_rate < 1.1:
|
||||
return np_img
|
||||
H, W, C = np_img.shape
|
||||
H = int(float(H) / float(down_sampling_rate))
|
||||
W = int(float(W) / float(down_sampling_rate))
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate,
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything_processor",
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.4",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
# so using ADE20k color palette instead
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||
h, w = anns[0]["segmentation"].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann["segmentation"]
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_map_image_processor",
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map = Image.fromarray(color_map)
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"depth_anything_image_processor",
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.3",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
"dw_openpose_image_processor",
|
||||
title="DW Openpose Image Processor",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
draw_body: bool = InputField(default=True)
|
||||
draw_face: bool = InputField(default=False)
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
draw_hands=self.draw_hands,
|
||||
draw_body=self.draw_body,
|
||||
resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
"""Resize an image using a heuristic method. Preserves edge maps."""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
|
||||
@@ -4,7 +4,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector2
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -25,20 +25,20 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
|
||||
onnx_det_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_det())
|
||||
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector.get_model_url_pose())
|
||||
|
||||
loaded_session_det = context.models.load_local_model(
|
||||
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_det_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
)
|
||||
loaded_session_pose = context.models.load_local_model(
|
||||
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
|
||||
onnx_pose_path, DWOpenposeDetector.create_onnx_inference_session
|
||||
)
|
||||
|
||||
with loaded_session_det as session_det, loaded_session_pose as session_pose:
|
||||
assert isinstance(session_det, ort.InferenceSession)
|
||||
assert isinstance(session_pose, ort.InferenceSession)
|
||||
detector = DWOpenposeDetector2(session_det=session_det, session_pose=session_pose)
|
||||
detector = DWOpenposeDetector(session_det=session_det, session_pose=session_pose)
|
||||
detected_image = detector.run(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
||||
@@ -14,7 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.controlnet import ControlField, ControlNetInvocation
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.controlnet import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
|
||||
@@ -302,7 +302,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
||||
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
||||
# newer version of the app that added a new model type.
|
||||
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
|
||||
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
|
||||
self._logger.warning(
|
||||
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
|
||||
)
|
||||
else:
|
||||
results.append(model_config)
|
||||
|
||||
|
||||
@@ -201,6 +201,12 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
||||
return None
|
||||
|
||||
|
||||
class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
@@ -237,6 +243,20 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
retried_from_item_id: Optional[int] = Field(
|
||||
default=None, description="The item_id of the queue item that this item was retried from"
|
||||
)
|
||||
is_api_validation_run: bool = Field(
|
||||
default=False,
|
||||
description="Whether this queue item is an API validation run.",
|
||||
)
|
||||
published_workflow_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The ID of the published workflow associated with this queue item",
|
||||
)
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
)
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The nodes that were used as output from the API"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
|
||||
@@ -47,6 +47,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
query: Optional[str],
|
||||
tags: Optional[list[str]],
|
||||
has_been_opened: Optional[bool],
|
||||
is_published: Optional[bool],
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -56,6 +57,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided categories."""
|
||||
pass
|
||||
@@ -66,6 +68,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided tags."""
|
||||
pass
|
||||
|
||||
@@ -67,6 +67,7 @@ class WorkflowWithoutID(BaseModel):
|
||||
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
|
||||
# it is None.
|
||||
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@@ -101,6 +102,7 @@ class WorkflowRecordDTOBase(BaseModel):
|
||||
opened_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||
default=None, description="The opened timestamp of the workflow."
|
||||
)
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
|
||||
class WorkflowRecordDTO(WorkflowRecordDTOBase):
|
||||
|
||||
@@ -119,6 +119,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
query: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
@@ -241,6 +242,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
if not tags:
|
||||
return {}
|
||||
@@ -292,6 +294,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
|
||||
@@ -65,9 +65,6 @@ def apply_monkeypatches() -> None:
|
||||
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
def register_mime_types() -> None:
|
||||
"""Register additional mime types for windows."""
|
||||
|
||||
@@ -5,62 +5,14 @@ import huggingface_hub
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
from invokeai.backend.image_util.util import np_to_pil
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image: Image.Image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
dwpose_image = Image.fromarray(dwpose_image)
|
||||
|
||||
return dwpose_image
|
||||
|
||||
|
||||
class DWOpenposeDetector:
|
||||
"""
|
||||
@@ -68,62 +20,6 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.pose_estimation(np_image)
|
||||
nums, keys, locs = candidate.shape
|
||||
candidate[..., 0] /= float(W)
|
||||
candidate[..., 1] /= float(H)
|
||||
body = candidate[:, :18].copy()
|
||||
body = body.reshape(nums * 18, locs)
|
||||
score = subset[:, :18]
|
||||
for i in range(len(score)):
|
||||
for j in range(len(score[i])):
|
||||
if score[i][j] > 0.3:
|
||||
score[i][j] = int(18 * i + j)
|
||||
else:
|
||||
score[i][j] = -1
|
||||
|
||||
un_visible = subset < 0.3
|
||||
candidate[un_visible] = -1
|
||||
|
||||
# foot = candidate[:, 18:24]
|
||||
|
||||
faces = candidate[:, 24:92]
|
||||
|
||||
hands = candidate[:, 92:113]
|
||||
hands = np.vstack([hands, candidate[:, 113:]])
|
||||
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
|
||||
|
||||
class DWOpenposeDetector2:
|
||||
"""
|
||||
Code from the original implementation of the DW Openpose Detector.
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
This implementation is similar to DWOpenposeDetector, with some alterations to allow the onnx models to be loaded
|
||||
and managed by the model manager.
|
||||
"""
|
||||
|
||||
hf_repo_id = "yzd-v/DWPose"
|
||||
hf_filename_onnx_det = "yolox_l.onnx"
|
||||
hf_filename_onnx_pose = "dw-ll_ucoco_384.onnx"
|
||||
@@ -213,7 +109,7 @@ class DWOpenposeDetector2:
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return DWOpenposeDetector2.draw_pose(
|
||||
return DWOpenposeDetector.draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
@@ -127,11 +126,13 @@ def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||
x2 = int(x2 * W)
|
||||
y2 = int(y2 * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
hsv_color = np.array([[[ie / float(len(edges)) * 180, 255, 255]]], dtype=np.uint8)
|
||||
rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0, 0]
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
||||
rgb_color.tolist(),
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.dw_openpose.onnxdet import inference_detector
|
||||
from invokeai.backend.image_util.dw_openpose.onnxpose import inference_pose
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
def __call__(self, oriImg):
|
||||
det_result = inference_detector(self.session_det, oriImg)
|
||||
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
||||
|
||||
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
|
||||
# compute neck joint
|
||||
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
||||
# neck score when visualizing pred
|
||||
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
||||
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
|
||||
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
|
||||
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
|
||||
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
|
||||
keypoints_info = new_keypoints_info
|
||||
|
||||
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
|
||||
|
||||
return keypoints, scores
|
||||
@@ -1,245 +0,0 @@
|
||||
import math
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
|
||||
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
if bias is not None:
|
||||
bias = bias.float()
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
|
||||
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
|
||||
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
end = end.float()
|
||||
if isinstance(weight, torch.Tensor):
|
||||
weight = weight.float()
|
||||
if out is not None:
|
||||
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
||||
else:
|
||||
out_fp32 = None
|
||||
result = _torch_lerp(input, end, weight, out=out_fp32)
|
||||
if out is not None:
|
||||
out.copy_(out_fp32.half())
|
||||
del out_fp32
|
||||
return result.half()
|
||||
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
|
||||
|
||||
def new_torch_interpolate(
|
||||
input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
antialias=False,
|
||||
):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(
|
||||
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
||||
|
||||
# TODO: refactor it
|
||||
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
|
||||
class ChunkedSlicedAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
assert isinstance(slice_size, int)
|
||||
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
|
||||
self.slice_size = slice_size
|
||||
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
if self.slice_size != 1 or attn.upcast_attention:
|
||||
return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
chunk_tmp_tensor = torch.empty(
|
||||
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
self.get_attention_scores_chunked(
|
||||
attn,
|
||||
query_slice,
|
||||
key_slice,
|
||||
attn_mask_slice,
|
||||
hidden_states[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
chunk_tmp_tensor,
|
||||
)
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
|
||||
# batch size = 1
|
||||
assert query.shape[0] == 1
|
||||
assert key.shape[0] == 1
|
||||
assert value.shape[0] == 1
|
||||
assert hidden_states.shape[0] == 1
|
||||
|
||||
# dtype = query.dtype
|
||||
if attn.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
# out_item_size = query.dtype.itemsize
|
||||
# if attn.upcast_attention:
|
||||
# out_item_size = torch.float32.itemsize
|
||||
out_item_size = query.element_size()
|
||||
if attn.upcast_attention:
|
||||
out_item_size = 4
|
||||
|
||||
chunk_size = 2**29
|
||||
|
||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||
chunk_step = max(1, int(query.shape[1] / chunks_count))
|
||||
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
def _get_chunk_view(tensor, start, length):
|
||||
if start + length > tensor.shape[1]:
|
||||
length = tensor.shape[1] - start
|
||||
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
|
||||
return tensor[:, start : start + length]
|
||||
|
||||
for chunk_pos in range(0, query.shape[1], chunk_step):
|
||||
if attention_mask is not None:
|
||||
torch.baddbmm(
|
||||
_get_chunk_view(attention_mask, chunk_pos, chunk_step),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=1,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
else:
|
||||
torch.baddbmm(
|
||||
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
|
||||
_get_chunk_view(query, chunk_pos, chunk_step),
|
||||
key,
|
||||
beta=0,
|
||||
alpha=attn.scale,
|
||||
out=chunk,
|
||||
)
|
||||
chunk = chunk.softmax(dim=-1)
|
||||
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
|
||||
|
||||
# del chunk
|
||||
|
||||
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor
|
||||
@@ -150,7 +150,7 @@
|
||||
"prettier": "^3.3.3",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"storybook": "^8.3.4",
|
||||
"tsafe": "^1.7.5",
|
||||
"tsafe": "^1.8.5",
|
||||
"type-fest": "^4.26.1",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^6.1.0",
|
||||
|
||||
8
invokeai/frontend/web/pnpm-lock.yaml
generated
8
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -284,8 +284,8 @@ devDependencies:
|
||||
specifier: ^8.3.4
|
||||
version: 8.3.4
|
||||
tsafe:
|
||||
specifier: ^1.7.5
|
||||
version: 1.7.5
|
||||
specifier: ^1.8.5
|
||||
version: 1.8.5
|
||||
type-fest:
|
||||
specifier: ^4.26.1
|
||||
version: 4.26.1
|
||||
@@ -8791,8 +8791,8 @@ packages:
|
||||
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
|
||||
dev: false
|
||||
|
||||
/tsafe@1.7.5:
|
||||
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
|
||||
/tsafe@1.8.5:
|
||||
resolution: {integrity: sha512-LFWTWQrW6rwSY+IBNFl2ridGfUzVsPwrZ26T4KUJww/py8rzaQ/SY+MIz6YROozpUCaRcuISqagmlwub9YT9kw==}
|
||||
dev: true
|
||||
|
||||
/tsconfck@3.1.5(typescript@5.6.2):
|
||||
|
||||
@@ -1706,6 +1706,7 @@
|
||||
"noRecentWorkflows": "No Recent Workflows",
|
||||
"private": "Private",
|
||||
"shared": "Shared",
|
||||
"published": "Published",
|
||||
"browseWorkflows": "Browse Workflows",
|
||||
"deselectAll": "Deselect All",
|
||||
"recommended": "Recommended For You",
|
||||
@@ -1783,7 +1784,39 @@
|
||||
"textPlaceholder": "Empty Text",
|
||||
"workflowBuilderAlphaWarning": "The workflow builder is currently in alpha. There may be breaking changes before the stable release.",
|
||||
"minimum": "Minimum",
|
||||
"maximum": "Maximum"
|
||||
"maximum": "Maximum",
|
||||
"publish": "Publish",
|
||||
"published": "Published",
|
||||
"unpublish": "Unpublish",
|
||||
"workflowLocked": "Workflow Locked",
|
||||
"workflowLockedPublished": "Published workflows are locked for editing.\nYou can unpublish the workflow to edit it, or make a copy of it.",
|
||||
"workflowLockedDuringPublishing": "Workflow is locked while configuring for publishing.",
|
||||
"selectOutputNode": "Select Output Node",
|
||||
"changeOutputNode": "Change Output Node",
|
||||
"publishedWorkflowOutputs": "Outputs",
|
||||
"publishedWorkflowInputs": "Inputs",
|
||||
"unpublishableInputs": "These unpublishable inputs will be omitted",
|
||||
"noPublishableInputs": "No publishable inputs",
|
||||
"noOutputNodeSelected": "No output node selected",
|
||||
"cannotPublish": "Cannot publish workflow",
|
||||
"publishWarnings": "Warnings",
|
||||
"errorWorkflowHasUnsavedChanges": "Workflow has unsaved changes",
|
||||
"errorWorkflowHasBatchOrGeneratorNodes": "Workflow has batch and/or generator nodes",
|
||||
"errorWorkflowHasInvalidGraph": "Workflow graph invalid (hover Invoke button for details)",
|
||||
"errorWorkflowHasNoOutputNode": "No output node selected",
|
||||
"warningWorkflowHasNoPublishableInputFields": "No publishable input fields selected - published workflow will run with only default values",
|
||||
"warningWorkflowHasUnpublishableInputFields": "Workflow has some unpublishable inputs - these will be omitted from the published workflow",
|
||||
"publishFailed": "Publish failed",
|
||||
"publishFailedDesc": "There was a problem publishing the workflow. Please try again.",
|
||||
"publishSuccess": "Your workflow is being published",
|
||||
"publishSuccessDesc": "Check your <LinkComponent>Project Dashboard</LinkComponent> to see its progress.",
|
||||
"publishInProgress": "Publishing in progress",
|
||||
"publishedWorkflowIsLocked": "Published workflow is locked",
|
||||
"publishingValidationRun": "Publishing Validation Run",
|
||||
"publishingValidationRunInProgress": "Publishing validation run in progress.",
|
||||
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow.",
|
||||
"selectingOutputNode": "Selecting output node",
|
||||
"selectingOutputNodeDesc": "Click a node to select it as the workflow's output node."
|
||||
}
|
||||
},
|
||||
"controlLayers": {
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
|
||||
export const enqueueRequested = createAction<{
|
||||
tabName: TabName;
|
||||
prepend: boolean;
|
||||
}>('app/enqueueRequested');
|
||||
@@ -10,7 +10,6 @@ import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/l
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
@@ -63,7 +62,6 @@ addGalleryImageClickedListener(startAppListening);
|
||||
addGalleryOffsetChangedListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedNodes(startAppListening);
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
|
||||
@@ -5,7 +5,7 @@ import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAd
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import type { EnqueueBatchArg, ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('queue');
|
||||
@@ -19,7 +19,7 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
|
||||
const { imageDTO } = action.payload;
|
||||
const state = getState();
|
||||
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
const enqueueBatchArg: EnqueueBatchArg = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: await buildAdHocPostProcessingGraph({
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
@@ -17,10 +17,11 @@ import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
|
||||
actionCreator: enqueueRequestedCanvas,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
log.debug('Enqueue requested');
|
||||
const state = getState();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
@@ -9,10 +9,11 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
|
||||
|
||||
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
|
||||
actionCreator: enqueueRequestedUpscaling,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
@@ -3,6 +3,7 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { getDebugLoggerMiddleware } from 'app/store/middleware/debugLoggerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
@@ -175,6 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.concat(authToastMiddleware)
|
||||
.concat(getDebugLoggerMiddleware())
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
enhancers: (getDefaultEnhancers) => {
|
||||
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
|
||||
|
||||
@@ -74,6 +74,7 @@ export type AppConfig = {
|
||||
allowPrivateBoards: boolean;
|
||||
allowPrivateStylePresets: boolean;
|
||||
allowClientSideUpload: boolean;
|
||||
allowPublishWorkflows: boolean;
|
||||
disabledTabs: TabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
|
||||
@@ -14,7 +14,7 @@ export const useGlobalHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'invoke',
|
||||
category: 'app',
|
||||
callback: queue.queueBack,
|
||||
callback: queue.enqueueBack,
|
||||
options: {
|
||||
enabled: !queue.isDisabled && !queue.isLoading,
|
||||
preventDefault: true,
|
||||
@@ -26,7 +26,7 @@ export const useGlobalHotkeys = () => {
|
||||
useRegisteredHotkeys({
|
||||
id: 'invokeFront',
|
||||
category: 'app',
|
||||
callback: queue.queueFront,
|
||||
callback: queue.enqueueFront,
|
||||
options: {
|
||||
enabled: !queue.isDisabled && !queue.isLoading,
|
||||
preventDefault: true,
|
||||
|
||||
@@ -54,7 +54,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
|
||||
import type { EnqueueBatchArg, ImageDTO, S } from 'services/api/types';
|
||||
import { QueueError } from 'services/events/errors';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -291,7 +291,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
*/
|
||||
const origin = getPrefixedId(graph.id);
|
||||
|
||||
const batch: BatchConfig = {
|
||||
const batch: EnqueueBatchArg = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: graph.getGraph(),
|
||||
|
||||
@@ -2,7 +2,9 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { AddNodeCmdk } from 'features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk';
|
||||
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
||||
import { TopCenterPanel } from 'features/nodes/components/flow/panels/TopPanel/TopCenterPanel';
|
||||
import { TopLeftPanel } from 'features/nodes/components/flow/panels/TopPanel/TopLeftPanel';
|
||||
import { TopRightPanel } from 'features/nodes/components/flow/panels/TopPanel/TopRightPanel';
|
||||
import WorkflowEditorSettings from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -32,7 +34,9 @@ const NodeEditor = () => {
|
||||
<>
|
||||
<Flow />
|
||||
<AddNodeCmdk />
|
||||
<TopPanel />
|
||||
<TopLeftPanel />
|
||||
<TopCenterPanel />
|
||||
<TopRightPanel />
|
||||
<BottomLeftPanel />
|
||||
<MinimapPanel />
|
||||
</>
|
||||
|
||||
@@ -18,6 +18,7 @@ import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import {
|
||||
$addNodeCmdk,
|
||||
$cursorPos,
|
||||
@@ -146,6 +147,7 @@ export const AddNodeCmdk = memo(() => {
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const addNode = useAddNode();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
// Filtering the list is expensive - debounce the search term to avoid stutters
|
||||
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
|
||||
const isOpen = useStore($addNodeCmdk);
|
||||
@@ -160,8 +162,8 @@ export const AddNodeCmdk = memo(() => {
|
||||
id: 'addNode',
|
||||
category: 'workflows',
|
||||
callback: open,
|
||||
options: { enabled: tab === 'workflows', preventDefault: true },
|
||||
dependencies: [open, tab],
|
||||
options: { enabled: tab === 'workflows' && !isLocked, preventDefault: true },
|
||||
dependencies: [open, tab, isLocked],
|
||||
});
|
||||
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
|
||||
@@ -4,6 +4,7 @@ import type {
|
||||
EdgeChange,
|
||||
HandleType,
|
||||
NodeChange,
|
||||
NodeMouseHandler,
|
||||
OnEdgesChange,
|
||||
OnInit,
|
||||
OnMoveEnd,
|
||||
@@ -16,8 +17,10 @@ import type {
|
||||
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useNodeCopyPaste } from 'features/nodes/hooks/useNodeCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import {
|
||||
@@ -44,7 +47,7 @@ import {
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { type AnyEdge, type AnyNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@@ -92,6 +95,8 @@ export const Flow = memo(() => {
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const store = useAppStore();
|
||||
const isWorkflowsFocused = useIsRegionFocused('workflows');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
useFocusRegion('workflows', flowWrapper);
|
||||
|
||||
useSyncExecutionState();
|
||||
@@ -215,7 +220,7 @@ export const Flow = memo(() => {
|
||||
id: 'copySelection',
|
||||
category: 'workflows',
|
||||
callback: copySelection,
|
||||
options: { preventDefault: true },
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [copySelection],
|
||||
});
|
||||
|
||||
@@ -244,24 +249,24 @@ export const Flow = memo(() => {
|
||||
id: 'selectAll',
|
||||
category: 'workflows',
|
||||
callback: selectAll,
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [selectAll, isWorkflowsFocused],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [selectAll, isWorkflowsFocused, isLocked],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'pasteSelection',
|
||||
category: 'workflows',
|
||||
callback: pasteSelection,
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [pasteSelection],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [pasteSelection, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'pasteSelectionWithEdges',
|
||||
category: 'workflows',
|
||||
callback: pasteSelectionWithEdges,
|
||||
options: { enabled: isWorkflowsFocused, preventDefault: true },
|
||||
dependencies: [pasteSelectionWithEdges],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
|
||||
dependencies: [pasteSelectionWithEdges, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -270,8 +275,8 @@ export const Flow = memo(() => {
|
||||
callback: () => {
|
||||
dispatch(undo());
|
||||
},
|
||||
options: { enabled: isWorkflowsFocused && mayUndo, preventDefault: true },
|
||||
dependencies: [mayUndo],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked && mayUndo, preventDefault: true },
|
||||
dependencies: [mayUndo, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -280,8 +285,8 @@ export const Flow = memo(() => {
|
||||
callback: () => {
|
||||
dispatch(redo());
|
||||
},
|
||||
options: { enabled: isWorkflowsFocused && mayRedo, preventDefault: true },
|
||||
dependencies: [mayRedo],
|
||||
options: { enabled: isWorkflowsFocused && !isLocked && mayRedo, preventDefault: true },
|
||||
dependencies: [mayRedo, isLocked, isWorkflowsFocused],
|
||||
});
|
||||
|
||||
const onEscapeHotkey = useCallback(() => {
|
||||
@@ -318,10 +323,22 @@ export const Flow = memo(() => {
|
||||
id: 'deleteSelection',
|
||||
category: 'workflows',
|
||||
callback: deleteSelection,
|
||||
options: { preventDefault: true, enabled: isWorkflowsFocused },
|
||||
dependencies: [deleteSelection, isWorkflowsFocused],
|
||||
options: { preventDefault: true, enabled: isWorkflowsFocused && !isLocked },
|
||||
dependencies: [deleteSelection, isWorkflowsFocused, isLocked],
|
||||
});
|
||||
|
||||
const onNodeClick = useCallback<NodeMouseHandler<AnyNode>>((e, node) => {
|
||||
if (!$isSelectingOutputNode.get()) {
|
||||
return;
|
||||
}
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const { id } = node.data;
|
||||
$outputNodeId.set(id);
|
||||
$isSelectingOutputNode.set(false);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ReactFlow<AnyNode, AnyEdge>
|
||||
id="workflow-editor"
|
||||
@@ -332,6 +349,7 @@ export const Flow = memo(() => {
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onInit={onInit}
|
||||
onNodeClick={onNodeClick}
|
||||
onMouseMove={onMouseMove}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
@@ -344,6 +362,12 @@ export const Flow = memo(() => {
|
||||
onMoveEnd={handleMoveEnd}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
isValidConnection={isValidConnection}
|
||||
edgesFocusable={!isLocked}
|
||||
edgesReconnectable={!isLocked}
|
||||
nodesDraggable={!isLocked}
|
||||
nodesConnectable={!isLocked}
|
||||
nodesFocusable={!isLocked}
|
||||
elementsSelectable={!isLocked}
|
||||
minZoom={0.1}
|
||||
snapToGrid={shouldSnapToGrid}
|
||||
snapGrid={snapGrid}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Handle, Position } from '@xyflow/react';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { map } from 'lodash-es';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
@@ -19,7 +19,7 @@ const collapsedHandleStyles: CSSProperties = {
|
||||
};
|
||||
|
||||
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
|
||||
if (!template) {
|
||||
return null;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useInvocationNodeNotes } from 'features/nodes/hooks/useNodeNotes';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -27,9 +27,9 @@ InvocationNodeInfoIcon.displayName = 'InvocationNodeInfoIcon';
|
||||
|
||||
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const notes = useInvocationNodeNotes(nodeId);
|
||||
const label = useNodeLabel(nodeId);
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const version = useNodeVersion(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
const nodeTemplate = useNodeTemplateOrThrow(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const title = useMemo(() => {
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { ChangeEvent } from 'react';
|
||||
@@ -48,7 +48,7 @@ InputFieldDescriptionPopover.displayName = 'InputFieldDescriptionPopover';
|
||||
const Content = memo(({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const description = useInputFieldDescriptionSafe(nodeId, fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(nodeId, fieldName);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(fieldDescriptionChanged({ nodeId, fieldName, val: e.target.value }));
|
||||
|
||||
@@ -7,7 +7,7 @@ import { InputFieldResetToDefaultValueIconButton } from 'features/nodes/componen
|
||||
import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useRef } from 'react';
|
||||
@@ -100,7 +100,7 @@ const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemp
|
||||
const draggableRef = useRef<HTMLDivElement>(null);
|
||||
const dragHandleRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef);
|
||||
const isDragging = useNodeFieldDnd(nodeId, fieldName, fieldTemplate, draggableRef, dragHandleRef);
|
||||
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
@@ -105,9 +106,16 @@ type HandleCommonProps = {
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Handle
|
||||
type="target"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Left}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
@@ -130,6 +138,7 @@ const ConnectionInProgressHandle = memo(
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionError !== null) {
|
||||
@@ -140,7 +149,13 @@ const ConnectionInProgressHandle = memo(
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Handle
|
||||
type="target"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Left}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
|
||||
@@ -17,7 +17,7 @@ import { StringFieldDropdown } from 'features/nodes/components/flow/nodes/Invoca
|
||||
import { StringFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldInput';
|
||||
import { StringFieldTextarea } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldTextarea';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
|
||||
@@ -9,8 +9,8 @@ import {
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateTitle } from 'features/nodes/hooks/useInputFieldTemplateTitle';
|
||||
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
|
||||
import type { MouseEvent } from 'react';
|
||||
@@ -43,8 +43,8 @@ interface Props {
|
||||
export const InputFieldTitle = memo((props: Props) => {
|
||||
const { nodeId, fieldName, isInvalid, isDragging } = props;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const label = useInputFieldLabelSafe(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitle(nodeId, fieldName);
|
||||
const label = useInputFieldUserTitleSafe(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { useInputFieldErrors } from 'features/nodes/hooks/useInputFieldErrors';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
@@ -105,9 +106,17 @@ type HandleCommonProps = {
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Handle
|
||||
type="source"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Right}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
@@ -130,6 +139,7 @@ const ConnectionInProgressHandle = memo(
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionErrorTKey !== null) {
|
||||
@@ -140,7 +150,13 @@ const ConnectionInProgressHandle = memo(
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Handle
|
||||
type="source"
|
||||
id={fieldTemplate.name}
|
||||
position={Position.Right}
|
||||
style={handleStyles}
|
||||
isConnectable={!isLocked}
|
||||
>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
|
||||
@@ -3,8 +3,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
|
||||
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@@ -17,10 +17,10 @@ type Props = {
|
||||
|
||||
const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const batchGroupId = useBatchGroupId(nodeId);
|
||||
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const templateTitle = useNodeTemplateTitleSafe(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { ChakraProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, useGlobalMenuClose } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useMouseOverFormField, useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
|
||||
@@ -62,6 +63,12 @@ const containerSx: SystemStyleObject = {
|
||||
display: 'block',
|
||||
shadow: '0 0 0 3px var(--invoke-colors-blue-300)',
|
||||
},
|
||||
'&[data-is-editor-locked="true"]': {
|
||||
'& *': {
|
||||
cursor: 'not-allowed',
|
||||
pointerEvents: 'none',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const shadowsSx: SystemStyleObject = {
|
||||
@@ -98,7 +105,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const mouseOverNode = useMouseOverNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const zoomToNode = useZoomToNode();
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const executionState = useNodeExecutionState(nodeId);
|
||||
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
|
||||
@@ -126,9 +134,9 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
// This target is marked as not fitting the view on double click
|
||||
return;
|
||||
}
|
||||
zoomToNode(nodeId);
|
||||
zoomToNode();
|
||||
},
|
||||
[nodeId, zoomToNode]
|
||||
[zoomToNode]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -141,6 +149,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
sx={containerSx}
|
||||
width={width || NODE_WIDTH}
|
||||
opacity={opacity}
|
||||
data-is-editor-locked={isLocked}
|
||||
data-is-selected={selected}
|
||||
data-is-mouse-over-form-field={mouseOverFormField.isMouseOverFormField}
|
||||
>
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const TopCenterPanel = memo(() => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
return (
|
||||
<Flex gap={2} top={2} left="50%" transform="translateX(-50%)" position="absolute" pointerEvents="none">
|
||||
{!!name.length && <WorkflowName />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TopCenterPanel.displayName = 'TopCenterPanel';
|
||||
@@ -0,0 +1,64 @@
|
||||
import { Alert, AlertDescription, AlertIcon, AlertTitle, Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isSelectingOutputNode,
|
||||
useIsValidationRunInProgress,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const TopLeftPanel = memo(() => {
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isValidationRunInProgress = useIsValidationRunInProgress();
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
{!isLocked && (
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
</Flex>
|
||||
)}
|
||||
{isLocked && (
|
||||
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<Box>
|
||||
<AlertTitle>{t('workflows.builder.workflowLocked')}</AlertTitle>
|
||||
{isValidationRunInProgress && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.publishingValidationRunInProgress')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && !isSelectingOutputNode && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedDuringPublishing')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isInPublishFlow && !isValidationRunInProgress && isSelectingOutputNode && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.selectingOutputNodeDesc')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
{isPublished && (
|
||||
<AlertDescription whiteSpace="pre-wrap">
|
||||
{t('workflows.builder.workflowLockedPublished')}
|
||||
</AlertDescription>
|
||||
)}
|
||||
</Box>
|
||||
</Alert>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TopLeftPanel.displayName = 'TopLeftPanel';
|
||||
@@ -1,40 +0,0 @@
|
||||
import { Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
</Flex>
|
||||
<Spacer />
|
||||
{!!name.length && <WorkflowName />}
|
||||
<Spacer />
|
||||
<ClearFlowButton />
|
||||
<SaveWorkflowButton />
|
||||
<IconButton
|
||||
pointerEvents="auto"
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiGearSixFill />}
|
||||
onClick={modal.setTrue}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(TopCenterPanel);
|
||||
@@ -0,0 +1,34 @@
|
||||
import { Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
export const TopRightPanel = memo(() => {
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (isLocked) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex gap={2} top={2} right={2} position="absolute" alignItems="flex-end" pointerEvents="none">
|
||||
<ClearFlowButton />
|
||||
<SaveWorkflowButton />
|
||||
<IconButton
|
||||
pointerEvents="auto"
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiGearSixFill />}
|
||||
onClick={modal.setTrue}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TopRightPanel.displayName = 'TopRightPanel';
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { HorizontalResizeHandle } from 'features/ui/components/tabs/ResizeHandle';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -23,23 +22,21 @@ export const EditModeLeftPanelContent = memo(() => {
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<PanelGroup
|
||||
ref={panelGroupRef}
|
||||
id="workflow-panel-group"
|
||||
autoSaveId="workflow-panel-group"
|
||||
direction="vertical"
|
||||
style={panelGroupStyles}
|
||||
>
|
||||
<Panel id="workflow" collapsible minSize={25}>
|
||||
<WorkflowFieldsLinearViewPanel />
|
||||
</Panel>
|
||||
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
|
||||
<Panel id="inspector" collapsible minSize={25}>
|
||||
<WorkflowNodeInspectorPanel />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</ScrollableContent>
|
||||
<PanelGroup
|
||||
ref={panelGroupRef}
|
||||
id="workflow-panel-group"
|
||||
autoSaveId="workflow-panel-group"
|
||||
direction="vertical"
|
||||
style={panelGroupStyles}
|
||||
>
|
||||
<Panel id="workflow" collapsible minSize={25}>
|
||||
<WorkflowFieldsLinearViewPanel />
|
||||
</Panel>
|
||||
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
|
||||
<Panel id="inspector" collapsible minSize={25}>
|
||||
<WorkflowNodeInspectorPanel />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold, PiLockOpenBold } from 'react-icons/pi';
|
||||
|
||||
export const PublishedWorkflowPanelContent = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const saveAs = useSaveOrSaveAsWorkflow();
|
||||
return (
|
||||
<Flex flexDir="column" w="full" h="full" gap={2} alignItems="center">
|
||||
<Heading size="md" pt={32}>
|
||||
{t('workflows.builder.workflowLocked')}
|
||||
</Heading>
|
||||
<Text fontSize="md">{t('workflows.builder.publishedWorkflowsLocked')}</Text>
|
||||
<Button size="md" onClick={saveAs} variant="ghost" leftIcon={<PiCopyBold />}>
|
||||
{t('common.saveAs')}
|
||||
</Button>
|
||||
<Button size="md" onClick={undefined} variant="ghost" leftIcon={<PiLockOpenBold />}>
|
||||
{t('workflows.builder.unpublish')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishedWorkflowPanelContent.displayName = 'PublishedWorkflowPanelContent';
|
||||
@@ -2,7 +2,7 @@ import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowListMenuTrigger } from 'features/nodes/components/sidePanel/WorkflowListMenu/WorkflowListMenuTrigger';
|
||||
import { WorkflowViewEditToggleButton } from 'features/nodes/components/sidePanel/WorkflowViewEditToggleButton';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { WorkflowLibraryMenu } from 'features/workflowLibrary/components/WorkflowLibraryMenu/WorkflowLibraryMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
@@ -10,12 +10,13 @@ import SaveWorkflowButton from './SaveWorkflowButton';
|
||||
|
||||
export const ActiveWorkflowNameAndActions = memo(() => {
|
||||
const mode = useAppSelector(selectWorkflowMode);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={1} minW={0}>
|
||||
<WorkflowListMenuTrigger />
|
||||
<Spacer />
|
||||
{mode === 'edit' && <SaveWorkflowButton />}
|
||||
{mode === 'edit' && !isPublished && <SaveWorkflowButton />}
|
||||
<WorkflowViewEditToggleButton />
|
||||
<WorkflowLibraryMenu />
|
||||
</Flex>
|
||||
|
||||
@@ -1,22 +1,30 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EditModeLeftPanelContent } from 'features/nodes/components/sidePanel/EditModeLeftPanelContent';
|
||||
import { PublishedWorkflowPanelContent } from 'features/nodes/components/sidePanel/PublishedWorkflowPanelContent';
|
||||
import { $isInPublishFlow } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { PublishWorkflowPanelContent } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
|
||||
import { ActiveWorkflowDescription } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowDescription';
|
||||
import { ActiveWorkflowNameAndActions } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowNameAndActions';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { ViewModeLeftPanelContent } from './viewMode/ViewModeLeftPanelContent';
|
||||
|
||||
const WorkflowsTabLeftPanel = () => {
|
||||
const mode = useAppSelector(selectWorkflowMode);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" gap={2} flexDir="column">
|
||||
<ActiveWorkflowNameAndActions />
|
||||
{mode === 'view' && <ActiveWorkflowDescription />}
|
||||
{mode === 'view' && <ViewModeLeftPanelContent />}
|
||||
{mode === 'edit' && <EditModeLeftPanelContent />}
|
||||
{isInPublishFlow && <PublishWorkflowPanelContent />}
|
||||
{!isInPublishFlow && <ActiveWorkflowNameAndActions />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'view' && <ActiveWorkflowDescription />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'view' && <ViewModeLeftPanelContent />}
|
||||
{!isInPublishFlow && !isPublished && mode === 'edit' && <EditModeLeftPanelContent />}
|
||||
{isPublished && <PublishedWorkflowPanelContent />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -67,11 +67,8 @@ FormElementEditModeHeader.displayName = 'FormElementEditModeHeader';
|
||||
const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
const { t } = useTranslation();
|
||||
const { nodeId } = element.data.fieldIdentifier;
|
||||
const zoomToNode = useZoomToNode();
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const onClick = useCallback(() => {
|
||||
zoomToNode(nodeId);
|
||||
}, [nodeId, zoomToNode]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -79,7 +76,7 @@ const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
tooltip={t('workflows.builder.zoomToNode')}
|
||||
aria-label={t('workflows.builder.zoomToNode')}
|
||||
onClick={onClick}
|
||||
onClick={zoomToNode}
|
||||
icon={<PiGpsFixBold />}
|
||||
variant="link"
|
||||
size="sm"
|
||||
|
||||
@@ -2,8 +2,8 @@ import { FormHelperText, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
@@ -13,7 +13,7 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
|
||||
return (
|
||||
<Flex ref={draggableRef} id={id} className={NODE_FIELD_CLASS_NAME} sx={sx} data-parent-layout={containerCtx.layout}>
|
||||
<NodeFieldElementEditModeContent dragHandleRef={dragHandleRef} el={el} isDragging={isDragging} />
|
||||
<NodeFieldElementOverlay element={el} />
|
||||
<NodeFieldElementOverlay nodeId={el.data.fieldIdentifier.nodeId} />
|
||||
<DndListDropIndicator activeDropRegion={activeDropRegion} gap="var(--invoke-space-4)" />
|
||||
</Flex>
|
||||
);
|
||||
@@ -105,9 +105,9 @@ const nodeFieldOverlaySx: SystemStyleObject = {
|
||||
},
|
||||
};
|
||||
|
||||
const NodeFieldElementOverlay = memo(({ element }: { element: NodeFieldElement }) => {
|
||||
const mouseOverNode = useMouseOverNode(element.data.fieldIdentifier.nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(element.data.fieldIdentifier.nodeId);
|
||||
export const NodeFieldElementOverlay = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const mouseOverNode = useMouseOverNode(nodeId);
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
|
||||
return (
|
||||
<Box
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { Flex, FormLabel, Spacer } from '@invoke-ai/ui-library';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
export const NodeFieldElementLabel = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _label = useMemo(() => label || fieldTemplate.title, [label, fieldTemplate.title]);
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Flex, FormLabel, Input, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -12,7 +12,7 @@ export const NodeFieldElementLabelEditable = memo(({ el }: { el: NodeFieldElemen
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { NodeFieldElementFloatSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementFloatSettings';
|
||||
import { NodeFieldElementIntegerSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementIntegerSettings';
|
||||
import { NodeFieldElementStringSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringSettings';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
|
||||
import {
|
||||
isFloatFieldInputTemplate,
|
||||
|
||||
@@ -5,8 +5,9 @@ import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/
|
||||
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { NodeFieldElementLabel } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabel';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow, useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplateSafe';
|
||||
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
@@ -36,7 +37,7 @@ const useFormatFallbackLabel = () => {
|
||||
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { id, data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const containerCtx = useContainerContext();
|
||||
const formatFallbackLabel = useFormatFallbackLabel();
|
||||
@@ -69,7 +70,7 @@ NodeFieldElementViewMode.displayName = 'NodeFieldElementViewMode';
|
||||
const NodeFieldElementViewModeContent = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _description = useMemo(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import type { DropTargetRecord } from '@atlaskit/pragmatic-drag-and-drop/dist/types/internal-types';
|
||||
import type { ElementDragPayload } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import {
|
||||
draggable,
|
||||
dropTargetForElements,
|
||||
@@ -33,7 +35,7 @@ import {
|
||||
selectFormRootElementId,
|
||||
selectWorkflowSlice,
|
||||
} from 'features/nodes/store/workflowSlice';
|
||||
import type { FieldIdentifier, FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type { FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
|
||||
import type { ElementId, FormElement } from 'features/nodes/types/workflow';
|
||||
import { buildNodeFieldElement, isContainerElement } from 'features/nodes/types/workflow';
|
||||
import type { RefObject } from 'react';
|
||||
@@ -58,6 +60,27 @@ const isFormElementDndData = (data: Record<string | symbol, unknown>): data is F
|
||||
return uniqueFormElementDndKey in data;
|
||||
};
|
||||
|
||||
const uniqueNodeFieldDndKey = Symbol('node-field');
|
||||
type NodeFieldDndData = {
|
||||
[uniqueNodeFieldDndKey]: true;
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
};
|
||||
const buildNodeFieldDndData = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
fieldTemplate: FieldInputTemplate
|
||||
): NodeFieldDndData => ({
|
||||
[uniqueNodeFieldDndKey]: true,
|
||||
nodeId,
|
||||
fieldName,
|
||||
fieldTemplate,
|
||||
});
|
||||
const isNodeFieldDndData = (data: Record<string | symbol, unknown>): data is NodeFieldDndData => {
|
||||
return uniqueNodeFieldDndKey in data;
|
||||
};
|
||||
|
||||
/**
|
||||
* Flashes an element by changing its background color. Used to indicate that an element has been moved.
|
||||
* @param elementId The id of the element to flash
|
||||
@@ -133,6 +156,27 @@ const useGetInitialValue = () => {
|
||||
return _getInitialValue;
|
||||
};
|
||||
|
||||
const getSourceElement = (source: ElementDragPayload) => {
|
||||
if (isNodeFieldDndData(source.data)) {
|
||||
const { nodeId, fieldName, fieldTemplate } = source.data;
|
||||
return buildNodeFieldElement(nodeId, fieldName, fieldTemplate.type);
|
||||
}
|
||||
|
||||
if (isFormElementDndData(source.data)) {
|
||||
return source.data.element;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const getTargetElement = (target: DropTargetRecord) => {
|
||||
if (isFormElementDndData(target.data)) {
|
||||
return target.data.element;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Singleton hook that monitors for builder dnd events and dispatches actions accordingly.
|
||||
*/
|
||||
@@ -156,20 +200,20 @@ export const useBuilderDndMonitor = () => {
|
||||
|
||||
useEffect(() => {
|
||||
return monitorForElements({
|
||||
canMonitor: ({ source }) => isFormElementDndData(source.data),
|
||||
canMonitor: ({ source }) => isFormElementDndData(source.data) || isNodeFieldDndData(source.data),
|
||||
onDrop: ({ location, source }) => {
|
||||
const target = location.current.dropTargets[0];
|
||||
if (!target) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isFormElementDndData(source.data) || !isFormElementDndData(target.data)) {
|
||||
const sourceElement = getSourceElement(source);
|
||||
const targetElement = getTargetElement(target);
|
||||
|
||||
if (!sourceElement || !targetElement) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceElement = source.data.element;
|
||||
const targetElement = target.data.element;
|
||||
|
||||
if (sourceElement.id === targetElement.id) {
|
||||
// Dropping on self is a no-op
|
||||
return;
|
||||
@@ -359,8 +403,15 @@ export const useFormElementDnd = (
|
||||
element: draggableElement,
|
||||
// TODO(psyche): This causes a kinda jittery behaviour - need a better heuristic to determine stickiness
|
||||
getIsSticky: () => false,
|
||||
canDrop: ({ source }) =>
|
||||
isFormElementDndData(source.data) && source.data.element.id !== getElement(elementId).parentId,
|
||||
canDrop: ({ source }) => {
|
||||
if (isNodeFieldDndData(source.data)) {
|
||||
return true;
|
||||
}
|
||||
if (isFormElementDndData(source.data)) {
|
||||
return source.data.element.id !== getElement(elementId).parentId;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
getData: ({ input }) => {
|
||||
const element = getElement(elementId);
|
||||
|
||||
@@ -423,8 +474,16 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
dropTargetForElements({
|
||||
element: droppableElement,
|
||||
getIsSticky: () => false,
|
||||
canDrop: ({ source }) =>
|
||||
getElement(rootElementId, isContainerElement).data.children.length === 0 && isFormElementDndData(source.data),
|
||||
canDrop: ({ source }) => {
|
||||
const rootElement = getElement(rootElementId, isContainerElement);
|
||||
if (rootElement.data.children.length !== 0) {
|
||||
return false;
|
||||
}
|
||||
if (isNodeFieldDndData(source.data) || isFormElementDndData(source.data)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
getData: ({ input }) => {
|
||||
const element = getElement(rootElementId, isContainerElement);
|
||||
|
||||
@@ -455,7 +514,8 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
/**
|
||||
* Hook that provides dnd functionality for node fields.
|
||||
*
|
||||
* @param fieldIdentifier The identifier of the node field
|
||||
* @param nodeId: The id of the node
|
||||
* @param fieldName: The name of the field
|
||||
* @param fieldTemplate The template of the node field, required to build the form element
|
||||
* @param draggableRef The ref of the draggable HTML element
|
||||
* @param dragHandleRef The ref of the drag handle HTML element
|
||||
@@ -463,7 +523,8 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
|
||||
* @returns Whether the node field is currently being dragged
|
||||
*/
|
||||
export const useNodeFieldDnd = (
|
||||
fieldIdentifier: FieldIdentifier,
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
fieldTemplate: FieldInputTemplate,
|
||||
draggableRef: RefObject<HTMLElement>,
|
||||
dragHandleRef: RefObject<HTMLElement>
|
||||
@@ -481,12 +542,7 @@ export const useNodeFieldDnd = (
|
||||
draggable({
|
||||
element: draggableElement,
|
||||
dragHandle: dragHandleElement,
|
||||
getInitialData: () => {
|
||||
const { nodeId, fieldName } = fieldIdentifier;
|
||||
const { type } = fieldTemplate;
|
||||
const element = buildNodeFieldElement(nodeId, fieldName, type);
|
||||
return buildFormElementDndData(element);
|
||||
},
|
||||
getInitialData: () => buildNodeFieldDndData(nodeId, fieldName, fieldTemplate),
|
||||
onDragStart: () => {
|
||||
setIsDragging(true);
|
||||
},
|
||||
@@ -495,7 +551,7 @@ export const useNodeFieldDnd = (
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [dragHandleRef, draggableRef, fieldIdentifier, fieldTemplate]);
|
||||
}, [dragHandleRef, draggableRef, fieldName, fieldTemplate, nodeId]);
|
||||
|
||||
return isDragging;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { formElementAdded, selectFormRootElementId } from 'features/nodes/store/workflowSlice';
|
||||
import { buildNodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import { InvocationNodeNotesTextarea } from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeNotesTextarea';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
@@ -36,7 +36,7 @@ export default memo(InspectorDetailsTab);
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const version = useNodeVersion(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const needsUpdate = useNodeNeedsUpdate(nodeId);
|
||||
|
||||
return (
|
||||
|
||||
@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -37,7 +37,7 @@ const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`
|
||||
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const nes = useNodeExecutionState(nodeId);
|
||||
|
||||
if (!nes || nes.outputs.length === 0) {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { Flex, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
|
||||
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -14,8 +14,8 @@ type Props = {
|
||||
|
||||
const InspectorTabEditableNodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const label = useNodeUserTitleSafe(nodeId);
|
||||
const templateTitle = useNodeTemplateTitleSafe(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const onChange = useCallback(
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
|
||||
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -29,7 +29,7 @@ export default memo(NodeTemplateInspector);
|
||||
|
||||
const Content = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
|
||||
return <DataViewer data={template} label={t('nodes.nodeTemplate')} bg="base.850" color="base.200" />;
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplateSafe';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
|
||||
@@ -0,0 +1,445 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Divider,
|
||||
Flex,
|
||||
ListItem,
|
||||
Spacer,
|
||||
Text,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { $projectUrl } from 'app/store/nanostores/projectId';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { ExternalLink } from 'features/gallery/components/ImageViewer/NoContentForViewer';
|
||||
import { NodeFieldElementOverlay } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$isReadyToDoValidationRun,
|
||||
$isSelectingOutputNode,
|
||||
$outputNodeId,
|
||||
$validationRunBatchId,
|
||||
usePublishInputs,
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
|
||||
import { useInputFieldUserTitleOrThrow } from 'features/nodes/hooks/useInputFieldUserTitleOrThrow';
|
||||
import { useMouseOverFormField } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { useNodeTemplateTitleOrThrow } from 'features/nodes/hooks/useNodeTemplateTitleOrThrow';
|
||||
import { useNodeUserTitleOrThrow } from 'features/nodes/hooks/useNodeUserTitleOrThrow';
|
||||
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
|
||||
import { selectHasBatchOrGeneratorNodes } from 'features/nodes/store/selectors';
|
||||
import { selectIsWorkflowSaved } from 'features/nodes/store/workflowSlice';
|
||||
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
|
||||
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowLineRightBold, PiLightningFill, PiXBold } from 'react-icons/pi';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const PublishWorkflowPanelContent = memo(() => {
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} h="full">
|
||||
<ButtonGroup isAttached={false} size="sm" variant="ghost">
|
||||
<Spacer />
|
||||
<CancelPublishButton />
|
||||
<PublishWorkflowButton />
|
||||
</ButtonGroup>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={2} w="full" h="full">
|
||||
<OutputFields />
|
||||
<PublishableInputFields />
|
||||
<UnpublishableInputFields />
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishWorkflowPanelContent.displayName = 'PublishWorkflowPanelContent';
|
||||
|
||||
const OutputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Flex alignItems="center">
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
|
||||
<Spacer />
|
||||
<SelectOutputNodeButton variant="link" size="sm" />
|
||||
</Flex>
|
||||
|
||||
<Divider />
|
||||
{!outputNodeId && (
|
||||
<Text fontWeight="semibold" color="error.300">
|
||||
{t('workflows.builder.noOutputNodeSelected')}
|
||||
</Text>
|
||||
)}
|
||||
{outputNodeId && <OutputFieldsContent outputNodeId={outputNodeId} />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
OutputFields.displayName = 'OutputFields';
|
||||
|
||||
const OutputFieldsContent = memo(({ outputNodeId }: { outputNodeId: string }) => {
|
||||
const outputFieldNames = useOutputFieldNames(outputNodeId);
|
||||
|
||||
return (
|
||||
<>
|
||||
{outputFieldNames.map((fieldName) => (
|
||||
<NodeOutputFieldPreview key={`${outputNodeId}-${fieldName}`} nodeId={outputNodeId} fieldName={fieldName} />
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
OutputFieldsContent.displayName = 'OutputFieldsContent';
|
||||
|
||||
const PublishableInputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
if (inputs.publishable.length === 0) {
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold" color="warning.300">
|
||||
{t('workflows.builder.noPublishableInputs')}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowInputs')}</Text>
|
||||
<Divider />
|
||||
{inputs.publishable.map(({ nodeId, fieldName }) => {
|
||||
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
|
||||
})}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
PublishableInputFields.displayName = 'PublishableInputFields';
|
||||
|
||||
const UnpublishableInputFields = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
if (inputs.unpublishable.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
|
||||
<Text fontWeight="semibold" color="warning.300">
|
||||
{t('workflows.builder.unpublishableInputs')}
|
||||
</Text>
|
||||
<Divider />
|
||||
{inputs.unpublishable.map(({ nodeId, fieldName }) => {
|
||||
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
|
||||
})}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
UnpublishableInputFields.displayName = 'UnpublishableInputFields';
|
||||
|
||||
const SelectOutputNodeButton = memo((props: ButtonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const onClick = useCallback(() => {
|
||||
$outputNodeId.set(null);
|
||||
$isSelectingOutputNode.set(true);
|
||||
}, []);
|
||||
return (
|
||||
<Button
|
||||
leftIcon={<PiArrowLineRightBold />}
|
||||
isDisabled={isSelectingOutputNode}
|
||||
tooltip={isSelectingOutputNode ? t('workflows.builder.selectingOutputNodeDesc') : undefined}
|
||||
onClick={onClick}
|
||||
{...props}
|
||||
>
|
||||
{isSelectingOutputNode
|
||||
? t('workflows.builder.selectingOutputNode')
|
||||
: outputNodeId
|
||||
? t('workflows.builder.changeOutputNode')
|
||||
: t('workflows.builder.selectOutputNode')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
SelectOutputNodeButton.displayName = 'SelectOutputNodeButton';
|
||||
|
||||
const CancelPublishButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const onClick = useCallback(() => {
|
||||
$isInPublishFlow.set(false);
|
||||
$isSelectingOutputNode.set(false);
|
||||
$outputNodeId.set(null);
|
||||
}, []);
|
||||
return (
|
||||
<Button leftIcon={<PiXBold />} onClick={onClick}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
CancelPublishButton.displayName = 'CancelDeployButton';
|
||||
|
||||
const PublishWorkflowButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isReadyToDoValidationRun = useStore($isReadyToDoValidationRun);
|
||||
const isReadyToEnqueue = useStore($isReadyToEnqueue);
|
||||
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
|
||||
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
|
||||
const outputNodeId = useStore($outputNodeId);
|
||||
const isSelectingOutputNode = useStore($isSelectingOutputNode);
|
||||
const inputs = usePublishInputs();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
|
||||
const projectUrl = useStore($projectUrl);
|
||||
|
||||
const enqueue = useEnqueueWorkflows();
|
||||
const onClick = useCallback(async () => {
|
||||
const result = await withResultAsync(() => enqueue(true, true));
|
||||
if (result.isErr()) {
|
||||
toast({
|
||||
id: 'TOAST_PUBLISH_FAILED',
|
||||
status: 'error',
|
||||
title: t('workflows.builder.publishFailed'),
|
||||
description: t('workflows.builder.publishFailedDesc'),
|
||||
duration: null,
|
||||
});
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
|
||||
} else {
|
||||
toast({
|
||||
id: 'TOAST_PUBLISH_SUCCESSFUL',
|
||||
status: 'success',
|
||||
title: t('workflows.builder.publishSuccess'),
|
||||
description: (
|
||||
<Trans
|
||||
i18nKey="workflows.builder.publishSuccessDesc"
|
||||
components={{
|
||||
LinkComponent: <ExternalLink href={projectUrl ?? ''} />,
|
||||
}}
|
||||
/>
|
||||
),
|
||||
duration: null,
|
||||
});
|
||||
assert(result.value.enqueueResult.batch.batch_id);
|
||||
$validationRunBatchId.set(result.value.enqueueResult.batch.batch_id);
|
||||
log.debug(parseify(result.value), 'Enqueued batch');
|
||||
}
|
||||
}, [enqueue, projectUrl, t]);
|
||||
|
||||
return (
|
||||
<PublishTooltip
|
||||
isWorkflowSaved={isWorkflowSaved}
|
||||
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
|
||||
isReadyToEnqueue={isReadyToEnqueue}
|
||||
hasOutputNode={outputNodeId !== null && !isSelectingOutputNode}
|
||||
hasPublishableInputs={inputs.publishable.length > 0}
|
||||
hasUnpublishableInputs={inputs.unpublishable.length > 0}
|
||||
>
|
||||
<Button
|
||||
leftIcon={<PiLightningFill />}
|
||||
isDisabled={
|
||||
!allowPublishWorkflows ||
|
||||
!isReadyToEnqueue ||
|
||||
!isWorkflowSaved ||
|
||||
hasBatchOrGeneratorNodes ||
|
||||
!isReadyToDoValidationRun ||
|
||||
!(outputNodeId !== null && !isSelectingOutputNode)
|
||||
}
|
||||
onClick={onClick}
|
||||
>
|
||||
{t('workflows.builder.publish')}
|
||||
</Button>
|
||||
</PublishTooltip>
|
||||
);
|
||||
});
|
||||
PublishWorkflowButton.displayName = 'DoValidationRunButton';
|
||||
|
||||
const NodeInputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
|
||||
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
|
||||
const fieldUserTitle = useInputFieldUserTitleOrThrow(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
position="relative"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
onMouseOver={mouseOverFormField.handleMouseOver}
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
onClick={zoomToNode}
|
||||
>
|
||||
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldUserTitle || fieldTemplateTitle}`}</Text>
|
||||
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
|
||||
<NodeFieldElementOverlay nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
NodeInputFieldPreview.displayName = 'NodeInputFieldPreview';
|
||||
|
||||
const NodeOutputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
|
||||
const mouseOverFormField = useMouseOverFormField(nodeId);
|
||||
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
|
||||
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
|
||||
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
position="relative"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
onMouseOver={mouseOverFormField.handleMouseOver}
|
||||
onMouseOut={mouseOverFormField.handleMouseOut}
|
||||
onClick={zoomToNode}
|
||||
>
|
||||
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldTemplate.title}`}</Text>
|
||||
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
|
||||
<NodeFieldElementOverlay nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
NodeOutputFieldPreview.displayName = 'NodeOutputFieldPreview';
|
||||
|
||||
export const StartPublishFlowButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
const isReadyToEnqueue = useStore($isReadyToEnqueue);
|
||||
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
|
||||
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
|
||||
const inputs = usePublishInputs();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
$isInPublishFlow.set(true);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<PublishTooltip
|
||||
isWorkflowSaved={isWorkflowSaved}
|
||||
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
|
||||
isReadyToEnqueue={isReadyToEnqueue}
|
||||
hasOutputNode={true}
|
||||
hasPublishableInputs={inputs.publishable.length > 0}
|
||||
hasUnpublishableInputs={inputs.unpublishable.length > 0}
|
||||
>
|
||||
<Button
|
||||
onClick={onClick}
|
||||
leftIcon={<PiLightningFill />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
isDisabled={!allowPublishWorkflows || !isReadyToEnqueue || !isWorkflowSaved || hasBatchOrGeneratorNodes}
|
||||
>
|
||||
{t('workflows.builder.publish')}
|
||||
</Button>
|
||||
</PublishTooltip>
|
||||
);
|
||||
});
|
||||
|
||||
StartPublishFlowButton.displayName = 'StartPublishFlowButton';
|
||||
|
||||
const PublishTooltip = memo(
|
||||
({
|
||||
isWorkflowSaved,
|
||||
hasBatchOrGeneratorNodes,
|
||||
isReadyToEnqueue,
|
||||
hasOutputNode,
|
||||
hasPublishableInputs,
|
||||
hasUnpublishableInputs,
|
||||
children,
|
||||
}: PropsWithChildren<{
|
||||
isWorkflowSaved: boolean;
|
||||
hasBatchOrGeneratorNodes: boolean;
|
||||
isReadyToEnqueue: boolean;
|
||||
hasOutputNode: boolean;
|
||||
hasPublishableInputs: boolean;
|
||||
hasUnpublishableInputs: boolean;
|
||||
}>) => {
|
||||
const { t } = useTranslation();
|
||||
const warnings = useMemo(() => {
|
||||
const _warnings: string[] = [];
|
||||
if (!hasPublishableInputs) {
|
||||
_warnings.push(t('workflows.builder.warningWorkflowHasNoPublishableInputFields'));
|
||||
}
|
||||
if (hasUnpublishableInputs) {
|
||||
_warnings.push(t('workflows.builder.warningWorkflowHasUnpublishableInputFields'));
|
||||
}
|
||||
return _warnings;
|
||||
}, [hasPublishableInputs, hasUnpublishableInputs, t]);
|
||||
const errors = useMemo(() => {
|
||||
const _errors: string[] = [];
|
||||
if (!isWorkflowSaved) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasUnsavedChanges'));
|
||||
}
|
||||
if (hasBatchOrGeneratorNodes) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasBatchOrGeneratorNodes'));
|
||||
}
|
||||
if (!isReadyToEnqueue) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasInvalidGraph'));
|
||||
}
|
||||
if (!hasOutputNode) {
|
||||
_errors.push(t('workflows.builder.errorWorkflowHasNoOutputNode'));
|
||||
}
|
||||
return _errors;
|
||||
}, [hasBatchOrGeneratorNodes, hasOutputNode, isReadyToEnqueue, isWorkflowSaved, t]);
|
||||
|
||||
if (errors.length === 0 && warnings.length === 0) {
|
||||
return children;
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={
|
||||
<Flex flexDir="column">
|
||||
{errors.length > 0 && (
|
||||
<>
|
||||
<Text color="error.700" fontWeight="semibold">
|
||||
{t('workflows.builder.cannotPublish')}:
|
||||
</Text>
|
||||
<UnorderedList>
|
||||
{errors.map((problem, index) => (
|
||||
<ListItem key={index}>{problem}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</>
|
||||
)}
|
||||
{warnings.length > 0 && (
|
||||
<>
|
||||
<Text color="warning.700" fontWeight="semibold">
|
||||
{t('workflows.builder.publishWarnings')}:
|
||||
</Text>
|
||||
<UnorderedList>
|
||||
{warnings.map((problem, index) => (
|
||||
<ListItem key={index}>{problem}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
);
|
||||
PublishTooltip.displayName = 'PublishTooltip';
|
||||
@@ -0,0 +1,23 @@
|
||||
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLockBold } from 'react-icons/pi';
|
||||
|
||||
export const LockedWorkflowIcon = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Tooltip label={t('workflows.builder.publishedWorkflowsLocked')} closeOnScroll>
|
||||
<IconButton
|
||||
size="sm"
|
||||
cursor="not-allowed"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('workflows.builder.publishedWorkflowsLocked')}
|
||||
icon={<PiLockBold />}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
LockedWorkflowIcon.displayName = 'LockedWorkflowIcon';
|
||||
@@ -26,6 +26,7 @@ import {
|
||||
workflowLibraryTagToggled,
|
||||
workflowLibraryViewChanged,
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
@@ -39,13 +40,12 @@ export const WorkflowLibrarySideNav = () => {
|
||||
const { t } = useTranslation();
|
||||
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
|
||||
const view = useAppSelector(selectWorkflowLibraryView);
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
|
||||
return (
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
|
||||
<Flex flexDir="column" w="full" pb={2}>
|
||||
<Flex flexDir="column" w="full" pb={2} gap={2}>
|
||||
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
|
||||
</Flex>
|
||||
<Flex flexDir="column" w="full" pb={2}>
|
||||
<WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>
|
||||
{categoryOptions.includes('project') && (
|
||||
<Collapse in={view === 'yours' || view === 'shared' || view === 'private'}>
|
||||
@@ -60,6 +60,9 @@ export const WorkflowLibrarySideNav = () => {
|
||||
</Flex>
|
||||
</Collapse>
|
||||
)}
|
||||
{allowPublishWorkflows && (
|
||||
<WorkflowLibraryViewButton view="published">{t('workflows.published')}</WorkflowLibraryViewButton>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
|
||||
<BrowseWorkflowsButton />
|
||||
|
||||
@@ -36,6 +36,8 @@ const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => {
|
||||
return ['user'];
|
||||
case 'shared':
|
||||
return ['project'];
|
||||
case 'published':
|
||||
return ['user', 'project', 'default'];
|
||||
default:
|
||||
assert<Equals<typeof view, never>>(false);
|
||||
}
|
||||
@@ -66,6 +68,7 @@ const useInfiniteQueryAry = () => {
|
||||
query: debouncedSearchTerm,
|
||||
tags: view === 'defaults' ? selectedTags : [],
|
||||
has_been_opened: getHasBeenOpened(view),
|
||||
is_published: view === 'published' ? true : undefined,
|
||||
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
|
||||
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { LockedWorkflowIcon } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/LockedWorkflowIcon';
|
||||
import { ShareWorkflowButton } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/ShareWorkflow';
|
||||
import { selectWorkflowId, workflowModeChanged } from 'features/nodes/store/workflowSlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
@@ -54,7 +55,6 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
position="relative"
|
||||
role="button"
|
||||
onClick={handleClickLoad}
|
||||
cursor="pointer"
|
||||
bg="base.750"
|
||||
borderRadius="base"
|
||||
w="full"
|
||||
@@ -81,7 +81,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
<Flex gap={2} alignItems="flex-start" justifyContent="space-between" w="full">
|
||||
<Text noOfLines={2}>{workflow.name}</Text>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{isActive && (
|
||||
{isActive && !workflow.is_published && (
|
||||
<Badge
|
||||
color="invokeBlue.400"
|
||||
borderColor="invokeBlue.700"
|
||||
@@ -93,6 +93,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
{t('workflows.opened')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.is_published && (
|
||||
<Badge
|
||||
color="invokeGreen.400"
|
||||
borderColor="invokeGreen.700"
|
||||
borderWidth={1}
|
||||
bg="transparent"
|
||||
flexShrink={0}
|
||||
variant="subtle"
|
||||
>
|
||||
{t('workflows.builder.published')}
|
||||
</Badge>
|
||||
)}
|
||||
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
|
||||
{workflow.category === 'default' && (
|
||||
<Image
|
||||
@@ -119,8 +131,10 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
</Text>
|
||||
)}
|
||||
<Spacer />
|
||||
{workflow.category === 'default' && <ViewWorkflow workflowId={workflow.workflow_id} />}
|
||||
{workflow.category !== 'default' && (
|
||||
{workflow.category === 'default' && !workflow.is_published && (
|
||||
<ViewWorkflow workflowId={workflow.workflow_id} />
|
||||
)}
|
||||
{workflow.category !== 'default' && !workflow.is_published && (
|
||||
<>
|
||||
<EditWorkflow workflowId={workflow.workflow_id} />
|
||||
<DownloadWorkflow workflowId={workflow.workflow_id} />
|
||||
@@ -128,6 +142,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
|
||||
</>
|
||||
)}
|
||||
{workflow.category === 'project' && <ShareWorkflowButton workflow={workflow} />}
|
||||
{workflow.is_published && <LockedWorkflowIcon />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { WorkflowBuilder } from 'features/nodes/components/sidePanel/builder/WorkflowBuilder';
|
||||
import { StartPublishFlowButton } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
|
||||
import { selectAllowPublishWorkflows } from 'features/system/store/configSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -8,12 +11,15 @@ import WorkflowJSONTab from './WorkflowJSONTab';
|
||||
|
||||
const WorkflowFieldsLinearViewPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
const allowPublishWorkflows = useAppSelector(selectAllowPublishWorkflows);
|
||||
return (
|
||||
<Tabs variant="enclosed" display="flex" w="full" h="full" flexDir="column">
|
||||
<TabList>
|
||||
<Tab>{t('workflows.builder.builder')}</Tab>
|
||||
<Tab>{t('common.details')}</Tab>
|
||||
<Tab>JSON</Tab>
|
||||
<Spacer />
|
||||
{allowPublishWorkflows && <StartPublishFlowButton />}
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="full" pt={2}>
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowFormNodeFieldFieldIdentifiersDeduped } from 'features/nodes/store/workflowSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { isBoardFieldType } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetBatchStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const $isInPublishFlow = atom(false);
|
||||
export const $outputNodeId = atom<string | null>(null);
|
||||
export const $isSelectingOutputNode = atom(false);
|
||||
export const $isReadyToDoValidationRun = computed(
|
||||
[$isInPublishFlow, $outputNodeId, $isSelectingOutputNode],
|
||||
(isInPublishFlow, outputNodeId, isSelectingOutputNode) => {
|
||||
return isInPublishFlow && outputNodeId !== null && !isSelectingOutputNode;
|
||||
}
|
||||
);
|
||||
export const $validationRunBatchId = atom<string | null>(null);
|
||||
|
||||
export const useIsValidationRunInProgress = () => {
|
||||
const validationRunBatchId = useStore($validationRunBatchId);
|
||||
const { isValidationRunInProgress } = useGetBatchStatusQuery(
|
||||
validationRunBatchId ? { batch_id: validationRunBatchId } : skipToken,
|
||||
{
|
||||
selectFromResult: ({ currentData }) => {
|
||||
if (!currentData) {
|
||||
return { isValidationRunInProgress: false };
|
||||
}
|
||||
if (currentData && currentData.in_progress > 0) {
|
||||
return { isValidationRunInProgress: true };
|
||||
}
|
||||
return { isValidationRunInProgress: false };
|
||||
},
|
||||
}
|
||||
);
|
||||
return validationRunBatchId !== null || isValidationRunInProgress;
|
||||
};
|
||||
|
||||
export const selectFieldIdentifiersWithInvocationTypes = createSelector(
|
||||
selectWorkflowFormNodeFieldFieldIdentifiersDeduped,
|
||||
selectNodesSlice,
|
||||
(fieldIdentifiers, nodes) => {
|
||||
const result: { nodeId: string; fieldName: string; type: string }[] = [];
|
||||
for (const fieldIdentifier of fieldIdentifiers) {
|
||||
const node = nodes.nodes.find((node) => node.id === fieldIdentifier.nodeId);
|
||||
assert(isInvocationNode(node), `Node ${fieldIdentifier.nodeId} not found`);
|
||||
result.push({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName, type: node.data.type });
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
);
|
||||
|
||||
export const getPublishInputs = (fieldIdentifiers: (FieldIdentifier & { type: string })[], templates: Templates) => {
|
||||
// Certain field types are not allowed to be input fields on a published workflow
|
||||
const publishable: FieldIdentifier[] = [];
|
||||
const unpublishable: FieldIdentifier[] = [];
|
||||
for (const fieldIdentifier of fieldIdentifiers) {
|
||||
const fieldTemplate = templates[fieldIdentifier.type]?.inputs[fieldIdentifier.fieldName];
|
||||
if (!fieldTemplate) {
|
||||
unpublishable.push(fieldIdentifier);
|
||||
continue;
|
||||
}
|
||||
if (isBoardFieldType(fieldTemplate.type)) {
|
||||
unpublishable.push(fieldIdentifier);
|
||||
continue;
|
||||
}
|
||||
publishable.push(fieldIdentifier);
|
||||
}
|
||||
return { publishable, unpublishable };
|
||||
};
|
||||
|
||||
export const usePublishInputs = () => {
|
||||
const templates = useStore($templates);
|
||||
const fieldIdentifiersWithInvocationTypes = useAppSelector(selectFieldIdentifiersWithInvocationTypes);
|
||||
const fieldIdentifiers = useMemo(
|
||||
() => getPublishInputs(fieldIdentifiersWithInvocationTypes, templates),
|
||||
[fieldIdentifiersWithInvocationTypes, templates]
|
||||
);
|
||||
|
||||
return fieldIdentifiers;
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
|
||||
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { isSingleOrCollection } from 'features/nodes/types/field';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
const isConnectionInputField = (field: FieldInputTemplate) => {
|
||||
return (
|
||||
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
|
||||
@@ -19,7 +20,7 @@ const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
|
||||
};
|
||||
|
||||
export const useInputFieldNamesMissing = (nodeId: string) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const node = useNodeData(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const instanceFields = new Set(Object.keys(node.inputs));
|
||||
@@ -30,7 +31,7 @@ export const useInputFieldNamesMissing = (nodeId: string) => {
|
||||
};
|
||||
|
||||
export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const anyOrDirectFields: string[] = [];
|
||||
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
|
||||
@@ -44,7 +45,7 @@ export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
|
||||
};
|
||||
|
||||
export const useInputFieldNamesConnection = (nodeId: string) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const fieldNames = useMemo(() => {
|
||||
const connectionFields: string[] = [];
|
||||
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
@@ -13,7 +14,7 @@ import { assert } from 'tsafe';
|
||||
* @throws Will throw an error if the template for the input field is not found.
|
||||
*/
|
||||
export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string): FieldInputTemplate => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const fieldTemplate = useMemo(() => {
|
||||
const _fieldTemplate = template.inputs[fieldName];
|
||||
assert(_fieldTemplate, `Template for input field ${fieldName} not found.`);
|
||||
@@ -21,17 +22,3 @@ export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string):
|
||||
}, [fieldName, template.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
|
||||
*
|
||||
* @param nodeId - The ID of the node.
|
||||
* @param fieldName - The name of the input field.
|
||||
*/
|
||||
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
@@ -0,0 +1,17 @@
|
||||
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplateSafe';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
|
||||
*
|
||||
* @param nodeId - The ID of the node.
|
||||
* @param fieldName - The name of the input field.
|
||||
*/
|
||||
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const template = useNodeTemplateSafe(nodeId);
|
||||
const fieldTemplate = useMemo(() => template?.inputs[fieldName] ?? null, [fieldName, template?.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
@@ -1,9 +1,10 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useInputFieldTemplateTitle = (nodeId: string, fieldName: string): string => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useInputFieldTemplateTitleOrThrow = (nodeId: string, fieldName: string): string => {
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
|
||||
const title = useMemo(() => {
|
||||
const fieldTemplate = template.inputs[fieldName];
|
||||
@@ -11,7 +11,7 @@ import { useMemo } from 'react';
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldDescriptionSafe = (nodeId: string, fieldName: string) => {
|
||||
export const useInputFieldUserDescriptionSafe = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
@@ -0,0 +1,23 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Gets the user-defined title of an input field for a given node.
|
||||
*
|
||||
* If the node doesn't exist or is not an invocation node, an error is thrown.
|
||||
*
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldUserTitleOrThrow = (nodeId: string, fieldName: string): string => {
|
||||
const selector = useMemo(
|
||||
() => createSelector(selectNodesSlice, (nodes) => selectFieldInputInstance(nodes, nodeId, fieldName).label),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const title = useAppSelector(selector);
|
||||
|
||||
return title;
|
||||
};
|
||||
@@ -4,21 +4,21 @@ import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/s
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Gets the user-defined label of an input field for a given node.
|
||||
* Gets the user-defined title of an input field for a given node.
|
||||
*
|
||||
* If the node doesn't exist or is not an invocation node, an empty string is returned.
|
||||
*
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldLabelSafe = (nodeId: string, fieldName: string): string => {
|
||||
export const useInputFieldUserTitleSafe = (nodeId: string, fieldName: string): string => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.label ?? ''),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const label = useAppSelector(selector);
|
||||
const title = useAppSelector(selector);
|
||||
|
||||
return label;
|
||||
return title;
|
||||
};
|
||||
@@ -1,9 +1,10 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { isBatchNodeType, isGeneratorNodeType } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useIsExecutableNode = (nodeId: string) => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const isExecutableNode = useMemo(
|
||||
() => !isBatchNodeType(template.type) && !isGeneratorNodeType(template.type),
|
||||
[template]
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $isInPublishFlow, useIsValidationRunInProgress } from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
|
||||
|
||||
export const useIsWorkflowEditorLocked = () => {
|
||||
const isInPublishFlow = useStore($isInPublishFlow);
|
||||
const isPublished = useAppSelector(selectWorkflowIsPublished);
|
||||
const isValidationRunInProgress = useIsValidationRunInProgress();
|
||||
|
||||
const isLocked = isInPublishFlow || isPublished || isValidationRunInProgress;
|
||||
return isLocked;
|
||||
};
|
||||
@@ -1,9 +1,10 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { Classification } from 'features/nodes/types/common';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useNodeClassification = (nodeId: string): Classification => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const classification = useMemo(() => template.classification, [template]);
|
||||
return classification;
|
||||
};
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { some } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useNodeHasImageOutput = (nodeId: string): boolean => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const hasImageOutput = useMemo(
|
||||
() =>
|
||||
some(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeType } from 'features/nodes/hooks/useNodeType';
|
||||
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useNodeNeedsUpdate = (nodeId: string) => {
|
||||
const type = useNodeType(nodeId);
|
||||
const version = useNodeVersion(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const needsUpdate = useMemo(() => {
|
||||
if (type !== template.type) {
|
||||
return true;
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useNodeTemplate = (nodeId: string): InvocationTemplate => {
|
||||
export const useNodeTemplateOrThrow = (nodeId: string): InvocationTemplate => {
|
||||
const templates = useStore($templates);
|
||||
const type = useNodeType(nodeId);
|
||||
const template = useMemo(() => {
|
||||
@@ -15,10 +15,3 @@ export const useNodeTemplate = (nodeId: string): InvocationTemplate => {
|
||||
}, [templates, type]);
|
||||
return template;
|
||||
};
|
||||
|
||||
export const useNodeTemplateSafe = (nodeId: string): InvocationTemplate | null => {
|
||||
const templates = useStore($templates);
|
||||
const type = useNodeType(nodeId);
|
||||
const template = useMemo(() => templates[type] ?? null, [templates, type]);
|
||||
return template;
|
||||
};
|
||||
@@ -0,0 +1,12 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useNodeType } from 'features/nodes/hooks/useNodeType';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplateSafe = (nodeId: string): InvocationTemplate | null => {
|
||||
const templates = useStore($templates);
|
||||
const type = useNodeType(nodeId);
|
||||
const template = useMemo(() => templates[type] ?? null, [templates, type]);
|
||||
return template;
|
||||
};
|
||||
@@ -0,0 +1,25 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useNodeTemplateTitleOrThrow = (nodeId: string): string => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(isInvocationNode(node), 'Node not found');
|
||||
const template = templates[node.data.type];
|
||||
assert(template, 'Template not found');
|
||||
return template.title;
|
||||
}),
|
||||
[nodeId, templates]
|
||||
);
|
||||
const title = useAppSelector(selector);
|
||||
return title;
|
||||
};
|
||||
@@ -6,7 +6,7 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplateTitle = (nodeId: string): string | null => {
|
||||
export const useNodeTemplateTitleSafe = (nodeId: string): string | null => {
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
@@ -0,0 +1,21 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useNodeUserTitleOrThrow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(isInvocationNode(node), 'Node not found');
|
||||
return node.data.label;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const title = useAppSelector(selector);
|
||||
return title;
|
||||
};
|
||||
@@ -3,16 +3,16 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeLabel = (nodeId: string) => {
|
||||
export const useNodeUserTitleSafe = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
return node?.data.label;
|
||||
return node?.data.label ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const label = useAppSelector(selector);
|
||||
return label;
|
||||
const title = useAppSelector(selector);
|
||||
return title;
|
||||
};
|
||||
@@ -1,10 +1,11 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useOutputFieldNames = (nodeId: string): string[] => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const fieldNames = useMemo(() => getSortedFilteredFieldNames(map(template.outputs)), [template.outputs]);
|
||||
return fieldNames;
|
||||
};
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import type { FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
|
||||
|
||||
export const useOutputFieldTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const template = useNodeTemplateOrThrow(nodeId);
|
||||
const fieldTemplate = useMemo(() => {
|
||||
const _fieldTemplate = template.outputs[fieldName];
|
||||
assert(_fieldTemplate, `Template for output field ${fieldName} not found`);
|
||||
|
||||
@@ -4,14 +4,14 @@ import { useCallback } from 'react';
|
||||
|
||||
const log = logger('workflows');
|
||||
|
||||
export const useZoomToNode = () => {
|
||||
const zoomToNode = useCallback((nodeId: string) => {
|
||||
export const useZoomToNode = (nodeId: string) => {
|
||||
const zoomToNode = useCallback(() => {
|
||||
const flow = $flow.get();
|
||||
if (!flow) {
|
||||
log.warn('No flow instance found, cannot zoom to node');
|
||||
return;
|
||||
}
|
||||
flow.fitView({ duration: 300, maxZoom: 1.5, nodes: [{ id: nodeId }] });
|
||||
}, []);
|
||||
}, [nodeId]);
|
||||
return zoomToNode;
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@ import type { RootState } from 'app/store/store';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isGeneratorNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const selectNode = (nodesSlice: NodesState, nodeId: string): AnyNode => {
|
||||
@@ -81,3 +81,7 @@ export const selectMayRedo = createSelector(
|
||||
(state: RootState) => state.nodes,
|
||||
(nodes) => nodes.future.length > 0
|
||||
);
|
||||
|
||||
export const selectHasBatchOrGeneratorNodes = createSelector(selectNodes, (nodes) =>
|
||||
nodes.filter(isInvocationNode).some((node) => isBatchNode(node) || isGeneratorNode(node))
|
||||
);
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { SQLiteDirection, WorkflowRecordOrderBy } from 'services/api/types';
|
||||
|
||||
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults';
|
||||
export type WorkflowLibraryView = 'recent' | 'yours' | 'private' | 'shared' | 'defaults' | 'published';
|
||||
|
||||
type WorkflowLibraryState = {
|
||||
view: WorkflowLibraryView;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user