mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 17:48:13 -05:00
Compare commits
37 Commits
psychedeli
...
feat/contr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51e44d6640 | ||
|
|
44a80a4929 | ||
|
|
e625c44c73 | ||
|
|
c764e34883 | ||
|
|
5d31df0cb7 | ||
|
|
bd63454e51 | ||
|
|
062df07de2 | ||
|
|
0fc14afcf0 | ||
|
|
4a0a1c30db | ||
|
|
3432fd72f8 | ||
|
|
05a43c41f9 | ||
|
|
bb48617101 | ||
|
|
aa2f68f608 | ||
|
|
fbccce7573 | ||
|
|
a35087ee6e | ||
|
|
03e463dc89 | ||
|
|
d467e138a4 | ||
|
|
ba4aaea45b | ||
|
|
53eb23b8b6 | ||
|
|
8b969053e7 | ||
|
|
98a076260b | ||
|
|
b3f4f28d76 | ||
|
|
acee4bd282 | ||
|
|
50d254fdb7 | ||
|
|
0cfc1c5f86 | ||
|
|
1419977e89 | ||
|
|
a953944894 | ||
|
|
a4cdaa245e | ||
|
|
105a4234b0 | ||
|
|
34c563060f | ||
|
|
d45c47db81 | ||
|
|
c771a4027f | ||
|
|
3fd27b1aa9 | ||
|
|
d59e534cad | ||
|
|
0c97a1e7e7 | ||
|
|
c8b109f52e | ||
|
|
399ebe443e |
@@ -332,6 +332,7 @@ class InvokeAiInstance:
|
||||
Configure the InvokeAI runtime directory
|
||||
"""
|
||||
|
||||
auto_install = False
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1, len(sys.argv)):
|
||||
@@ -340,13 +341,17 @@ class InvokeAiInstance:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
new_argv.append(el)
|
||||
auto_install = True
|
||||
sys.argv = new_argv
|
||||
|
||||
import messages
|
||||
import requests # to catch download exceptions
|
||||
from messages import introduction
|
||||
|
||||
introduction()
|
||||
auto_install = auto_install or messages.user_wants_auto_configuration()
|
||||
if auto_install:
|
||||
sys.argv.append("--yes")
|
||||
else:
|
||||
messages.introduction()
|
||||
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit import HTML, prompt
|
||||
from prompt_toolkit.completion import PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
@@ -65,17 +65,50 @@ def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||
dest_confirmed = Confirm.ask(
|
||||
":stop_sign: Are you sure you want to (re)install in this location?",
|
||||
":stop_sign: (re)install in this location?",
|
||||
default=False,
|
||||
)
|
||||
else:
|
||||
print(f"InvokeAI will be installed in {dest}")
|
||||
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
||||
dest_confirmed = Confirm.ask("Use this location?", default=True)
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
|
||||
|
||||
def user_wants_auto_configuration() -> bool:
|
||||
"""Prompt the user to choose between manual and auto configuration."""
|
||||
console.rule("InvokeAI Configuration Section")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
|
||||
"",
|
||||
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
|
||||
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
|
||||
"",
|
||||
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
|
||||
]
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = (
|
||||
prompt(
|
||||
HTML("Choose <b><a></b>utomatic or <b><m></b>anual configuration [a/m] (a): "),
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
|
||||
),
|
||||
)
|
||||
or "a"
|
||||
)
|
||||
return choice.lower().startswith("a")
|
||||
|
||||
|
||||
def dest_path(dest=None) -> Path:
|
||||
"""
|
||||
Prompt the user for the destination path and create the path
|
||||
|
||||
@@ -91,6 +91,9 @@ class FieldDescriptions:
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
inclusive_low = "The inclusive low value"
|
||||
exclusive_high = "The exclusive high value"
|
||||
decimal_places = "The number of decimal places to round to"
|
||||
|
||||
|
||||
class Input(str, Enum):
|
||||
|
||||
@@ -6,9 +6,11 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
ContentShuffleDetector,
|
||||
DWposeDetector,
|
||||
HEDdetector,
|
||||
LeresDetector,
|
||||
LineartAnimeDetector,
|
||||
@@ -125,7 +127,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||
"image_processor", title="Base Image Processorwp", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
@@ -589,3 +591,29 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map = Image.fromarray(color_map)
|
||||
return color_map
|
||||
|
||||
|
||||
@invocation(
|
||||
"dwpose_image_processor",
|
||||
title="DWPose Processor",
|
||||
tags=["controlnet", "dwpose", "pose"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DWPoseImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies DW-Pose processing to image"""
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# for now, executing DWPose processing on CPU only
|
||||
device = "cpu"
|
||||
dwpose_processor = DWposeDetector(device=device)
|
||||
processed_image = dwpose_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
@@ -65,13 +65,27 @@ class DivideInvocation(BaseInvocation):
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
||||
|
||||
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
|
||||
class RandomFloatInvocation(BaseInvocation):
|
||||
"""Outputs a single random float"""
|
||||
|
||||
low: float = InputField(default=0.0, description=FieldDescriptions.inclusive_low)
|
||||
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
|
||||
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
random_float = np.random.uniform(self.low, self.high)
|
||||
rounded_float = round(random_float, self.decimals)
|
||||
return FloatOutput(value=rounded_float)
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_to_int",
|
||||
title="Float To Integer",
|
||||
|
||||
@@ -241,7 +241,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from time import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
@@ -59,7 +58,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
# If the cache is full, we need to remove the least used
|
||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(time(), invocation_output, invocation_output.json())
|
||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
number_to_delete = min(number_to_delete, len(self._cache))
|
||||
|
||||
@@ -70,7 +70,6 @@ def get_literal_fields(field) -> list[Any]:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
@@ -458,7 +457,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
@@ -651,8 +650,19 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_ramcache() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
|
||||
# So we adjust everthing down a bit.
|
||||
return (
|
||||
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
|
||||
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
opts.ram = default_ramcache()
|
||||
return opts
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ from diffusers.models import UNet2DConditionModel
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
||||
|
||||
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
||||
from .resampler import Resampler
|
||||
|
||||
@@ -87,6 +89,20 @@ class IPAdapter:
|
||||
if self._attn_processors is not None:
|
||||
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self):
|
||||
if self._state_dict is not None:
|
||||
image_proj_size = sum(
|
||||
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
|
||||
)
|
||||
ip_adapter_size = sum(
|
||||
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
|
||||
)
|
||||
return image_proj_size + ip_adapter_size
|
||||
else:
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
|
||||
torch.nn.ModuleList(self._attn_processors.values())
|
||||
)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from invokeai.backend.model_management.models.base import (
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ class IPAdapterModel(ModelBase):
|
||||
assert model_type == ModelType.IPAdapter
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
@@ -63,10 +64,13 @@ class IPAdapterModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
return build_ip_adapter(
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
|
||||
@@ -58,6 +58,7 @@
|
||||
"githubLabel": "Github",
|
||||
"hotkeysLabel": "Hotkeys",
|
||||
"imagePrompt": "Image Prompt",
|
||||
"imageFailedToLoad": "Unable to Load Image",
|
||||
"img2img": "Image To Image",
|
||||
"langArabic": "العربية",
|
||||
"langBrPortuguese": "Português do Brasil",
|
||||
@@ -79,7 +80,7 @@
|
||||
"lightMode": "Light Mode",
|
||||
"linear": "Linear",
|
||||
"load": "Load",
|
||||
"loading": "Loading $t({{noun}})...",
|
||||
"loading": "Loading",
|
||||
"loadingInvokeAI": "Loading Invoke AI",
|
||||
"learnMore": "Learn More",
|
||||
"modelManager": "Model Manager",
|
||||
@@ -716,6 +717,7 @@
|
||||
"cannotConnectInputToInput": "Cannot connect input to input",
|
||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||
"cannotConnectToSelf": "Cannot connect to self",
|
||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||
"clipField": "Clip",
|
||||
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
||||
"collection": "Collection",
|
||||
@@ -1442,6 +1444,8 @@
|
||||
"showCanvasDebugInfo": "Show Additional Canvas Info",
|
||||
"showGrid": "Show Grid",
|
||||
"showHide": "Show/Hide",
|
||||
"showResultsOn": "Show Results (On)",
|
||||
"showResultsOff": "Show Results (Off)",
|
||||
"showIntermediates": "Show Intermediates",
|
||||
"snapToGrid": "Snap to Grid",
|
||||
"undo": "Undo"
|
||||
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactNode, memo, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { theme as invokeAITheme } from 'theme/theme';
|
||||
import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
|
||||
|
||||
import '@fontsource-variable/inter';
|
||||
import { MantineProvider } from '@mantine/core';
|
||||
@@ -39,7 +39,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
|
||||
return (
|
||||
<MantineProvider theme={mantineTheme}>
|
||||
<ChakraProvider theme={theme} colorModeManager={manager}>
|
||||
<ChakraProvider
|
||||
theme={theme}
|
||||
colorModeManager={manager}
|
||||
toastOptions={TOAST_OPTIONS}
|
||||
>
|
||||
{children}
|
||||
</ChakraProvider>
|
||||
</MantineProvider>
|
||||
|
||||
@@ -54,21 +54,6 @@ import { addModelSelectedListener } from './listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
||||
import { addDynamicPromptsListener } from './listeners/promptChanged';
|
||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||
import {
|
||||
addSessionCanceledFulfilledListener,
|
||||
addSessionCanceledPendingListener,
|
||||
addSessionCanceledRejectedListener,
|
||||
} from './listeners/sessionCanceled';
|
||||
import {
|
||||
addSessionCreatedFulfilledListener,
|
||||
addSessionCreatedPendingListener,
|
||||
addSessionCreatedRejectedListener,
|
||||
} from './listeners/sessionCreated';
|
||||
import {
|
||||
addSessionInvokedFulfilledListener,
|
||||
addSessionInvokedPendingListener,
|
||||
addSessionInvokedRejectedListener,
|
||||
} from './listeners/sessionInvoked';
|
||||
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||
@@ -86,6 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
|
||||
import { addTabChangedListener } from './listeners/tabChanged';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@@ -136,6 +122,7 @@ addEnqueueRequestedCanvasListener();
|
||||
addEnqueueRequestedNodes();
|
||||
addEnqueueRequestedLinear();
|
||||
addAnyEnqueuedListener();
|
||||
addBatchEnqueuedListener();
|
||||
|
||||
// Canvas actions
|
||||
addCanvasSavedToGalleryListener();
|
||||
@@ -175,21 +162,6 @@ addSessionRetrievalErrorEventListener();
|
||||
addInvocationRetrievalErrorEventListener();
|
||||
addSocketQueueItemStatusChangedEventListener();
|
||||
|
||||
// Session Created
|
||||
addSessionCreatedPendingListener();
|
||||
addSessionCreatedFulfilledListener();
|
||||
addSessionCreatedRejectedListener();
|
||||
|
||||
// Session Invoked
|
||||
addSessionInvokedPendingListener();
|
||||
addSessionInvokedFulfilledListener();
|
||||
addSessionInvokedRejectedListener();
|
||||
|
||||
// Session Canceled
|
||||
addSessionCanceledPendingListener();
|
||||
addSessionCanceledFulfilledListener();
|
||||
addSessionCanceledRejectedListener();
|
||||
|
||||
// ControlNet
|
||||
addControlNetImageProcessedListener();
|
||||
addControlNetAutoProcessListener();
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
import { createStandaloneToast } from '@chakra-ui/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||
import { t } from 'i18next';
|
||||
import { get, truncate, upperFirst } from 'lodash-es';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { TOAST_OPTIONS, theme } from 'theme/theme';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const { toast } = createStandaloneToast({
|
||||
theme: theme,
|
||||
defaultOptions: TOAST_OPTIONS.defaultOptions,
|
||||
});
|
||||
|
||||
export const addBatchEnqueuedListener = () => {
|
||||
// success
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: async (action) => {
|
||||
const response = action.payload;
|
||||
const arg = action.meta.arg.originalArgs;
|
||||
logger('queue').debug(
|
||||
{ enqueueResult: parseify(response) },
|
||||
'Batch enqueued'
|
||||
);
|
||||
|
||||
if (!toast.isActive('batch-queued')) {
|
||||
toast({
|
||||
id: 'batch-queued',
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: response.enqueued,
|
||||
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
duration: 1000,
|
||||
status: 'success',
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// error
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
||||
effect: async (action) => {
|
||||
const response = action.payload;
|
||||
const arg = action.meta.arg.originalArgs;
|
||||
|
||||
if (!response) {
|
||||
toast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
description: 'Unknown Error',
|
||||
});
|
||||
logger('queue').error(
|
||||
{ batchConfig: parseify(arg), error: parseify(response) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zPydanticValidationError.safeParse(response);
|
||||
if (result.success) {
|
||||
result.data.data.detail.map((e) => {
|
||||
toast({
|
||||
id: 'batch-failed-to-queue',
|
||||
title: truncate(upperFirst(e.msg), { length: 128 }),
|
||||
status: 'error',
|
||||
description: truncate(
|
||||
`Path:
|
||||
${e.loc.join('.')}`,
|
||||
{ length: 128 }
|
||||
),
|
||||
});
|
||||
});
|
||||
} else {
|
||||
let detail = 'Unknown Error';
|
||||
if (response.status === 403 && 'body' in response) {
|
||||
detail = get(response, 'body.detail', 'Unknown Error');
|
||||
} else if (response.status === 403 && 'error' in response) {
|
||||
detail = get(response, 'error.detail', 'Unknown Error');
|
||||
}
|
||||
toast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
description: detail,
|
||||
});
|
||||
}
|
||||
logger('queue').error(
|
||||
{ batchConfig: parseify(arg), error: parseify(response) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -25,7 +25,7 @@ export const addBoardIdSelectedListener = () => {
|
||||
const state = getState();
|
||||
|
||||
const board_id = boardIdSelected.match(action)
|
||||
? action.payload
|
||||
? action.payload.boardId
|
||||
: state.gallery.selectedBoardId;
|
||||
|
||||
const galleryView = galleryViewChanged.match(action)
|
||||
@@ -55,7 +55,12 @@ export const addBoardIdSelectedListener = () => {
|
||||
|
||||
if (boardImagesData) {
|
||||
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
const selectedImage = imagesSelectors.selectById(
|
||||
boardImagesData,
|
||||
action.payload.selectedImageName
|
||||
);
|
||||
|
||||
dispatch(imageSelected(selectedImage || firstImage || null));
|
||||
} else {
|
||||
// board has no images - deselect
|
||||
dispatch(imageSelected(null));
|
||||
|
||||
@@ -3,9 +3,9 @@ import { canvasImageToControlNet } from 'features/canvas/store/actions';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addCanvasImageToControlNetListener = () => {
|
||||
startAppListening({
|
||||
@@ -16,7 +16,7 @@ export const addCanvasImageToControlNetListener = () => {
|
||||
|
||||
let blob;
|
||||
try {
|
||||
blob = await getBaseLayerBlob(state);
|
||||
blob = await getBaseLayerBlob(state, true);
|
||||
} catch (err) {
|
||||
log.error(String(err));
|
||||
dispatch(
|
||||
@@ -36,10 +36,10 @@ export const addCanvasImageToControlNetListener = () => {
|
||||
file: new File([blob], 'savedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
image_category: 'control',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
toastOptions: { title: t('toast.canvasSentControlnetAssets') },
|
||||
|
||||
@@ -3,9 +3,9 @@ import { canvasMaskToControlNet } from 'features/canvas/store/actions';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addCanvasMaskToControlNetListener = () => {
|
||||
startAppListening({
|
||||
@@ -50,7 +50,7 @@ export const addCanvasMaskToControlNetListener = () => {
|
||||
image_category: 'mask',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
toastOptions: { title: t('toast.maskSentControlnetAssets') },
|
||||
|
||||
@@ -12,8 +12,6 @@ import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGeneratio
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
@@ -140,8 +138,6 @@ export const addEnqueueRequestedCanvasListener = () => {
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
|
||||
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
||||
|
||||
// Prep the canvas staging area if it is not yet initialized
|
||||
@@ -158,28 +154,8 @@ export const addEnqueueRequestedCanvasListener = () => {
|
||||
|
||||
// Associate the session with the canvas session ID
|
||||
dispatch(canvasBatchIdAdded(batchId));
|
||||
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
||||
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
@@ -18,7 +14,6 @@ export const addEnqueueRequestedLinear = () => {
|
||||
(action.payload.tabName === 'txt2img' ||
|
||||
action.payload.tabName === 'img2img'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const log = logger('queue');
|
||||
const state = getState();
|
||||
const model = state.generation.model;
|
||||
const { prepend } = action.payload;
|
||||
@@ -41,38 +36,12 @@ export const addEnqueueRequestedLinear = () => {
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { BatchConfig } from 'services/api/types';
|
||||
import { startAppListening } from '..';
|
||||
@@ -13,9 +9,7 @@ export const addEnqueueRequestedNodes = () => {
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'nodes',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const log = logger('queue');
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
const graph = buildNodesGraph(state.nodes);
|
||||
const batchConfig: BatchConfig = {
|
||||
batch: {
|
||||
@@ -25,38 +19,12 @@ export const addEnqueueRequestedNodes = () => {
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
'Failed to enqueue batch'
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addSessionCanceledPendingListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCanceled.pending,
|
||||
effect: () => {
|
||||
//
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCanceledFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCanceled.fulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
log.debug({ session_id }, `Session canceled (${session_id})`);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCanceledRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCanceled.rejected,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
log.error(
|
||||
{
|
||||
session_id,
|
||||
error: serializeError(error),
|
||||
},
|
||||
`Problem canceling session`
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,45 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addSessionCreatedPendingListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCreated.pending,
|
||||
effect: () => {
|
||||
//
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCreatedFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCreated.fulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const session = action.payload;
|
||||
log.debug(
|
||||
{ session: parseify(session) },
|
||||
`Session created (${session.id})`
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCreatedRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCreated.rejected,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
if (action.payload) {
|
||||
const { error, status } = action.payload;
|
||||
const graph = parseify(action.meta.arg);
|
||||
log.error(
|
||||
{ graph, status, error: serializeError(error) },
|
||||
`Problem creating session`
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,44 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionInvoked } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addSessionInvokedPendingListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionInvoked.pending,
|
||||
effect: () => {
|
||||
//
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionInvokedFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionInvoked.fulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
log.debug({ session_id }, `Session invoked (${session_id})`);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionInvokedRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionInvoked.rejected,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
log.error(
|
||||
{
|
||||
session_id,
|
||||
error: serializeError(error),
|
||||
},
|
||||
`Problem invoking session`
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -81,9 +81,32 @@ export const addInvocationCompleteEventListener = () => {
|
||||
|
||||
// If auto-switch is enabled, select the new image
|
||||
if (shouldAutoSwitch) {
|
||||
// if auto-add is enabled, switch the board as the image comes in
|
||||
dispatch(galleryViewChanged('images'));
|
||||
dispatch(boardIdSelected(imageDTO.board_id ?? 'none'));
|
||||
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
|
||||
if (gallery.galleryView !== 'images') {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
|
||||
if (
|
||||
imageDTO.board_id &&
|
||||
imageDTO.board_id !== gallery.selectedBoardId
|
||||
) {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: imageDTO.board_id,
|
||||
selectedImageName: imageDTO.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: 'none',
|
||||
selectedImageName: imageDTO.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(imageSelected(imageDTO));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||
queueApi.util.invalidateTags([
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'InvocationCacheStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
{ type: 'SessionQueueItemDTO', id: item_id },
|
||||
{ type: 'BatchStatus', id: queue_batch_id },
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { AppThunkDispatch } from 'app/store/store';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { BatchConfig } from 'services/api/types';
|
||||
|
||||
export const enqueueBatch = async (
|
||||
batchConfig: BatchConfig,
|
||||
dispatch: AppThunkDispatch
|
||||
) => {
|
||||
const log = logger('session');
|
||||
const { prepend } = batchConfig;
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
dispatch(
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'resumeProcessor',
|
||||
})
|
||||
);
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -1,18 +1,9 @@
|
||||
import { chakra, ChakraProps } from '@chakra-ui/react';
|
||||
import { Box, ChakraProps } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { RgbaColorPicker } from 'react-colorful';
|
||||
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
|
||||
|
||||
type IAIColorPickerProps = Omit<ColorPickerBaseProps<RgbaColor>, 'color'> &
|
||||
ChakraProps & {
|
||||
pickerColor: RgbaColor;
|
||||
styleClass?: string;
|
||||
};
|
||||
|
||||
const ChakraRgbaColorPicker = chakra(RgbaColorPicker, {
|
||||
baseStyle: { paddingInline: 4 },
|
||||
shouldForwardProp: (prop) => !['pickerColor'].includes(prop),
|
||||
});
|
||||
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
|
||||
|
||||
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||
width: 6,
|
||||
@@ -20,19 +11,17 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||
borderColor: 'base.100',
|
||||
};
|
||||
|
||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
const { styleClass = '', ...rest } = props;
|
||||
const sx = {
|
||||
'.react-colorful__hue-pointer': colorPickerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerStyles,
|
||||
};
|
||||
|
||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
return (
|
||||
<ChakraRgbaColorPicker
|
||||
sx={{
|
||||
'.react-colorful__hue-pointer': colorPickerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerStyles,
|
||||
}}
|
||||
className={styleClass}
|
||||
{...rest}
|
||||
/>
|
||||
<Box sx={sx}>
|
||||
<RgbaColorPicker {...props} />
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -139,6 +139,11 @@ const IAICanvas = () => {
|
||||
const { handleDragStart, handleDragMove, handleDragEnd } =
|
||||
useCanvasDragMove();
|
||||
|
||||
const handleContextMenu = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => e.evt.preventDefault(),
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) {
|
||||
return;
|
||||
@@ -205,9 +210,7 @@ const IAICanvas = () => {
|
||||
onDragStart={handleDragStart}
|
||||
onDragMove={handleDragMove}
|
||||
onDragEnd={handleDragEnd}
|
||||
onContextMenu={(e: KonvaEventObject<MouseEvent>) =>
|
||||
e.evt.preventDefault()
|
||||
}
|
||||
onContextMenu={handleContextMenu}
|
||||
onWheel={handleWheel}
|
||||
draggable={(tool === 'move' || isStaging) && !isModifyingBoundingBox}
|
||||
>
|
||||
@@ -223,7 +226,11 @@ const IAICanvas = () => {
|
||||
>
|
||||
<IAICanvasObjectRenderer />
|
||||
</Layer>
|
||||
<Layer id="mask" visible={isMaskEnabled} listening={false}>
|
||||
<Layer
|
||||
id="mask"
|
||||
visible={isMaskEnabled && !isStaging}
|
||||
listening={false}
|
||||
>
|
||||
<IAICanvasMaskLines visible={true} listening={false} />
|
||||
<IAICanvasMaskCompositer listening={false} />
|
||||
</Layer>
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { Image, Rect } from 'react-konva';
|
||||
import { memo } from 'react';
|
||||
import { Image } from 'react-konva';
|
||||
import { $authToken } from 'services/api/client';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import useImage from 'use-image';
|
||||
import { CanvasImage } from '../store/canvasTypes';
|
||||
import { $authToken } from 'services/api/client';
|
||||
import { memo } from 'react';
|
||||
import IAICanvasImageErrorFallback from './IAICanvasImageErrorFallback';
|
||||
|
||||
type IAICanvasImageProps = {
|
||||
canvasImage: CanvasImage;
|
||||
};
|
||||
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||
const { width, height, x, y, imageName } = props.canvasImage;
|
||||
const { x, y, imageName } = props.canvasImage;
|
||||
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
|
||||
imageName ?? skipToken
|
||||
);
|
||||
const [image] = useImage(
|
||||
const [image, status] = useImage(
|
||||
imageDTO?.image_url ?? '',
|
||||
$authToken.get() ? 'use-credentials' : 'anonymous'
|
||||
);
|
||||
|
||||
if (isError) {
|
||||
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
||||
if (isError || status === 'failed') {
|
||||
return <IAICanvasImageErrorFallback canvasImage={props.canvasImage} />;
|
||||
}
|
||||
|
||||
return <Image x={x} y={y} image={image} listening={false} />;
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useColorModeValue, useToken } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Group, Rect, Text } from 'react-konva';
|
||||
import { CanvasImage } from '../store/canvasTypes';
|
||||
|
||||
type IAICanvasImageErrorFallbackProps = {
|
||||
canvasImage: CanvasImage;
|
||||
};
|
||||
const IAICanvasImageErrorFallback = ({
|
||||
canvasImage,
|
||||
}: IAICanvasImageErrorFallbackProps) => {
|
||||
const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] =
|
||||
useToken('colors', ['base.400', 'base.500', 'base.700', 'base.900']);
|
||||
const errorColor = useColorModeValue(errorColorLight, errorColorDark);
|
||||
const fontColor = useColorModeValue(fontColorLight, fontColorDark);
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Group>
|
||||
<Rect
|
||||
x={canvasImage.x}
|
||||
y={canvasImage.y}
|
||||
width={canvasImage.width}
|
||||
height={canvasImage.height}
|
||||
fill={errorColor}
|
||||
/>
|
||||
<Text
|
||||
x={canvasImage.x}
|
||||
y={canvasImage.y}
|
||||
width={canvasImage.width}
|
||||
height={canvasImage.height}
|
||||
align="center"
|
||||
verticalAlign="middle"
|
||||
fontFamily='"Inter Variable", sans-serif'
|
||||
fontSize={canvasImage.width / 16}
|
||||
fontStyle="600"
|
||||
text={t('common.imageFailedToLoad')}
|
||||
fill={fontColor}
|
||||
/>
|
||||
</Group>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICanvasImageErrorFallback);
|
||||
@@ -3,10 +3,9 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { GroupConfig } from 'konva/lib/Group';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { memo } from 'react';
|
||||
import { Group, Rect } from 'react-konva';
|
||||
import IAICanvasImage from './IAICanvasImage';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector],
|
||||
@@ -15,11 +14,11 @@ const selector = createSelector(
|
||||
layerState,
|
||||
shouldShowStagingImage,
|
||||
shouldShowStagingOutline,
|
||||
boundingBoxCoordinates: { x, y },
|
||||
boundingBoxDimensions: { width, height },
|
||||
boundingBoxCoordinates: stageBoundingBoxCoordinates,
|
||||
boundingBoxDimensions: stageBoundingBoxDimensions,
|
||||
} = canvas;
|
||||
|
||||
const { selectedImageIndex, images } = layerState.stagingArea;
|
||||
const { selectedImageIndex, images, boundingBox } = layerState.stagingArea;
|
||||
|
||||
return {
|
||||
currentStagingAreaImage:
|
||||
@@ -30,10 +29,10 @@ const selector = createSelector(
|
||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||
shouldShowStagingImage,
|
||||
shouldShowStagingOutline,
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
x: boundingBox?.x ?? stageBoundingBoxCoordinates.x,
|
||||
y: boundingBox?.y ?? stageBoundingBoxCoordinates.y,
|
||||
width: boundingBox?.width ?? stageBoundingBoxDimensions.width,
|
||||
height: boundingBox?.height ?? stageBoundingBoxDimensions.height,
|
||||
};
|
||||
},
|
||||
{
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -23,8 +24,8 @@ import {
|
||||
FaCheck,
|
||||
FaEye,
|
||||
FaEyeSlash,
|
||||
FaPlus,
|
||||
FaSave,
|
||||
FaTimes,
|
||||
} from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { stagingAreaImageSaved } from '../store/actions';
|
||||
@@ -41,10 +42,10 @@ const selector = createSelector(
|
||||
} = canvas;
|
||||
|
||||
return {
|
||||
currentIndex: selectedImageIndex,
|
||||
total: images.length,
|
||||
currentStagingAreaImage:
|
||||
images.length > 0 ? images[selectedImageIndex] : undefined,
|
||||
isOnFirstImage: selectedImageIndex === 0,
|
||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||
shouldShowStagingImage,
|
||||
shouldShowStagingOutline,
|
||||
};
|
||||
@@ -55,10 +56,10 @@ const selector = createSelector(
|
||||
const IAICanvasStagingAreaToolbar = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
isOnFirstImage,
|
||||
isOnLastImage,
|
||||
currentStagingAreaImage,
|
||||
shouldShowStagingImage,
|
||||
currentIndex,
|
||||
total,
|
||||
} = useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
@@ -71,39 +72,6 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
dispatch(setShouldShowStagingOutline(false));
|
||||
}, [dispatch]);
|
||||
|
||||
useHotkeys(
|
||||
['left'],
|
||||
() => {
|
||||
handlePrevImage();
|
||||
},
|
||||
{
|
||||
enabled: () => true,
|
||||
preventDefault: true,
|
||||
}
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['right'],
|
||||
() => {
|
||||
handleNextImage();
|
||||
},
|
||||
{
|
||||
enabled: () => true,
|
||||
preventDefault: true,
|
||||
}
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['enter'],
|
||||
() => {
|
||||
handleAccept();
|
||||
},
|
||||
{
|
||||
enabled: () => true,
|
||||
preventDefault: true,
|
||||
}
|
||||
);
|
||||
|
||||
const handlePrevImage = useCallback(
|
||||
() => dispatch(prevStagingAreaImage()),
|
||||
[dispatch]
|
||||
@@ -119,10 +87,45 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useHotkeys(['left'], handlePrevImage, {
|
||||
enabled: () => true,
|
||||
preventDefault: true,
|
||||
});
|
||||
|
||||
useHotkeys(['right'], handleNextImage, {
|
||||
enabled: () => true,
|
||||
preventDefault: true,
|
||||
});
|
||||
|
||||
useHotkeys(['enter'], () => handleAccept, {
|
||||
enabled: () => true,
|
||||
preventDefault: true,
|
||||
});
|
||||
|
||||
const { data: imageDTO } = useGetImageDTOQuery(
|
||||
currentStagingAreaImage?.imageName ?? skipToken
|
||||
);
|
||||
|
||||
const handleToggleShouldShowStagingImage = useCallback(() => {
|
||||
dispatch(setShouldShowStagingImage(!shouldShowStagingImage));
|
||||
}, [dispatch, shouldShowStagingImage]);
|
||||
|
||||
const handleSaveToGallery = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
stagingAreaImageSaved({
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
const handleDiscardStagingArea = useCallback(() => {
|
||||
dispatch(discardStagedImages());
|
||||
}, [dispatch]);
|
||||
|
||||
if (!currentStagingAreaImage) {
|
||||
return null;
|
||||
}
|
||||
@@ -131,11 +134,12 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
<Flex
|
||||
pos="absolute"
|
||||
bottom={4}
|
||||
gap={2}
|
||||
w="100%"
|
||||
align="center"
|
||||
justify="center"
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
onMouseEnter={handleMouseOver}
|
||||
onMouseLeave={handleMouseOut}
|
||||
>
|
||||
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
|
||||
<IAIIconButton
|
||||
@@ -144,16 +148,29 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
icon={<FaArrowLeft />}
|
||||
onClick={handlePrevImage}
|
||||
colorScheme="accent"
|
||||
isDisabled={isOnFirstImage}
|
||||
isDisabled={!shouldShowStagingImage}
|
||||
/>
|
||||
<IAIButton
|
||||
colorScheme="accent"
|
||||
pointerEvents="none"
|
||||
isDisabled={!shouldShowStagingImage}
|
||||
sx={{
|
||||
background: 'base.600',
|
||||
_dark: {
|
||||
background: 'base.800',
|
||||
},
|
||||
}}
|
||||
>{`${currentIndex + 1}/${total}`}</IAIButton>
|
||||
<IAIIconButton
|
||||
tooltip={`${t('unifiedCanvas.next')} (Right)`}
|
||||
aria-label={`${t('unifiedCanvas.next')} (Right)`}
|
||||
icon={<FaArrowRight />}
|
||||
onClick={handleNextImage}
|
||||
colorScheme="accent"
|
||||
isDisabled={isOnLastImage}
|
||||
isDisabled={!shouldShowStagingImage}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
|
||||
<IAIIconButton
|
||||
tooltip={`${t('unifiedCanvas.accept')} (Enter)`}
|
||||
aria-label={`${t('unifiedCanvas.accept')} (Enter)`}
|
||||
@@ -162,13 +179,19 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
colorScheme="accent"
|
||||
/>
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.showHide')}
|
||||
aria-label={t('unifiedCanvas.showHide')}
|
||||
tooltip={
|
||||
shouldShowStagingImage
|
||||
? t('unifiedCanvas.showResultsOn')
|
||||
: t('unifiedCanvas.showResultsOff')
|
||||
}
|
||||
aria-label={
|
||||
shouldShowStagingImage
|
||||
? t('unifiedCanvas.showResultsOn')
|
||||
: t('unifiedCanvas.showResultsOff')
|
||||
}
|
||||
data-alert={!shouldShowStagingImage}
|
||||
icon={shouldShowStagingImage ? <FaEye /> : <FaEyeSlash />}
|
||||
onClick={() =>
|
||||
dispatch(setShouldShowStagingImage(!shouldShowStagingImage))
|
||||
}
|
||||
onClick={handleToggleShouldShowStagingImage}
|
||||
colorScheme="accent"
|
||||
/>
|
||||
<IAIIconButton
|
||||
@@ -176,24 +199,14 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
aria-label={t('unifiedCanvas.saveToGallery')}
|
||||
isDisabled={!imageDTO || !imageDTO.is_intermediate}
|
||||
icon={<FaSave />}
|
||||
onClick={() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
stagingAreaImageSaved({
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
}}
|
||||
onClick={handleSaveToGallery}
|
||||
colorScheme="accent"
|
||||
/>
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.discardAll')}
|
||||
aria-label={t('unifiedCanvas.discardAll')}
|
||||
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
|
||||
onClick={() => dispatch(discardStagedImages())}
|
||||
icon={<FaTimes />}
|
||||
onClick={handleDiscardStagingArea}
|
||||
colorScheme="error"
|
||||
fontSize={20}
|
||||
/>
|
||||
|
||||
@@ -213,45 +213,45 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
[scaledStep]
|
||||
);
|
||||
|
||||
const handleStartedTransforming = () => {
|
||||
const handleStartedTransforming = useCallback(() => {
|
||||
dispatch(setIsTransformingBoundingBox(true));
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleEndedTransforming = () => {
|
||||
const handleEndedTransforming = useCallback(() => {
|
||||
dispatch(setIsTransformingBoundingBox(false));
|
||||
dispatch(setIsMovingBoundingBox(false));
|
||||
dispatch(setIsMouseOverBoundingBox(false));
|
||||
setIsMouseOverBoundingBoxOutline(false);
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleStartedMoving = () => {
|
||||
const handleStartedMoving = useCallback(() => {
|
||||
dispatch(setIsMovingBoundingBox(true));
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleEndedModifying = () => {
|
||||
const handleEndedModifying = useCallback(() => {
|
||||
dispatch(setIsTransformingBoundingBox(false));
|
||||
dispatch(setIsMovingBoundingBox(false));
|
||||
dispatch(setIsMouseOverBoundingBox(false));
|
||||
setIsMouseOverBoundingBoxOutline(false);
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleMouseOver = () => {
|
||||
const handleMouseOver = useCallback(() => {
|
||||
setIsMouseOverBoundingBoxOutline(true);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleMouseOut = () => {
|
||||
const handleMouseOut = useCallback(() => {
|
||||
!isTransformingBoundingBox &&
|
||||
!isMovingBoundingBox &&
|
||||
setIsMouseOverBoundingBoxOutline(false);
|
||||
};
|
||||
}, [isMovingBoundingBox, isTransformingBoundingBox]);
|
||||
|
||||
const handleMouseEnterBoundingBox = () => {
|
||||
const handleMouseEnterBoundingBox = useCallback(() => {
|
||||
dispatch(setIsMouseOverBoundingBox(true));
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
const handleMouseLeaveBoundingBox = () => {
|
||||
const handleMouseLeaveBoundingBox = useCallback(() => {
|
||||
dispatch(setIsMouseOverBoundingBox(false));
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Group {...rest}>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
@@ -135,11 +135,12 @@ const IAICanvasMaskOptions = () => {
|
||||
dispatch(setShouldPreserveMaskedArea(e.target.checked))
|
||||
}
|
||||
/>
|
||||
<IAIColorPicker
|
||||
sx={{ paddingTop: 2, paddingBottom: 2 }}
|
||||
pickerColor={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
<Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
|
||||
<IAIColorPicker
|
||||
color={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
|
||||
Save Mask
|
||||
</IAIButton>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Box } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -237,15 +237,18 @@ const IAICanvasToolChooserOptions = () => {
|
||||
sliderNumberInputProps={{ max: 500 }}
|
||||
/>
|
||||
</Flex>
|
||||
<IAIColorPicker
|
||||
<Box
|
||||
sx={{
|
||||
width: '100%',
|
||||
paddingTop: 2,
|
||||
paddingBottom: 2,
|
||||
}}
|
||||
pickerColor={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
>
|
||||
<IAIColorPicker
|
||||
color={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
</ButtonGroup>
|
||||
|
||||
@@ -6,7 +6,7 @@ export const canvasSelector = (state: RootState): CanvasState => state.canvas;
|
||||
|
||||
export const isStagingSelector = createSelector(
|
||||
[stateSelector],
|
||||
({ canvas }) => canvas.layerState.stagingArea.images.length > 0
|
||||
({ canvas }) => canvas.batchIds.length > 0
|
||||
);
|
||||
|
||||
export const initialCanvasImageSelector = (
|
||||
|
||||
@@ -8,7 +8,6 @@ import { setAspectRatio } from 'features/parameters/store/generationSlice';
|
||||
import { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import calculateCoordinates from '../util/calculateCoordinates';
|
||||
import calculateScale from '../util/calculateScale';
|
||||
@@ -187,7 +186,7 @@ export const canvasSlice = createSlice({
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
state.layerState = {
|
||||
...initialLayerState,
|
||||
...cloneDeep(initialLayerState),
|
||||
objects: [
|
||||
{
|
||||
kind: 'image',
|
||||
@@ -201,6 +200,7 @@ export const canvasSlice = createSlice({
|
||||
],
|
||||
};
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
|
||||
const newScale = calculateScale(
|
||||
stageDimensions.width,
|
||||
@@ -350,11 +350,14 @@ export const canvasSlice = createSlice({
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.stagingArea = { ...initialLayerState.stagingArea };
|
||||
state.layerState.stagingArea = cloneDeep(
|
||||
cloneDeep(initialLayerState)
|
||||
).stagingArea;
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
addFillRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } =
|
||||
@@ -491,8 +494,9 @@ export const canvasSlice = createSlice({
|
||||
resetCanvas: (state) => {
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
state.layerState = initialLayerState;
|
||||
state.layerState = cloneDeep(initialLayerState);
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
},
|
||||
canvasResized: (
|
||||
state,
|
||||
@@ -617,25 +621,22 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
||||
const length = state.layerState.stagingArea.images.length;
|
||||
const nextIndex = state.layerState.stagingArea.selectedImageIndex + 1;
|
||||
const lastIndex = state.layerState.stagingArea.images.length - 1;
|
||||
|
||||
state.layerState.stagingArea.selectedImageIndex = Math.min(
|
||||
currentIndex + 1,
|
||||
length - 1
|
||||
);
|
||||
state.layerState.stagingArea.selectedImageIndex =
|
||||
nextIndex > lastIndex ? 0 : nextIndex;
|
||||
},
|
||||
prevStagingAreaImage: (state) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
||||
const prevIndex = state.layerState.stagingArea.selectedImageIndex - 1;
|
||||
const lastIndex = state.layerState.stagingArea.images.length - 1;
|
||||
|
||||
state.layerState.stagingArea.selectedImageIndex = Math.max(
|
||||
currentIndex - 1,
|
||||
0
|
||||
);
|
||||
state.layerState.stagingArea.selectedImageIndex =
|
||||
prevIndex < 0 ? lastIndex : prevIndex;
|
||||
},
|
||||
commitStagingAreaImage: (state) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
@@ -657,13 +658,12 @@ export const canvasSlice = createSlice({
|
||||
...imageToCommit,
|
||||
});
|
||||
}
|
||||
state.layerState.stagingArea = {
|
||||
...initialLayerState.stagingArea,
|
||||
};
|
||||
state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea;
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
fitBoundingBoxToStage: (state) => {
|
||||
const {
|
||||
@@ -786,11 +786,6 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(sessionCanceled.pending, (state) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
state.layerState.stagingArea = initialLayerState.stagingArea;
|
||||
}
|
||||
});
|
||||
builder.addCase(setAspectRatio, (state, action) => {
|
||||
const ratio = action.payload;
|
||||
if (ratio) {
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import { getCanvasBaseLayer } from './konvaInstanceProvider';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { getCanvasBaseLayer } from './konvaInstanceProvider';
|
||||
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
||||
|
||||
/**
|
||||
* Get the canvas base layer blob, with or without bounding box according to `shouldCropToBoundingBoxOnSave`
|
||||
*/
|
||||
export const getBaseLayerBlob = async (state: RootState) => {
|
||||
export const getBaseLayerBlob = async (
|
||||
state: RootState,
|
||||
alwaysUseBoundingBox: boolean = false
|
||||
) => {
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
|
||||
if (!canvasBaseLayer) {
|
||||
@@ -24,14 +27,15 @@ export const getBaseLayerBlob = async (state: RootState) => {
|
||||
|
||||
const absPos = clonedBaseLayer.getAbsolutePosition();
|
||||
|
||||
const boundingBox = shouldCropToBoundingBoxOnSave
|
||||
? {
|
||||
x: boundingBoxCoordinates.x + absPos.x,
|
||||
y: boundingBoxCoordinates.y + absPos.y,
|
||||
width: boundingBoxDimensions.width,
|
||||
height: boundingBoxDimensions.height,
|
||||
}
|
||||
: clonedBaseLayer.getClientRect();
|
||||
const boundingBox =
|
||||
shouldCropToBoundingBoxOnSave || alwaysUseBoundingBox
|
||||
? {
|
||||
x: boundingBoxCoordinates.x + absPos.x,
|
||||
y: boundingBoxCoordinates.y + absPos.y,
|
||||
width: boundingBoxDimensions.width,
|
||||
height: boundingBoxDimensions.height,
|
||||
}
|
||||
: clonedBaseLayer.getClientRect();
|
||||
|
||||
return konvaNodeToBlob(clonedBaseLayer, boundingBox);
|
||||
};
|
||||
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { components } from 'services/api/schema';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
@@ -99,6 +98,9 @@ export const controlNetSlice = createSlice({
|
||||
isControlNetEnabledToggled: (state) => {
|
||||
state.isEnabled = !state.isEnabled;
|
||||
},
|
||||
controlNetEnabled: (state) => {
|
||||
state.isEnabled = true;
|
||||
},
|
||||
controlNetAdded: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -112,6 +114,12 @@ export const controlNetSlice = createSlice({
|
||||
controlNetId,
|
||||
};
|
||||
},
|
||||
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
|
||||
const controlNet = action.payload;
|
||||
state.controlNets[controlNet.controlNetId] = {
|
||||
...controlNet,
|
||||
};
|
||||
},
|
||||
controlNetDuplicated: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -418,10 +426,6 @@ export const controlNetSlice = createSlice({
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
|
||||
builder.addMatcher(isAnySessionRejected, (state) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
|
||||
builder.addMatcher(
|
||||
imagesApi.endpoints.deleteImage.matchFulfilled,
|
||||
(state, action) => {
|
||||
@@ -444,7 +448,9 @@ export const controlNetSlice = createSlice({
|
||||
|
||||
export const {
|
||||
isControlNetEnabledToggled,
|
||||
controlNetEnabled,
|
||||
controlNetAdded,
|
||||
controlNetRecalled,
|
||||
controlNetDuplicated,
|
||||
controlNetAddedFromImage,
|
||||
controlNetRemoved,
|
||||
|
||||
@@ -93,7 +93,7 @@ const GalleryBoard = ({
|
||||
const [localBoardName, setLocalBoardName] = useState(board_name);
|
||||
|
||||
const handleSelectBoard = useCallback(() => {
|
||||
dispatch(boardIdSelected(board_id));
|
||||
dispatch(boardIdSelected({ boardId: board_id }));
|
||||
if (autoAssignBoardOnClick) {
|
||||
dispatch(autoAddBoardIdChanged(board_id));
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector);
|
||||
const boardName = useBoardName('none');
|
||||
const handleSelectBoard = useCallback(() => {
|
||||
dispatch(boardIdSelected('none'));
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
if (autoAssignBoardOnClick) {
|
||||
dispatch(autoAddBoardIdChanged('none'));
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ const SystemBoardButton = ({ board_id }: Props) => {
|
||||
const boardName = useBoardName(board_id);
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(boardIdSelected(board_id));
|
||||
dispatch(boardIdSelected({ boardId: board_id }));
|
||||
}, [board_id, dispatch]);
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import {
|
||||
ControlNetMetadataItem,
|
||||
CoreMetadata,
|
||||
LoRAMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useMemo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
||||
import {
|
||||
isValidControlNetModel,
|
||||
isValidLoRAModel,
|
||||
} from '../../../parameters/types/parameterSchemas';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
type Props = {
|
||||
@@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
} = useRecallParameters();
|
||||
|
||||
const handleRecallPositivePrompt = useCallback(() => {
|
||||
@@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => {
|
||||
[recallLoRA]
|
||||
);
|
||||
|
||||
const handleRecallControlNet = useCallback(
|
||||
(controlnet: ControlNetMetadataItem) => {
|
||||
recallControlNet(controlnet);
|
||||
},
|
||||
[recallControlNet]
|
||||
);
|
||||
|
||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||
return metadata?.controlnets
|
||||
? metadata.controlnets.filter((controlnet) =>
|
||||
isValidControlNetModel(controlnet.control_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.controlnets]);
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
@@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => {
|
||||
);
|
||||
}
|
||||
})}
|
||||
{validControlNets.map((controlnet, index) => (
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="ControlNet"
|
||||
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
|
||||
onClick={() => handleRecallControlNet(controlnet)}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -35,8 +35,11 @@ export const gallerySlice = createSlice({
|
||||
autoAssignBoardOnClickChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.autoAssignBoardOnClick = action.payload;
|
||||
},
|
||||
boardIdSelected: (state, action: PayloadAction<BoardId>) => {
|
||||
state.selectedBoardId = action.payload;
|
||||
boardIdSelected: (
|
||||
state,
|
||||
action: PayloadAction<{ boardId: BoardId; selectedImageName?: string }>
|
||||
) => {
|
||||
state.selectedBoardId = action.payload.boardId;
|
||||
state.galleryView = 'images';
|
||||
},
|
||||
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
OnConnect,
|
||||
OnConnectEnd,
|
||||
OnConnectStart,
|
||||
OnEdgeUpdateFunc,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnInit,
|
||||
@@ -21,6 +22,7 @@ import {
|
||||
OnSelectionChangeFunc,
|
||||
ProOptions,
|
||||
ReactFlow,
|
||||
ReactFlowProps,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { useIsValidConnection } from '../../hooks/useIsValidConnection';
|
||||
@@ -28,6 +30,8 @@ import {
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeAdded,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
nodesChanged,
|
||||
@@ -167,6 +171,63 @@ export const Flow = () => {
|
||||
}
|
||||
}, []);
|
||||
|
||||
// #region Updatable Edges
|
||||
|
||||
/**
|
||||
* Adapted from https://reactflow.dev/docs/examples/edges/updatable-edge/
|
||||
* and https://reactflow.dev/docs/examples/edges/delete-edge-on-drop/
|
||||
*
|
||||
* - Edges can be dragged from one handle to another.
|
||||
* - If the user drags the edge away from the node and drops it, delete the edge.
|
||||
* - Do not delete the edge if the cursor didn't move (resolves annoying behaviour
|
||||
* where the edge is deleted if you click it accidentally).
|
||||
*/
|
||||
|
||||
// We have a ref for cursor position, but it is the *projected* cursor position.
|
||||
// Easiest to just keep track of the last mouse event for this particular feature
|
||||
const edgeUpdateMouseEvent = useRef<MouseEvent>();
|
||||
|
||||
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> =
|
||||
useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
// update mouse event
|
||||
edgeUpdateMouseEvent.current = e;
|
||||
// always delete the edge when starting an updated
|
||||
dispatch(edgeDeleted(edge.id));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
|
||||
(_oldEdge, newConnection) => {
|
||||
// instead of updating the edge (we deleted it earlier), we instead create
|
||||
// a new one.
|
||||
dispatch(connectionMade(newConnection));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> =
|
||||
useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
// Handle the case where user begins a drag but didn't move the cursor -
|
||||
// bc we deleted the edge, we need to add it back
|
||||
if (
|
||||
// ignore touch events
|
||||
!('touches' in e) &&
|
||||
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
|
||||
edgeUpdateMouseEvent.current?.clientY === e.clientY
|
||||
) {
|
||||
dispatch(edgeAdded(edge));
|
||||
}
|
||||
// reset mouse event
|
||||
edgeUpdateMouseEvent.current = undefined;
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
// #endregion
|
||||
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionCopied());
|
||||
@@ -196,6 +257,9 @@ export const Flow = () => {
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
onEdgeUpdate={onEdgeUpdate}
|
||||
onEdgeUpdateStart={onEdgeUpdateStart}
|
||||
onEdgeUpdateEnd={onEdgeUpdateEnd}
|
||||
onNodesDelete={onNodesDelete}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
|
||||
@@ -53,13 +53,12 @@ export const useIsValidConnection = () => {
|
||||
}
|
||||
|
||||
if (
|
||||
edges
|
||||
.filter((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
})
|
||||
.find((edge) => {
|
||||
edge.source === source && edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
edges.find((edge) => {
|
||||
edge.target === target &&
|
||||
edge.targetHandle === targetHandle &&
|
||||
edge.source === source &&
|
||||
edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return false;
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
SelectionMode,
|
||||
updateEdge,
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
@@ -182,6 +183,16 @@ const nodesSlice = createSlice({
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
},
|
||||
edgeAdded: (state, action: PayloadAction<Edge>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
edgeUpdated: (
|
||||
state,
|
||||
action: PayloadAction<{ oldEdge: Edge; newConnection: Connection }>
|
||||
) => {
|
||||
const { oldEdge, newConnection } = action.payload;
|
||||
state.edges = updateEdge(oldEdge, newConnection, state.edges);
|
||||
},
|
||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||
state.connectionStartParams = action.payload;
|
||||
const { nodeId, handleId, handleType } = action.payload;
|
||||
@@ -366,6 +377,7 @@ const nodesSlice = createSlice({
|
||||
target: edge.target,
|
||||
type: 'collapsed',
|
||||
data: { count: 1 },
|
||||
updatable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -388,6 +400,7 @@ const nodesSlice = createSlice({
|
||||
target: edge.target,
|
||||
type: 'collapsed',
|
||||
data: { count: 1 },
|
||||
updatable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -400,6 +413,9 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
}
|
||||
},
|
||||
edgeDeleted: (state, action: PayloadAction<string>) => {
|
||||
state.edges = state.edges.filter((e) => e.id !== action.payload);
|
||||
},
|
||||
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
|
||||
const edges = action.payload;
|
||||
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
|
||||
@@ -890,69 +906,72 @@ const nodesSlice = createSlice({
|
||||
});
|
||||
|
||||
export const {
|
||||
nodesChanged,
|
||||
edgesChanged,
|
||||
nodeAdded,
|
||||
nodesDeleted,
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverOpened,
|
||||
addNodePopoverToggled,
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
connectionEnded,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
nodeTemplatesBuilt,
|
||||
nodeEditorReset,
|
||||
imageCollectionFieldValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
edgeUpdated,
|
||||
fieldBoardValueChanged,
|
||||
fieldBooleanValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldColorValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
imageCollectionFieldValueChanged,
|
||||
mouseOverFieldChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeAdded,
|
||||
nodeEditorReset,
|
||||
nodeEmbedWorkflowChanged,
|
||||
nodeExclusivelySelected,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
nodeLabelChanged,
|
||||
nodeNotesChanged,
|
||||
edgesDeleted,
|
||||
shouldValidateGraphChanged,
|
||||
shouldAnimateEdgesChanged,
|
||||
nodeOpacityChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldColorEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
selectedEdgesChanged,
|
||||
workflowNameChanged,
|
||||
workflowDescriptionChanged,
|
||||
workflowTagsChanged,
|
||||
workflowAuthorChanged,
|
||||
workflowNotesChanged,
|
||||
workflowVersionChanged,
|
||||
workflowContactChanged,
|
||||
workflowLoaded,
|
||||
nodesChanged,
|
||||
nodesDeleted,
|
||||
nodeTemplatesBuilt,
|
||||
nodeUseCacheChanged,
|
||||
notesNodeValueChanged,
|
||||
selectedAll,
|
||||
selectedEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
selectionCopied,
|
||||
selectionModeChanged,
|
||||
selectionPasted,
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
viewportChanged,
|
||||
workflowAuthorChanged,
|
||||
workflowContactChanged,
|
||||
workflowDescriptionChanged,
|
||||
workflowExposedFieldAdded,
|
||||
workflowExposedFieldRemoved,
|
||||
fieldLabelChanged,
|
||||
viewportChanged,
|
||||
mouseOverFieldChanged,
|
||||
selectionCopied,
|
||||
selectionPasted,
|
||||
selectedAll,
|
||||
addNodePopoverOpened,
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverToggled,
|
||||
selectionModeChanged,
|
||||
nodeEmbedWorkflowChanged,
|
||||
nodeIsIntermediateChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeExclusivelySelected,
|
||||
nodeUseCacheChanged,
|
||||
workflowLoaded,
|
||||
workflowNameChanged,
|
||||
workflowNotesChanged,
|
||||
workflowTagsChanged,
|
||||
workflowVersionChanged,
|
||||
edgeAdded,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
||||
@@ -55,9 +55,29 @@ export const makeConnectionErrorSelector = (
|
||||
return i18n.t('nodes.cannotConnectInputToInput');
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const target = handleType === 'target' ? nodeId : connectionNodeId;
|
||||
const targetHandle =
|
||||
handleType === 'target' ? fieldName : connectionFieldName;
|
||||
const source = handleType === 'source' ? nodeId : connectionNodeId;
|
||||
const sourceHandle =
|
||||
handleType === 'source' ? fieldName : connectionFieldName;
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
||||
edge.target === target &&
|
||||
edge.targetHandle === targetHandle &&
|
||||
edge.source === source &&
|
||||
edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return i18n.t('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType !== 'CollectionItem'
|
||||
|
||||
@@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({
|
||||
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
|
||||
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
|
||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish().catch(null),
|
||||
@@ -1222,6 +1226,7 @@ export const zInvocationNodeData = z.object({
|
||||
notes: z.string(),
|
||||
embedWorkflow: z.boolean(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean().optional(),
|
||||
version: zSemVer.optional(),
|
||||
});
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ export const addSDXLRefinerToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId?: string,
|
||||
canvasInitImage?: ImageDTO
|
||||
canvasInitImage?: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): void => {
|
||||
const {
|
||||
refinerModel,
|
||||
@@ -257,8 +258,30 @@ export const addSDXLRefinerToGraph = (
|
||||
};
|
||||
}
|
||||
|
||||
graph.edges.push(
|
||||
{
|
||||
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK
|
||||
] as CreateDenoiseMaskInvocation),
|
||||
mask: canvasMaskImage,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||
field: 'image',
|
||||
@@ -267,18 +290,19 @@ export const addSDXLRefinerToGraph = (
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_DENOISE_LATENTS,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
}
|
||||
);
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_DENOISE_LATENTS,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
|
||||
@@ -663,7 +663,8 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
graph,
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
modelLoaderNodeId,
|
||||
canvasInitImage
|
||||
canvasInitImage,
|
||||
canvasMaskImage
|
||||
);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
|
||||
@@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import {
|
||||
CoreMetadata,
|
||||
LoRAMetadataItem,
|
||||
ControlNetMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
@@ -18,9 +22,18 @@ import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
controlNetModelsAdapter,
|
||||
loraModelsAdapter,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
} from '../../../services/api/endpoints/models';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetEnabled,
|
||||
controlNetRecalled,
|
||||
controlNetReset,
|
||||
initialControlNet,
|
||||
} from '../../controlNet/store/controlNetSlice';
|
||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import {
|
||||
@@ -38,6 +51,7 @@ import {
|
||||
isValidCfgScale,
|
||||
isValidHeight,
|
||||
isValidLoRAModel,
|
||||
isValidControlNetModel,
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
@@ -53,6 +67,11 @@ import {
|
||||
isValidStrength,
|
||||
isValidWidth,
|
||||
} from '../types/parameterSchemas';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import {
|
||||
CONTROLNET_PROCESSORS,
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||
} from 'features/controlNet/store/constants';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ generation }) => {
|
||||
const { model } = generation;
|
||||
@@ -390,6 +409,121 @@ export const useRecallParameters = () => {
|
||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall ControlNet with toast
|
||||
*/
|
||||
|
||||
const { controlnets } = useGetControlNetModelsQuery(undefined, {
|
||||
selectFromResult: (result) => ({
|
||||
controlnets: result.data
|
||||
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
|
||||
: [],
|
||||
}),
|
||||
});
|
||||
|
||||
const prepareControlNetMetadataItem = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||
if (!isValidControlNetModel(controlnetMetadataItem.control_model)) {
|
||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||
}
|
||||
|
||||
const {
|
||||
image,
|
||||
control_model,
|
||||
control_weight,
|
||||
begin_step_percent,
|
||||
end_step_percent,
|
||||
control_mode,
|
||||
resize_mode,
|
||||
} = controlnetMetadataItem;
|
||||
|
||||
const matchingControlNetModel = controlnets.find(
|
||||
(c) =>
|
||||
c.base_model === control_model.base_model &&
|
||||
c.model_name === control_model.model_name
|
||||
);
|
||||
|
||||
if (!matchingControlNetModel) {
|
||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel =
|
||||
matchingControlNetModel?.base_model === model?.base_model;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
controlnet: null,
|
||||
error: 'ControlNet incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
const controlNetId = uuidv4();
|
||||
|
||||
let processorType = initialControlNet.processorType;
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
|
||||
processorType =
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
|
||||
initialControlNet.processorType;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
|
||||
const controlnet: ControlNetConfig = {
|
||||
isEnabled: true,
|
||||
model: matchingControlNetModel,
|
||||
weight:
|
||||
typeof control_weight === 'number'
|
||||
? control_weight
|
||||
: initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent || initialControlNet.endStepPct,
|
||||
controlMode: control_mode || initialControlNet.controlMode,
|
||||
resizeMode: resize_mode || initialControlNet.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode:
|
||||
processorNode.type !== 'none'
|
||||
? processorNode
|
||||
: initialControlNet.processorNode,
|
||||
shouldAutoConfig: true,
|
||||
controlNetId,
|
||||
};
|
||||
|
||||
return { controlnet, error: null };
|
||||
},
|
||||
[controlnets, model?.base_model]
|
||||
);
|
||||
|
||||
const recallControlNet = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
|
||||
|
||||
if (!result.controlnet) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
controlNetRecalled({
|
||||
...result.controlnet,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(controlNetEnabled());
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[
|
||||
prepareControlNetMetadataItem,
|
||||
dispatch,
|
||||
parameterSetToast,
|
||||
parameterNotSetToast,
|
||||
]
|
||||
);
|
||||
|
||||
/*
|
||||
* Sets image as initial image with toast
|
||||
*/
|
||||
@@ -428,6 +562,7 @@ export const useRecallParameters = () => {
|
||||
refiner_negative_aesthetic_score,
|
||||
refiner_start,
|
||||
loras,
|
||||
controlnets,
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
@@ -517,6 +652,15 @@ export const useRecallParameters = () => {
|
||||
}
|
||||
});
|
||||
|
||||
dispatch(controlNetReset());
|
||||
dispatch(controlNetEnabled());
|
||||
controlnets?.forEach((controlnet) => {
|
||||
const result = prepareControlNetMetadataItem(controlnet);
|
||||
if (result.controlnet) {
|
||||
dispatch(controlNetRecalled(result.controlnet));
|
||||
}
|
||||
});
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[
|
||||
@@ -524,6 +668,7 @@ export const useRecallParameters = () => {
|
||||
allParameterSetToast,
|
||||
dispatch,
|
||||
prepareLoRAMetadataItem,
|
||||
prepareControlNetMetadataItem,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -542,6 +687,7 @@ export const useRecallParameters = () => {
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
recallAllParameters,
|
||||
sendToImageToImage,
|
||||
};
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { ButtonGroup } from '@chakra-ui/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
|
||||
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
|
||||
import StatusStatGroup from './common/StatusStatGroup';
|
||||
@@ -11,16 +9,7 @@ import StatusStatItem from './common/StatusStatItem';
|
||||
|
||||
const InvocationCacheStatus = () => {
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery(undefined);
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
|
||||
pollingInterval:
|
||||
isConnected &&
|
||||
queueStatus?.processor.is_started &&
|
||||
queueStatus?.queue.pending > 0
|
||||
? 5000
|
||||
: 0,
|
||||
});
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined);
|
||||
|
||||
return (
|
||||
<StatusStatGroup>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -40,7 +41,7 @@ export const useCancelCurrentQueueItem = () => {
|
||||
}, [currentQueueItemId, dispatch, t, trigger]);
|
||||
|
||||
const isDisabled = useMemo(
|
||||
() => !isConnected || !currentQueueItemId,
|
||||
() => !isConnected || isNil(currentQueueItemId),
|
||||
[isConnected, currentQueueItemId]
|
||||
);
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { t } from 'i18next';
|
||||
import { get, startCase, truncate, upperFirst } from 'lodash-es';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import {
|
||||
appSocketConnected,
|
||||
appSocketDisconnected,
|
||||
@@ -20,8 +19,7 @@ import {
|
||||
} from 'services/events/actions';
|
||||
import { calculateStepPercentage } from '../util/calculateStepPercentage';
|
||||
import { makeToast } from '../util/makeToast';
|
||||
import { SystemState, LANGUAGES } from './types';
|
||||
import { zPydanticValidationError } from './zodSchemas';
|
||||
import { LANGUAGES, SystemState } from './types';
|
||||
|
||||
export const initialSystemState: SystemState = {
|
||||
isInitialized: false,
|
||||
@@ -175,50 +173,6 @@ export const systemSlice = createSlice({
|
||||
|
||||
// *** Matchers - must be after all cases ***
|
||||
|
||||
/**
|
||||
* Session Invoked - REJECTED
|
||||
* Session Created - REJECTED
|
||||
*/
|
||||
builder.addMatcher(isAnySessionRejected, (state, action) => {
|
||||
let errorDescription = undefined;
|
||||
const duration = 5000;
|
||||
|
||||
if (action.payload?.status === 422) {
|
||||
const result = zPydanticValidationError.safeParse(action.payload);
|
||||
if (result.success) {
|
||||
result.data.error.detail.map((e) => {
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: truncate(upperFirst(e.msg), { length: 128 }),
|
||||
status: 'error',
|
||||
description: truncate(
|
||||
`Path:
|
||||
${e.loc.join('.')}`,
|
||||
{ length: 128 }
|
||||
),
|
||||
duration,
|
||||
})
|
||||
);
|
||||
});
|
||||
return;
|
||||
}
|
||||
} else if (action.payload?.error) {
|
||||
errorDescription = action.payload?.error;
|
||||
}
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description: truncate(
|
||||
get(errorDescription, 'detail', 'Unknown Error'),
|
||||
{ length: 128 }
|
||||
),
|
||||
duration,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Any server error
|
||||
*/
|
||||
|
||||
@@ -2,7 +2,7 @@ import { z } from 'zod';
|
||||
|
||||
export const zPydanticValidationError = z.object({
|
||||
status: z.literal(422),
|
||||
error: z.object({
|
||||
data: z.object({
|
||||
detail: z.array(
|
||||
z.object({
|
||||
loc: z.array(z.string()),
|
||||
|
||||
@@ -14,7 +14,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||
import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
@@ -110,7 +110,7 @@ export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager', 'queue'];
|
||||
export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue'];
|
||||
|
||||
const InvokeTabs = () => {
|
||||
const activeTab = useAppSelector(activeTabIndexSelector);
|
||||
const activeTabIndex = useAppSelector(activeTabIndexSelector);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const enabledTabs = useAppSelector(enabledTabsSelector);
|
||||
const { t } = useTranslation();
|
||||
@@ -150,13 +150,13 @@ const InvokeTabs = () => {
|
||||
|
||||
const handleTabChange = useCallback(
|
||||
(index: number) => {
|
||||
const activeTabName = tabMap[index];
|
||||
if (!activeTabName) {
|
||||
const tab = enabledTabs[index];
|
||||
if (!tab) {
|
||||
return;
|
||||
}
|
||||
dispatch(setActiveTab(activeTabName));
|
||||
dispatch(setActiveTab(tab.id));
|
||||
},
|
||||
[dispatch]
|
||||
[dispatch, enabledTabs]
|
||||
);
|
||||
|
||||
const {
|
||||
@@ -216,8 +216,8 @@ const InvokeTabs = () => {
|
||||
return (
|
||||
<Tabs
|
||||
variant="appTabs"
|
||||
defaultIndex={activeTab}
|
||||
index={activeTab}
|
||||
defaultIndex={activeTabIndex}
|
||||
index={activeTabIndex}
|
||||
onChange={handleTabChange}
|
||||
sx={{
|
||||
flexGrow: 1,
|
||||
|
||||
@@ -95,26 +95,32 @@ export default function UnifiedCanvasColorPicker() {
|
||||
>
|
||||
<Flex minWidth={60} direction="column" gap={4} width="100%">
|
||||
{layer === 'base' && (
|
||||
<IAIColorPicker
|
||||
<Box
|
||||
sx={{
|
||||
width: '100%',
|
||||
paddingTop: 2,
|
||||
paddingBottom: 2,
|
||||
}}
|
||||
pickerColor={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
>
|
||||
<IAIColorPicker
|
||||
color={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{layer === 'mask' && (
|
||||
<IAIColorPicker
|
||||
<Box
|
||||
sx={{
|
||||
width: '100%',
|
||||
paddingTop: 2,
|
||||
paddingBottom: 2,
|
||||
}}
|
||||
pickerColor={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
>
|
||||
<IAIColorPicker
|
||||
color={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
|
||||
export const setActiveTabReducer = (
|
||||
state: UIState,
|
||||
newActiveTab: number | InvokeTabName
|
||||
) => {
|
||||
if (typeof newActiveTab === 'number') {
|
||||
state.activeTab = newActiveTab;
|
||||
} else {
|
||||
state.activeTab = tabMap.indexOf(newActiveTab);
|
||||
}
|
||||
};
|
||||
@@ -1,27 +1,23 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
import { isEqual, isString } from 'lodash-es';
|
||||
import { tabMap } from './tabMap';
|
||||
|
||||
export const activeTabNameSelector = createSelector(
|
||||
(state: RootState) => state.ui,
|
||||
(ui: UIState) => tabMap[ui.activeTab] as InvokeTabName,
|
||||
{
|
||||
memoizeOptions: {
|
||||
equalityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
(state: RootState) => state,
|
||||
/**
|
||||
* Previously `activeTab` was an integer, but now it's a string.
|
||||
* Default to first tab in case user has integer.
|
||||
*/
|
||||
({ ui }) => (isString(ui.activeTab) ? ui.activeTab : 'txt2img')
|
||||
);
|
||||
|
||||
export const activeTabIndexSelector = createSelector(
|
||||
(state: RootState) => state.ui,
|
||||
(ui: UIState) => ui.activeTab,
|
||||
{
|
||||
memoizeOptions: {
|
||||
equalityCheck: isEqual,
|
||||
},
|
||||
(state: RootState) => state,
|
||||
({ ui, config }) => {
|
||||
const tabs = tabMap.filter((t) => !config.disabledTabs.includes(t));
|
||||
const idx = tabs.indexOf(ui.activeTab);
|
||||
return idx === -1 ? 0 : idx;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
|
||||
export const initialUIState: UIState = {
|
||||
activeTab: 0,
|
||||
activeTab: 'txt2img',
|
||||
shouldShowImageDetails: false,
|
||||
shouldUseCanvasBetaLayout: false,
|
||||
shouldShowExistingModelsInSearch: false,
|
||||
@@ -26,7 +25,7 @@ export const uiSlice = createSlice({
|
||||
initialState: initialUIState,
|
||||
reducers: {
|
||||
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
|
||||
setActiveTabReducer(state, action.payload);
|
||||
state.activeTab = action.payload;
|
||||
},
|
||||
setShouldShowImageDetails: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowImageDetails = action.payload;
|
||||
@@ -73,7 +72,7 @@ export const uiSlice = createSlice({
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(initialImageChanged, (state) => {
|
||||
setActiveTabReducer(state, 'img2img');
|
||||
state.activeTab = 'img2img';
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
|
||||
export type Coordinates = {
|
||||
x: number;
|
||||
@@ -13,7 +14,7 @@ export type Dimensions = {
|
||||
export type Rect = Coordinates & Dimensions;
|
||||
|
||||
export interface UIState {
|
||||
activeTab: number;
|
||||
activeTab: InvokeTabName;
|
||||
shouldShowImageDetails: boolean;
|
||||
shouldUseCanvasBetaLayout: boolean;
|
||||
shouldShowExistingModelsInSearch: boolean;
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { $queueId } from 'features/queue/store/queueNanoStore';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { $client } from 'services/api/client';
|
||||
import { paths } from 'services/api/schema';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
type CreateSessionArg = {
|
||||
graph: NonNullable<
|
||||
paths['/api/v1/sessions/']['post']['requestBody']
|
||||
>['content']['application/json'];
|
||||
};
|
||||
|
||||
type CreateSessionResponse = O.Required<
|
||||
NonNullable<
|
||||
paths['/api/v1/sessions/']['post']['requestBody']
|
||||
>['content']['application/json'],
|
||||
'id'
|
||||
>;
|
||||
|
||||
type CreateSessionThunkConfig = {
|
||||
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
|
||||
};
|
||||
|
||||
/**
|
||||
* `SessionsService.createSession()` thunk
|
||||
*/
|
||||
export const sessionCreated = createAsyncThunk<
|
||||
CreateSessionResponse,
|
||||
CreateSessionArg,
|
||||
CreateSessionThunkConfig
|
||||
>('api/sessionCreated', async (arg, { rejectWithValue }) => {
|
||||
const { graph } = arg;
|
||||
const { POST } = $client.get();
|
||||
const { data, error, response } = await POST('/api/v1/sessions/', {
|
||||
body: graph,
|
||||
params: { query: { queue_id: $queueId.get() } },
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
|
||||
type InvokedSessionArg = {
|
||||
session_id: paths['/api/v1/sessions/{session_id}/invoke']['put']['parameters']['path']['session_id'];
|
||||
};
|
||||
|
||||
type InvokedSessionResponse =
|
||||
paths['/api/v1/sessions/{session_id}/invoke']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type InvokedSessionThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: InvokedSessionArg;
|
||||
error: unknown;
|
||||
status: number;
|
||||
};
|
||||
};
|
||||
|
||||
const isErrorWithStatus = (error: unknown): error is { status: number } =>
|
||||
isObject(error) && 'status' in error;
|
||||
|
||||
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
|
||||
isObject(error) && 'detail' in error;
|
||||
|
||||
/**
|
||||
* `SessionsService.invokeSession()` thunk
|
||||
*/
|
||||
export const sessionInvoked = createAsyncThunk<
|
||||
InvokedSessionResponse,
|
||||
InvokedSessionArg,
|
||||
InvokedSessionThunkConfig
|
||||
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
|
||||
const { session_id } = arg;
|
||||
const { PUT } = $client.get();
|
||||
const { error, response } = await PUT(
|
||||
'/api/v1/sessions/{session_id}/invoke',
|
||||
{
|
||||
params: {
|
||||
query: { queue_id: $queueId.get(), all: true },
|
||||
path: { session_id },
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (error) {
|
||||
if (isErrorWithStatus(error) && error.status === 403) {
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
error: (error as any).body.detail,
|
||||
});
|
||||
}
|
||||
if (isErrorWithDetail(error) && response.status === 403) {
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
error: error.detail,
|
||||
});
|
||||
}
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
type CancelSessionArg =
|
||||
paths['/api/v1/sessions/{session_id}/invoke']['delete']['parameters']['path'];
|
||||
|
||||
type CancelSessionResponse =
|
||||
paths['/api/v1/sessions/{session_id}/invoke']['delete']['responses']['200']['content']['application/json'];
|
||||
|
||||
type CancelSessionThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: CancelSessionArg;
|
||||
error: unknown;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* `SessionsService.cancelSession()` thunk
|
||||
*/
|
||||
export const sessionCanceled = createAsyncThunk<
|
||||
CancelSessionResponse,
|
||||
CancelSessionArg,
|
||||
CancelSessionThunkConfig
|
||||
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
|
||||
const { session_id } = arg;
|
||||
const { DELETE } = $client.get();
|
||||
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
|
||||
params: {
|
||||
path: { session_id },
|
||||
},
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
|
||||
type ListSessionsArg = {
|
||||
params: paths['/api/v1/sessions/']['get']['parameters'];
|
||||
};
|
||||
|
||||
type ListSessionsResponse =
|
||||
paths['/api/v1/sessions/']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListSessionsThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: ListSessionsArg;
|
||||
error: unknown;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* `SessionsService.listSessions()` thunk
|
||||
*/
|
||||
export const listedSessions = createAsyncThunk<
|
||||
ListSessionsResponse,
|
||||
ListSessionsArg,
|
||||
ListSessionsThunkConfig
|
||||
>('api/listSessions', async (arg, { rejectWithValue }) => {
|
||||
const { params } = arg;
|
||||
const { GET } = $client.get();
|
||||
const { data, error } = await GET('/api/v1/sessions/', {
|
||||
params,
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
|
||||
export const isAnySessionRejected = isAnyOf(
|
||||
sessionCreated.rejected,
|
||||
sessionInvoked.rejected
|
||||
);
|
||||
@@ -1,5 +1,4 @@
|
||||
import { ThemeOverride } from '@chakra-ui/react';
|
||||
|
||||
import { ThemeOverride, ToastProviderProps } from '@chakra-ui/react';
|
||||
import { InvokeAIColors } from './colors/colors';
|
||||
import { accordionTheme } from './components/accordion';
|
||||
import { buttonTheme } from './components/button';
|
||||
@@ -149,3 +148,7 @@ export const theme: ThemeOverride = {
|
||||
Tooltip: tooltipTheme,
|
||||
},
|
||||
};
|
||||
|
||||
export const TOAST_OPTIONS: ToastProviderProps = {
|
||||
defaultOptions: { isClosable: true },
|
||||
};
|
||||
|
||||
@@ -37,7 +37,7 @@ dependencies = [
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel~=2.0.2",
|
||||
"controlnet-aux>=0.0.6",
|
||||
"controlnet-aux>=0.0.7",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.21.0",
|
||||
@@ -52,6 +52,10 @@ dependencies = [
|
||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
"mmcv>=2.0.1",
|
||||
"mmdet>=3.1.0",
|
||||
"mmengine",
|
||||
"mmpose>=1.1.0",
|
||||
"numpy",
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
|
||||
Reference in New Issue
Block a user