Merge branch 'main' into refactor/rename-get-logger

This commit is contained in:
psychedelicious
2023-09-05 10:37:53 +10:00
committed by GitHub
232 changed files with 8608 additions and 11111 deletions

View File

@@ -60,7 +60,7 @@ class Config:
thumbnail_path = None
def find_and_load(self):
"""find the yaml config file and load"""
"""Find the yaml config file and load"""
root = app_config.root_path
if not self.confirm_and_load(os.path.abspath(root)):
print("\r\nSpecify custom database and outputs paths:")
@@ -70,7 +70,7 @@ class Config:
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
def confirm_and_load(self, invoke_root):
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
"""Validate a yaml path exists, confirms the user wants to use it and loads config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
@@ -337,33 +337,24 @@ class InvokeAIMetadataParser:
def map_scheduler(self, old_scheduler):
"""Convert the legacy sampler names to matching 3.0 schedulers"""
# this was more elegant as a case statement, but that's not available in python 3.9
if old_scheduler is None:
return None
match (old_scheduler):
case "ddim":
return "ddim"
case "plms":
return "pnmd"
case "k_lms":
return "lms"
case "k_dpm_2":
return "kdpm_2"
case "k_dpm_2_a":
return "kdpm_2_a"
case "dpmpp_2":
return "dpmpp_2s"
case "k_dpmpp_2":
return "dpmpp_2m"
case "k_dpmpp_2_a":
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
case "k_euler":
return "euler"
case "k_euler_a":
return "euler_a"
case "k_heun":
return "heun"
return None
scheduler_map = dict(
ddim="ddim",
plms="pnmd",
k_lms="lms",
k_dpm_2="kdpm_2",
k_dpm_2_a="kdpm_2_a",
dpmpp_2="dpmpp_2s",
k_dpmpp_2="dpmpp_2m",
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
k_euler="euler",
k_euler_a="euler_a",
k_heun="heun",
)
return scheduler_map.get(old_scheduler)
def split_prompt(self, raw_prompt: str):
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
@@ -524,27 +515,27 @@ class MediaImportProcessor:
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
)
input_option = input("Specify desired board option: ")
match (input_option):
case "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
# This was more elegant as a case statement, but not supported in python 3.9
if input_option == "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
return board_name
elif input_option == "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "3":
return "IMPORT"
case "4":
return f"IMPORT_{timestamp_string}"
case "5":
return "IMPORT_APPVERSION"
elif input_option == "3":
return "IMPORT"
elif input_option == "4":
return f"IMPORT_{timestamp_string}"
elif input_option == "5":
return "IMPORT_APPVERSION"
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""

View File

@@ -7,5 +7,4 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*
src/services/api/schema.d.ts

View File

@@ -7,8 +7,7 @@ index.html
.yarn/
.yalc/
*.scss
src/services/api/
src/services/fixtures/*
src/services/api/schema.d.ts
docs/
static/
src/theme/css/overlayscrollbars.css

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-ext-wght-normal-848492d3.woff2) format("woff2-variations");unicode-range:U+0460-052F,U+1C80-1C88,U+20B4,U+2DE0-2DFF,U+A640-A69F,U+FE2E-FE2F}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-wght-normal-262a1054.woff2) format("woff2-variations");unicode-range:U+0301,U+0400-045F,U+0490-0491,U+04B0-04B1,U+2116}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-ext-wght-normal-fe977ddb.woff2) format("woff2-variations");unicode-range:U+1F00-1FFF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-wght-normal-89b4a3fe.woff2) format("woff2-variations");unicode-range:U+0370-03FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-vietnamese-wght-normal-ac4e131c.woff2) format("woff2-variations");unicode-range:U+0102-0103,U+0110-0111,U+0128-0129,U+0168-0169,U+01A0-01A1,U+01AF-01B0,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1EA0-1EF9,U+20AB}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-ext-wght-normal-45606f83.woff2) format("woff2-variations");unicode-range:U+0100-02AF,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1E00-1EFF,U+2020,U+20A0-20AB,U+20AD-20CF,U+2113,U+2C60-2C7F,U+A720-A7FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-wght-normal-450f3ba4.woff2) format("woff2-variations");unicode-range:U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD}/*!
@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-ext-wght-normal-848492d3.woff2) format("woff2-variations");unicode-range:U+0460-052F,U+1C80-1C88,U+20B4,U+2DE0-2DFF,U+A640-A69F,U+FE2E-FE2F}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-wght-normal-262a1054.woff2) format("woff2-variations");unicode-range:U+0301,U+0400-045F,U+0490-0491,U+04B0-04B1,U+2116}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-ext-wght-normal-fe977ddb.woff2) format("woff2-variations");unicode-range:U+1F00-1FFF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-wght-normal-89b4a3fe.woff2) format("woff2-variations");unicode-range:U+0370-03FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-vietnamese-wght-normal-ac4e131c.woff2) format("woff2-variations");unicode-range:U+0102-0103,U+0110-0111,U+0128-0129,U+0168-0169,U+01A0-01A1,U+01AF-01B0,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1EA0-1EF9,U+20AB}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-ext-wght-normal-45606f83.woff2) format("woff2-variations");unicode-range:U+0100-02AF,U+0304,U+0308,U+0329,U+1E00-1E9F,U+1EF2-1EFF,U+2020,U+20A0-20AB,U+20AD-20CF,U+2113,U+2C60-2C7F,U+A720-A7FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-wght-normal-450f3ba4.woff2) format("woff2-variations");unicode-range:U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0304,U+0308,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD}/*!
* OverlayScrollbars
* Version: 2.2.1
*

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-2c171c8f.js"></script>
<script type="module" crossorigin src="./assets/index-08cda350.js"></script>
</head>
<body dir="ltr">

View File

@@ -19,7 +19,7 @@
"toggleAutoscroll": "Toggle autoscroll",
"toggleLogViewer": "Toggle Log Viewer",
"showGallery": "Show Gallery",
"showOptionsPanel": "Show Options Panel",
"showOptionsPanel": "Show Side Panel",
"menu": "Menu"
},
"common": {
@@ -52,7 +52,7 @@
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Node Editor",
"nodes": "Workflow Editor",
"batch": "Batch Manager",
"modelManager": "Model Manager",
"postprocessing": "Post Processing",
@@ -95,7 +95,6 @@
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged",
"pinOptionsPanel": "Pin Options Panel",
"loading": "Loading",
"loadingInvokeAI": "Loading Invoke AI",
"random": "Random",
@@ -116,7 +115,6 @@
"maintainAspectRatio": "Maintain Aspect Ratio",
"autoSwitchNewImages": "Auto-Switch to New Images",
"singleColumnLayout": "Single Column Layout",
"pinGallery": "Pin Gallery",
"allImagesLoaded": "All Images Loaded",
"loadMore": "Load More",
"noImagesInGallery": "No Images to Display",
@@ -133,6 +131,7 @@
"generalHotkeys": "General Hotkeys",
"galleryHotkeys": "Gallery Hotkeys",
"unifiedCanvasHotkeys": "Unified Canvas Hotkeys",
"nodesHotkeys": "Nodes Hotkeys",
"invoke": {
"title": "Invoke",
"desc": "Generate an image"
@@ -332,6 +331,10 @@
"acceptStagingImage": {
"title": "Accept Staging Image",
"desc": "Accept Current Staging Area Image"
},
"addNodes": {
"title": "Add Nodes",
"desc": "Opens the add node menu"
}
},
"modelManager": {
@@ -503,13 +506,15 @@
"hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size",
"codeformerFidelity": "Fidelity",
"compositingSettingsHeader": "Compositing Settings",
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"seamSize": "Seam Size",
"seamBlur": "Seam Blur",
"seamStrength": "Seam Strength",
"seamSteps": "Seam Steps",
"maskBlur": "Blur",
"maskBlurMethod": "Blur Method",
"coherencePassHeader": "Coherence Pass",
"coherenceSteps": "Steps",
"coherenceStrength": "Strength",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"scaleBeforeProcessing": "Scale Before Processing",
"scaledWidth": "Scaled W",
"scaledHeight": "Scaled H",
@@ -565,10 +570,11 @@
"useSlidersForAll": "Use Sliders For All Options",
"showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"autoChangeDimensions": "Update W/H To Model Defaults On Change",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
"resetComplete": "Web UI has been reset.",
"consoleLogLevel": "Log Level",
"shouldLogToConsole": "Console Logging",
"developer": "Developer",
@@ -708,14 +714,16 @@
"ui": {
"showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes"
"swapSizes": "Swap Sizes",
"lockRatio": "Lock Ratio"
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveGraph": "Save Graph",
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
"clearGraph": "Clear Graph",
"clearGraphDesc": "Are you sure you want to clear all nodes?",
"reloadNodeTemplates": "Reload Node Templates",
"downloadWorkflow": "Download Workflow JSON",
"loadWorkflow": "Load Workflow",
"resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
"resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View",

View File

@@ -74,6 +74,8 @@
"@nanostores/react": "^0.7.1",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"@stevebel/png": "^1.5.1",
"compare-versions": "^6.1.0",
"dateformat": "^5.0.3",
"formik": "^2.4.3",
"framer-motion": "^10.16.1",
@@ -110,6 +112,7 @@
"roarr": "^7.15.1",
"serialize-error": "^11.0.1",
"socket.io-client": "^4.7.2",
"type-fest": "^4.2.0",
"use-debounce": "^9.0.4",
"use-image": "^1.1.1",
"uuid": "^9.0.0",

View File

@@ -506,12 +506,13 @@
"hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size",
"codeformerFidelity": "Fidelity",
"compositingSettingsHeader": "Compositing Settings",
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"maskBlur": "Blur",
"maskBlurMethod": "Blur Method",
"coherencePassHeader": "Coherence Pass",
"coherenceSteps": "Coherence Pass Steps",
"coherenceStrength": "Coherence Pass Strength",
"coherenceSteps": "Steps",
"coherenceStrength": "Strength",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"scaleBeforeProcessing": "Scale Before Processing",
@@ -569,6 +570,7 @@
"useSlidersForAll": "Use Sliders For All Options",
"showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"autoChangeDimensions": "Update W/H To Model Defaults On Change",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
@@ -712,11 +714,12 @@
"ui": {
"showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes"
"swapSizes": "Swap Sizes",
"lockRatio": "Lock Ratio"
},
"nodes": {
"reloadNodeTemplates": "Reload Node Templates",
"saveWorkflow": "Save Workflow",
"downloadWorkflow": "Download Workflow JSON",
"loadWorkflow": "Load Workflow",
"resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",

View File

@@ -14,6 +14,7 @@ import i18n from 'i18n';
import { size } from 'lodash-es';
import { ReactNode, memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import { usePreselectedImage } from '../../features/parameters/hooks/usePreselectedImage';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
@@ -23,13 +24,22 @@ const DEFAULT_CONFIG = {};
interface Props {
config?: PartialAppConfig;
headerComponent?: ReactNode;
selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
}
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
const App = ({
config = DEFAULT_CONFIG,
headerComponent,
selectedImage,
}: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
const { handlePreselectedImage } = usePreselectedImage();
const handleReset = useCallback(() => {
localStorage.clear();
location.reload();
@@ -51,6 +61,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
dispatch(appStarted());
}, [dispatch]);
useEffect(() => {
handlePreselectedImage(selectedImage);
}, [handlePreselectedImage, selectedImage]);
return (
<ErrorBoundary
onReset={handleReset}

View File

@@ -26,6 +26,10 @@ interface Props extends PropsWithChildren {
headerComponent?: ReactNode;
middleware?: Middleware[];
projectId?: string;
selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
}
const InvokeAIUI = ({
@@ -35,6 +39,7 @@ const InvokeAIUI = ({
headerComponent,
middleware,
projectId,
selectedImage,
}: Props) => {
useEffect(() => {
// configure API client token
@@ -81,7 +86,11 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<AppDndContext>
<App config={config} headerComponent={headerComponent} />
<App
config={config}
headerComponent={headerComponent}
selectedImage={selectedImage}
/>
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>

View File

@@ -15,7 +15,9 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasImageToControlNetListener } from './listeners/canvasImageToControlNet';
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
import { addCanvasMaskToControlNetListener } from './listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
@@ -41,6 +43,8 @@ import {
addImageUploadedFulfilledListener,
addImageUploadedRejectedListener,
} from './listeners/imageUploaded';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
@@ -80,8 +84,7 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
export const listenerMiddleware = createListenerMiddleware();
@@ -137,6 +140,8 @@ addSessionReadyToInvokeListener();
// Canvas actions
addCanvasSavedToGalleryListener();
addCanvasMaskSavedToGalleryListener();
addCanvasImageToControlNetListener();
addCanvasMaskToControlNetListener();
addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener();
addCanvasMergedListener();
@@ -198,6 +203,9 @@ addBoardIdSelectedListener();
// Node schemas
addReceivedOpenAPISchemaListener();
// Workflows
addWorkflowLoadedListener();
// DND
addImageDroppedListener();

View File

@@ -0,0 +1,58 @@
import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
export const addCanvasImageToControlNetListener = () => {
startAppListening({
actionCreator: canvasImageToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
log.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Saving Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Canvas Sent to ControlNet & Assets' },
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};

View File

@@ -0,0 +1,70 @@
import { logger } from 'app/logging/logger';
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
export const addCanvasMaskToControlNetListener = () => {
startAppListening({
actionCreator: canvasMaskToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
dispatch(
addToast({
title: 'Problem Importing Mask',
description: 'Unable to export mask',
status: 'error',
})
);
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Mask Sent to ControlNet & Assets' },
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};

View File

@@ -1,9 +1,12 @@
import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
modelChanged,
setHeight,
setWidth,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
@@ -74,6 +77,22 @@ export const addModelSelectedListener = () => {
}
}
// Update Width / Height / Bounding Box Dimensions on Model Change
if (
state.generation.model?.base_model !== newModel.base_model &&
state.ui.shouldAutoChangeDimensions
) {
if (['sdxl', 'sdxl-refiner'].includes(newModel.base_model)) {
dispatch(setWidth(1024));
dispatch(setHeight(1024));
dispatch(setBoundingBoxDimensions({ width: 1024, height: 1024 }));
} else {
dispatch(setWidth(512));
dispatch(setHeight(512));
dispatch(setBoundingBoxDimensions({ width: 512, height: 512 }));
}
}
dispatch(modelChanged(newModel));
},
});

View File

@@ -0,0 +1,55 @@
import { logger } from 'app/logging/logger';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { startAppListening } from '..';
export const addWorkflowLoadedListener = () => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const workflow = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates;
const { workflow: validatedWorkflow, errors } = validateWorkflow(
workflow,
nodeTemplates
);
dispatch(workflowLoaded(validatedWorkflow));
if (!errors.length) {
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded with Warnings',
status: 'warning',
})
)
);
errors.forEach(({ message, ...rest }) => {
log.warn(rest, message);
});
}
dispatch(setActiveTab('nodes'));
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
},
});
};

View File

@@ -6,11 +6,11 @@ import {
configureStore,
} from '@reduxjs/toolkit';
import canvasReducer from 'features/canvas/store/canvasSlice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice';

View File

@@ -86,8 +86,8 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box
sx={{
p: 2,
pt: 3,
p: 4,
pb: 4,
borderBottomRadius: 'base',
bg: 'base.150',
_dark: {

View File

@@ -1,10 +1,12 @@
import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { AnimatePresence, motion } from 'framer-motion';
import {
KeyboardEvent,
ReactNode,
@@ -18,8 +20,6 @@ import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import ImageUploadOverlay from './ImageUploadOverlay';
import { AnimatePresence, motion } from 'framer-motion';
import { stateSelector } from 'app/store/store';
const selector = createSelector(
[stateSelector, activeTabNameSelector],

View File

@@ -0,0 +1,56 @@
import { Box } from '@chakra-ui/react';
import { memo, useMemo } from 'react';
type Props = {
isSelected: boolean;
isHovered: boolean;
};
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
const shadow = useMemo(() => {
if (isSelected && isHovered) {
return 'nodeHoveredSelected.light';
}
if (isSelected) {
return 'nodeSelected.light';
}
if (isHovered) {
return 'nodeHovered.light';
}
return undefined;
}, [isHovered, isSelected]);
const shadowDark = useMemo(() => {
if (isSelected && isHovered) {
return 'nodeHoveredSelected.dark';
}
if (isSelected) {
return 'nodeSelected.dark';
}
if (isHovered) {
return 'nodeHovered.dark';
}
return undefined;
}, [isHovered, isSelected]);
return (
<Box
className="selection-box"
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
opacity: isSelected || isHovered ? 1 : 0.5,
transitionProperty: 'common',
transitionDuration: '0.1s',
pointerEvents: 'none',
shadow,
_dark: {
shadow: shadowDark,
},
}}
/>
);
};
export default memo(SelectionOverlay);

View File

@@ -63,7 +63,11 @@ const selector = createSelector(
return;
}
if (fieldTemplate.required && !field.value && !hasConnection) {
if (
fieldTemplate.required &&
field.value === undefined &&
!hasConnection
) {
reasons.push(
`${node.data.label || nodeTemplate.title} -> ${
field.label || fieldTemplate.title

View File

@@ -1,2 +1,2 @@
export const colorTokenToCssVar = (colorToken: string) =>
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;

View File

@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@@ -20,3 +21,11 @@ export const canvasMerged = createAction('canvas/canvasMerged');
export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved'
);
export const canvasMaskToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasMaskToControlNet');
export const canvasImageToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasImageToControlNet');

View File

@@ -17,11 +17,13 @@ import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
@@ -36,6 +38,8 @@ const ControlNet = (props: ControlNetProps) => {
const { controlNetId } = controlNet;
const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector);
const selector = createSelector(
stateSelector,
({ controlNet }) => {
@@ -108,6 +112,9 @@ const ControlNet = (props: ControlNetProps) => {
>
<ParamControlNetModel controlNet={controlNet} />
</Box>
{activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports controlNet={controlNet} />
)}
<IAIIconButton
size="sm"
tooltip="Duplicate"
@@ -167,6 +174,7 @@ const ControlNet = (props: ControlNetProps) => {
/>
)}
</Flex>
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex

View File

@@ -5,13 +5,21 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
import {
useAddImageToBoardMutation,
useChangeImageIsIntermediateMutation,
useGetImageDTOQuery,
useRemoveImageFromBoardMutation,
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import {
@@ -26,11 +34,13 @@ type Props = {
const selector = createSelector(
stateSelector,
({ controlNet }) => {
({ controlNet, gallery }) => {
const { pendingControlImages } = controlNet;
const { autoAddBoardId } = gallery;
return {
pendingControlImages,
autoAddBoardId,
};
},
defaultSelectorOptions
@@ -47,7 +57,8 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);
const activeTabName = useAppSelector(activeTabNameSelector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@@ -59,9 +70,57 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
processedControlImageName ?? skipToken
);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]);
const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) {
return;
}
await changeIsIntermediate({
imageDTO: processedControlImage,
is_intermediate: false,
}).unwrap();
if (autoAddBoardId !== 'none') {
addToBoard({
imageDTO: processedControlImage,
board_id: autoAddBoardId,
});
} else {
removeFromBoard({ imageDTO: processedControlImage });
}
}, [
processedControlImage,
changeIsIntermediate,
autoAddBoardId,
addToBoard,
removeFromBoard,
]);
const handleSetControlImageToDimensions = useCallback(() => {
if (!controlImage) {
return;
}
if (activeTabName === 'unifiedCanvas') {
dispatch(
setBoundingBoxDimensions({
width: controlImage.width,
height: controlImage.height,
})
);
} else {
dispatch(setWidth(controlImage.width));
dispatch(setHeight(controlImage.height));
}
}, [controlImage, activeTabName, dispatch]);
const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true);
}, []);
@@ -121,13 +180,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage || !isEnabled}
postUploadAction={postUploadAction}
>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
</IAIDndImage>
/>
<Box
sx={{
@@ -148,14 +201,29 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
imageDTO={processedControlImage}
isUploadDisabled={true}
isDropDisabled={!isEnabled}
>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
</IAIDndImage>
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <FaSave size={16} /> : undefined}
tooltip="Save Control Image"
styleOverrides={{ marginTop: 6 }}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <FaRulerVertical size={16} /> : undefined}
tooltip="Set Control Image Dimensions To W/H"
styleOverrides={{ marginTop: 12 }}
/>
</>
{pendingControlImages.includes(controlNetId) && (
<Flex
sx={{

View File

@@ -0,0 +1,54 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
canvasImageToControlNet,
canvasMaskToControlNet,
} from 'features/canvas/store/actions';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { FaImage, FaMask } from 'react-icons/fa';
type ControlNetCanvasImageImportsProps = {
controlNet: ControlNetConfig;
};
const ControlNetCanvasImageImports = (
props: ControlNetCanvasImageImportsProps
) => {
const { controlNet } = props;
const dispatch = useAppDispatch();
const handleImportImageFromCanvas = useCallback(() => {
dispatch(canvasImageToControlNet({ controlNet }));
}, [controlNet, dispatch]);
const handleImportMaskFromCanvas = useCallback(() => {
dispatch(canvasMaskToControlNet({ controlNet }));
}, [controlNet, dispatch]);
return (
<Flex
sx={{
gap: 2,
}}
>
<IAIIconButton
size="sm"
icon={<FaImage />}
tooltip="Import Image From Canvas"
aria-label="Import Image From Canvas"
onClick={handleImportImageFromCanvas}
/>
<IAIIconButton
size="sm"
icon={<FaMask />}
tooltip="Import Mask From Canvas"
aria-label="Import Mask From Canvas"
onClick={handleImportMaskFromCanvas}
/>
</Flex>
);
};
export default memo(ControlNetCanvasImageImports);

View File

@@ -4,11 +4,11 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import { memo } from 'react';
const selector = createSelector(
stateSelector,

View File

@@ -15,6 +15,7 @@ import { BoardDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type Props = {
board?: BoardDTO;
@@ -33,12 +34,16 @@ const BoardContextMenu = ({
const selector = useMemo(
() =>
createSelector(stateSelector, ({ gallery, system }) => {
const isAutoAdd = gallery.autoAddBoardId === board_id;
const isProcessing = system.isProcessing;
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
}),
createSelector(
stateSelector,
({ gallery, system }) => {
const isAutoAdd = gallery.autoAddBoardId === board_id;
const isProcessing = system.isProcessing;
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
},
defaultSelectorOptions
),
[board_id]
);

View File

@@ -9,14 +9,15 @@ import {
MenuButton,
MenuList,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
@@ -37,12 +38,12 @@ import {
FaSeedling,
FaShareAlt,
} from 'react-icons/fa';
import { MdDeviceHub } from 'react-icons/md';
import {
useGetImageDTOQuery,
useGetImageMetadataQuery,
useGetImageMetadataFromFileQuery,
} from 'services/api/endpoints/images';
import { menuListMotionProps } from 'theme/components/menu';
import { useDebounce } from 'use-debounce';
import { sentImageToImg2Img } from '../../store/actions';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
@@ -101,22 +102,27 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
lastSelectedImage,
500
);
const { currentData: imageDTO } = useGetImageDTOQuery(
lastSelectedImage?.image_name ?? skipToken
);
const { currentData: metadataData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg?.image_name ?? skipToken
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
lastSelectedImage ?? skipToken,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const metadata = metadataData?.metadata;
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoadRequested(workflow));
}, [dispatch, workflow]);
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata);
@@ -153,6 +159,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('p', handleUsePrompt, [imageDTO]);
useHotkeys('w', handleLoadWorkflow, [workflow]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
@@ -259,22 +267,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton
isLoading={isLoading}
icon={<MdDeviceHub />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!workflow}
onClick={handleLoadWorkflow}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!metadata?.positive_prompt}
onClick={handleUsePrompt}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!metadata?.seed}
onClick={handleUseSeed}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaAsterisk />}
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}

View File

@@ -1,5 +1,4 @@
import { Flex, MenuItem, Text } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@@ -26,15 +25,15 @@ import {
FaShare,
FaTrash,
} from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md';
import { MdDeviceHub, MdStar, MdStarBorder } from 'react-icons/md';
import {
useGetImageMetadataQuery,
useGetImageMetadataFromFileQuery,
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { workflowLoadRequested } from 'features/nodes/store/actions';
type SingleSelectionMenuItemsProps = {
imageDTO: ImageDTO;
@@ -50,15 +49,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
imageDTO.image_name,
500
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const [starImages] = useStarImagesMutation();
@@ -67,8 +66,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const metadata = currentData?.metadata;
const handleDelete = useCallback(() => {
if (!imageDTO) {
return;
@@ -99,6 +96,13 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoadRequested(workflow));
}, [dispatch, workflow]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
@@ -118,7 +122,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
}, [dispatch, imageDTO, t, toaster]);
const handleUseAllParameters = useCallback(() => {
console.log(metadata);
recallAllParameters(metadata);
}, [metadata, recallAllParameters]);
@@ -169,27 +172,34 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem
icon={<FaQuoteRight />}
icon={isLoading ? <SpinnerIcon /> : <MdDeviceHub />}
onClickCapture={handleLoadWorkflow}
isDisabled={isLoading || !workflow}
>
{t('nodes.loadWorkflow')}
</MenuItem>
<MenuItem
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
onClickCapture={handleRecallPrompt}
isDisabled={
metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined
isLoading ||
(metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined)
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<FaSeedling />}
icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined}
isDisabled={isLoading || metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
icon={<FaAsterisk />}
icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
onClickCapture={handleUseAllParameters}
isDisabled={!metadata}
isDisabled={isLoading || !metadata}
>
{t('parameters.useAll')}
</MenuItem>
@@ -228,20 +238,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
>
{t('gallery.deleteImage')}
</MenuItem>
{metadata?.created_by && (
<Flex
sx={{
padding: '5px 10px',
marginTop: '5px',
}}
>
<Text fontSize="xs" fontWeight="bold">
Created by {metadata?.created_by}
</Text>
</Flex>
)}
</>
);
};
export default memo(SingleSelectionMenuItems);
const SpinnerIcon = () => (
<Flex w="14px" alignItems="center" justifyContent="center">
<Spinner size="xs" />
</Flex>
);

View File

@@ -39,7 +39,7 @@ const ImageGalleryContent = () => {
const { galleryView } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { isOpen: isBoardListOpen, onToggle: onToggleBoardList } =
useDisclosure();
useDisclosure({ defaultIsOpen: true });
const handleClickImages = useCallback(() => {
dispatch(galleryViewChanged('images'));

View File

@@ -8,7 +8,7 @@ import {
ImageDraggableData,
TypesafeDraggableData,
} from 'features/dnd/types';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
import { FaTrash } from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md';

View File

@@ -2,7 +2,7 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { FaCopy, FaSave } from 'react-icons/fa';
import { FaCopy, FaDownload } from 'react-icons/fa';
type Props = {
label: string;
@@ -23,7 +23,7 @@ const DataViewer = (props: Props) => {
navigator.clipboard.writeText(dataString);
}, [dataString]);
const handleSave = useCallback(() => {
const handleDownload = useCallback(() => {
const blob = new Blob([dataString]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
@@ -73,13 +73,13 @@ const DataViewer = (props: Props) => {
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
{withDownload && (
<Tooltip label={`Save ${label} JSON`}>
<Tooltip label={`Download ${label} JSON`}>
<IconButton
aria-label={`Save ${label} JSON`}
icon={<FaSave />}
aria-label={`Download ${label} JSON`}
icon={<FaDownload />}
variant="ghost"
opacity={0.7}
onClick={handleSave}
onClick={handleDownload}
/>
</Tooltip>
)}

View File

@@ -1,10 +1,10 @@
import { CoreMetadata } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react';
import { UnsafeImageMetadata } from 'services/api/types';
import ImageMetadataItem from './ImageMetadataItem';
type Props = {
metadata?: UnsafeImageMetadata['metadata'];
metadata?: CoreMetadata;
};
const ImageMetadataActions = (props: Props) => {
@@ -94,20 +94,22 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallNegativePrompt}
/>
)}
{metadata.seed !== undefined && (
{metadata.seed !== undefined && metadata.seed !== null && (
<ImageMetadataItem
label="Seed"
value={metadata.seed}
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.model !== undefined &&
metadata.model !== null &&
metadata.model.model_name && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.width && (
<ImageMetadataItem
label="Width"
@@ -150,7 +152,7 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallSteps}
/>
)}
{metadata.cfg_scale !== undefined && (
{metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
<ImageMetadataItem
label="CFG scale"
value={metadata.cfg_scale}

View File

@@ -9,14 +9,12 @@ import {
Tabs,
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import ImageMetadataActions from './ImageMetadataActions';
import DataViewer from './DataViewer';
import ImageMetadataActions from './ImageMetadataActions';
type ImageMetadataViewerProps = {
image: ImageDTO;
@@ -29,18 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false));
// });
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
image.image_name,
500
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
);
const metadata = currentData?.metadata;
const graph = currentData?.graph;
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
});
return (
<Flex
@@ -71,17 +63,17 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
<TabList>
<Tab>Core Metadata</Tab>
<Tab>Metadata</Tab>
<Tab>Image Details</Tab>
<Tab>Graph</Tab>
<Tab>Workflow</Tab>
</TabList>
<TabPanels>
<TabPanel>
{metadata ? (
<DataViewer data={metadata} label="Core Metadata" />
<DataViewer data={metadata} label="Metadata" />
) : (
<IAINoContentFallback label="No core metadata found" />
<IAINoContentFallback label="No metadata found" />
)}
</TabPanel>
<TabPanel>
@@ -92,10 +84,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
)}
</TabPanel>
<TabPanel>
{graph ? (
<DataViewer data={graph} label="Graph" />
{workflow ? (
<DataViewer data={workflow} label="Workflow" />
) : (
<IAINoContentFallback label="No graph found" />
<IAINoContentFallback label="No workflow found" />
)}
</TabPanel>
</TabPanels>

View File

@@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
@@ -13,6 +14,7 @@ import {
OnConnectStart,
OnEdgesChange,
OnEdgesDelete,
OnInit,
OnMoveEnd,
OnNodesChange,
OnNodesDelete,
@@ -147,6 +149,11 @@ export const Flow = () => {
dispatch(contextMenusClosed());
}, [dispatch]);
const onInit: OnInit = useCallback((flow) => {
$flow.set(flow);
flow.fitView();
}, []);
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault();
dispatch(selectionCopied());
@@ -170,6 +177,7 @@ export const Flow = () => {
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
onInit={onInit}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}

View File

@@ -1,13 +1,15 @@
import { Flex, Image, Text } from '@chakra-ui/react';
import { useState, PropsWithChildren, memo } from 'react';
import { useSelector } from 'react-redux';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { Flex, Image, Text } from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { NodeProps } from 'reactflow';
import NodeWrapper from '../common/NodeWrapper';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { PropsWithChildren, memo } from 'react';
import { useSelector } from 'react-redux';
import { NodeProps } from 'reactflow';
import NodeWrapper from '../common/NodeWrapper';
import { stateSelector } from 'app/store/store';
const selector = createSelector(stateSelector, ({ system, gallery }) => {
const imageDTO = gallery.selection[gallery.selection.length - 1];
@@ -54,44 +56,90 @@ const CurrentImageNode = (props: NodeProps) => {
export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
<NodeWrapper
nodeId={props.nodeProps.data.id}
selected={props.nodeProps.selected}
width={384}
>
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{
flexDirection: 'column',
}}
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => {
const [isHovering, setIsHovering] = useState(false);
const handleMouseEnter = () => {
setIsHovering(true);
};
const handleMouseLeave = () => {
setIsHovering(false);
};
return (
<NodeWrapper
nodeId={props.nodeProps.id}
selected={props.nodeProps.selected}
width={384}
>
<Flex
layerStyle="nodeHeader"
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
className={DRAG_HANDLE_CLASSNAME}
sx={{
borderTopRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
h: 8,
position: 'relative',
flexDirection: 'column',
}}
>
<Text
<Flex
layerStyle="nodeHeader"
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'base.700',
_dark: { color: 'base.200' },
borderTopRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
h: 8,
}}
>
Current Image
</Text>
<Text
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'base.700',
_dark: { color: 'base.200' },
}}
>
Current Image
</Text>
</Flex>
<Flex
layerStyle="nodeBody"
sx={{
w: 'full',
h: 'full',
borderBottomRadius: 'base',
p: 2,
}}
>
{props.children}
{isHovering && (
<motion.div
key="nextPrevButtons"
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
style={{
position: 'absolute',
top: 40,
left: -2,
right: -2,
bottom: 0,
pointerEvents: 'none',
}}
>
<NextPrevImageButtons />
</motion.div>
)}
</Flex>
</Flex>
<Flex
layerStyle="nodeBody"
sx={{ w: 'full', h: 'full', borderBottomRadius: 'base', p: 2 }}
>
{props.children}
</Flex>
</Flex>
</NodeWrapper>
);
</NodeWrapper>
);
};

View File

@@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const embedWorkflow = useEmbedWorkflow(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeEmbedWorkflowChanged({
nodeId,
embedWorkflow: e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={embedWorkflow}
/>
</FormControl>
);
};
export default memo(EmbedWorkflowCheckbox);

View File

@@ -41,7 +41,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
flexDirection: 'column',
w: 'full',
h: 'full',
py: 1,
py: 2,
gap: 1,
borderBottomRadius: withFooter ? 0 : 'base',
}}

View File

@@ -1,16 +1,8 @@
import {
Checkbox,
Flex,
FormControl,
FormLabel,
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { Flex } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo } from 'react';
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
type Props = {
nodeId: string;
@@ -27,48 +19,13 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
px: 2,
py: 0,
h: 6,
justifyContent: 'space-between',
}}
>
<Spacer />
<SaveImageCheckbox nodeId={nodeId} />
<EmbedWorkflowCheckbox nodeId={nodeId} />
<SaveToGalleryCheckbox nodeId={nodeId} />
</Flex>
);
};
export default memo(InvocationNodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@@ -1,7 +1,5 @@
import {
Flex,
FormControl,
FormLabel,
Icon,
Modal,
ModalBody,
@@ -14,16 +12,16 @@ import {
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { compare } from 'compare-versions';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo, useMemo } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
import NotesTextarea from './NotesTextarea';
import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch';
interface Props {
nodeId: string;
@@ -33,6 +31,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const label = useNodeLabel(nodeId);
const title = useNodeTemplateTitle(nodeId);
const doVersionsMatch = useDoNodeVersionsMatch(nodeId);
return (
<>
@@ -54,7 +53,11 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
>
<Icon
as={FaInfoCircle}
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
sx={{
boxSize: 4,
w: 8,
color: doVersionsMatch ? 'base.400' : 'error.400',
}}
/>
</Flex>
</Tooltip>
@@ -80,45 +83,78 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
const title = useMemo(() => {
if (data?.label && nodeTemplate?.title) {
return `${data.label} (${nodeTemplate.title})`;
}
if (data?.label && !nodeTemplate) {
return data.label;
}
if (!data?.label && nodeTemplate) {
return nodeTemplate.title;
}
return 'Unknown Node';
}, [data, nodeTemplate]);
const versionComponent = useMemo(() => {
if (!isInvocationNodeData(data) || !nodeTemplate) {
return null;
}
if (!data.version) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version unknown
</Text>
);
}
if (!nodeTemplate.version) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version {data.version} (unknown template)
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '<')) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version {data.version} (update node)
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '>')) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version {data.version} (update app)
</Text>
);
}
return <Text as="span">Version {data.version}</Text>;
}, [data, nodeTemplate]);
if (!isInvocationNodeData(data)) {
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
}
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
<Text as="span" sx={{ fontWeight: 600 }}>
{title}
</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{nodeTemplate?.description}
</Text>
{versionComponent}
{data?.notes && <Text>{data.notes}</Text>}
</Flex>
);
});
TooltipContent.displayName = 'TooltipContent';
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
});
NotesTextarea.displayName = 'NodesTextarea';

View File

@@ -0,0 +1,33 @@
import { FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
};
export default memo(NotesTextarea);

View File

@@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { nodeIsIntermediateChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const SaveToGalleryCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const isIntermediate = useIsIntermediate(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeIsIntermediateChanged({
nodeId,
isIntermediate: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save to Gallery</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={!isIntermediate}
/>
</FormControl>
);
};
export default memo(SaveToGalleryCheckbox);

View File

@@ -0,0 +1,167 @@
import {
Editable,
EditableInput,
EditablePreview,
Flex,
Tooltip,
forwardRef,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
import FieldTooltipContent from './FieldTooltipContent';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
interface Props {
nodeId: string;
fieldName: string;
kind: 'input' | 'output';
isMissingInput?: boolean;
withTooltip?: boolean;
}
const EditableFieldTitle = forwardRef((props: Props, ref) => {
const {
nodeId,
fieldName,
kind,
isMissingInput = false,
withTooltip = false,
} = props;
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(
label || fieldTemplateTitle || 'Unknown Field'
);
const handleSubmit = useCallback(
async (newTitle: string) => {
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
},
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
}, [label, fieldTemplateTitle]);
return (
<Tooltip
label={
withTooltip ? (
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
) : undefined
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
hasArrow
>
<Flex
ref={ref}
sx={{
position: 'relative',
overflow: 'hidden',
alignItems: 'center',
justifyContent: 'flex-start',
gap: 1,
h: 'full',
}}
>
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
sx={{
position: 'relative',
alignItems: 'center',
h: 'full',
}}
>
<EditablePreview
sx={{
p: 0,
fontWeight: isMissingInput ? 600 : 400,
textAlign: 'left',
_hover: {
fontWeight: '600 !important',
},
}}
noOfLines={1}
/>
<EditableInput
className="nodrag"
sx={{
p: 0,
w: 'full',
fontWeight: 600,
color: 'base.900',
_dark: {
color: 'base.100',
},
_focusVisible: {
p: 0,
textAlign: 'left',
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
</Flex>
</Tooltip>
);
});
export default memo(EditableFieldTitle);
const EditableControls = memo(() => {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
e.preventDefault();
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Flex
onClick={handleClick}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
});
EditableControls.displayName = 'EditableControls';

View File

@@ -1,8 +1,11 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import {
COLLECTION_TYPES,
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import {
InputFieldTemplate,
@@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
borderWidth: 0,
zIndex: 1,
};
``;
export const inputHandleStyles: CSSProperties = {
left: '-1rem',
@@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color, title } = FIELDS[type];
const { color: typeColor, title } = FIELDS[type];
const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const s: CSSProperties = {
backgroundColor: colorTokenToCssVar(color),
backgroundColor:
isCollectionType || isPolymorphicType
? 'var(--invokeai-colors-base-900)'
: color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
zIndex: 1,
};
@@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
return s;
}, [
color,
connectionError,
handleType,
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);
const tooltip = useMemo(() => {

View File

@@ -1,16 +1,7 @@
import {
Editable,
EditableInput,
EditablePreview,
Flex,
forwardRef,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { Flex, Text, forwardRef } from '@chakra-ui/react';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
import { memo } from 'react';
interface Props {
nodeId: string;
@@ -24,31 +15,6 @@ const FieldTitle = forwardRef((props: Props, ref) => {
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(
label || fieldTemplateTitle || 'Unknown Field'
);
const handleSubmit = useCallback(
async (newTitle: string) => {
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
},
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
}, [label, fieldTemplateTitle]);
return (
<Flex
ref={ref}
@@ -62,82 +28,11 @@ const FieldTitle = forwardRef((props: Props, ref) => {
w: 'full',
}}
>
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
sx={{
position: 'relative',
alignItems: 'center',
h: 'full',
w: 'full',
}}
>
<EditablePreview
sx={{
p: 0,
fontWeight: isMissingInput ? 600 : 400,
textAlign: 'left',
_hover: {
fontWeight: '600 !important',
},
}}
noOfLines={1}
/>
<EditableInput
className="nodrag"
sx={{
p: 0,
fontWeight: 600,
color: 'base.900',
_dark: {
color: 'base.100',
},
_focusVisible: {
p: 0,
textAlign: 'left',
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
<Text sx={{ fontWeight: isMissingInput ? 600 : 400 }}>
{label || fieldTemplateTitle}
</Text>
</Flex>
);
});
export default memo(FieldTitle);
const EditableControls = memo(() => {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
e.preventDefault();
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Flex
onClick={handleClick}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
});
EditableControls.displayName = 'EditableControls';

View File

@@ -34,6 +34,8 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
}
return 'Unknown Field';
} else {
return fieldTemplate?.title || 'Unknown Field';
}
}, [field, fieldTemplate]);

View File

@@ -1,16 +1,11 @@
import { Box, Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { Box, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue';
import { useFieldInputKind } from 'features/nodes/hooks/useFieldInputKind';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, memo, useMemo } from 'react';
import EditableFieldTitle from './EditableFieldTitle';
import FieldContextMenu from './FieldContextMenu';
import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
interface Props {
@@ -21,7 +16,6 @@ interface Props {
const InputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const input = useFieldInputKind(nodeId, fieldName);
const {
isConnected,
@@ -51,11 +45,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
if (fieldTemplate?.fieldKind !== 'input') {
return (
<InputFieldWrapper
nodeId={nodeId}
fieldName={fieldName}
shouldDim={shouldDim}
>
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
>
@@ -66,19 +56,14 @@ const InputField = ({ nodeId, fieldName }: Props) => {
}
return (
<InputFieldWrapper
nodeId={nodeId}
fieldName={fieldName}
shouldDim={shouldDim}
>
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
as={Flex}
isInvalid={isMissingInput}
isDisabled={isConnected}
sx={{
alignItems: 'stretch',
justifyContent: 'space-between',
ps: 2,
ps: fieldTemplate.input === 'direct' ? 0 : 2,
gap: 2,
h: 'full',
w: 'full',
@@ -86,42 +71,28 @@ const InputField = ({ nodeId, fieldName }: Props) => {
>
<FieldContextMenu nodeId={nodeId} fieldName={fieldName} kind="input">
{(ref) => (
<Tooltip
label={
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
hasArrow
<FormLabel
sx={{
display: 'flex',
alignItems: 'center',
h: 'full',
mb: 0,
px: 1,
gap: 2,
}}
>
<FormLabel
sx={{
mb: 0,
width: input === 'connection' ? 'auto' : '25%',
flexShrink: 0,
flexGrow: 0,
}}
>
<FieldTitle
ref={ref}
nodeId={nodeId}
fieldName={fieldName}
kind="input"
isMissingInput={isMissingInput}
/>
</FormLabel>
</Tooltip>
<EditableFieldTitle
ref={ref}
nodeId={nodeId}
fieldName={fieldName}
kind="input"
isMissingInput={isMissingInput}
withTooltip
/>
</FormLabel>
)}
</FieldContextMenu>
<Box
sx={{
width: input === 'connection' ? 'auto' : '75%',
}}
>
<Box>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</Box>
</FormControl>
@@ -143,19 +114,12 @@ export default memo(InputField);
type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
nodeId: string;
fieldName: string;
}>;
const InputFieldWrapper = memo(
({ shouldDim, nodeId, fieldName, children }: InputFieldWrapperProps) => {
const { isMouseOverField, handleMouseOver, handleMouseOut } =
useIsMouseOverField(nodeId, fieldName);
({ shouldDim, children }: InputFieldWrapperProps) => {
return (
<Flex
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
sx={{
position: 'relative',
minH: 8,
@@ -169,7 +133,6 @@ const InputFieldWrapper = memo(
}}
>
{children}
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
</Flex>
);
}

View File

@@ -3,17 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { memo } from 'react';
import BooleanInputField from './inputs/BooleanInputField';
import ClipInputField from './inputs/ClipInputField';
import CollectionInputField from './inputs/CollectionInputField';
import CollectionItemInputField from './inputs/CollectionItemInputField';
import ColorInputField from './inputs/ColorInputField';
import ConditioningInputField from './inputs/ConditioningInputField';
import ControlInputField from './inputs/ControlInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import EnumInputField from './inputs/EnumInputField';
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
import ImageInputField from './inputs/ImageInputField';
import LatentsInputField from './inputs/LatentsInputField';
import LoRAModelInputField from './inputs/LoRAModelInputField';
import MainModelInputField from './inputs/MainModelInputField';
import NumberInputField from './inputs/NumberInputField';
@@ -21,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField';
import UnetInputField from './inputs/UnetInputField';
import VaeInputField from './inputs/VaeInputField';
import VaeModelInputField from './inputs/VaeModelInputField';
type InputFieldProps = {
@@ -30,7 +21,6 @@ type InputFieldProps = {
fieldName: string;
};
// build an individual input element based on the schema
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
@@ -92,75 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'LatentsField' &&
fieldTemplate?.type === 'LatentsField'
) {
return (
<LatentsInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
) {
return (
<ConditioningInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
return (
<UnetInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
return (
<ClipInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
return (
<VaeInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlField' &&
fieldTemplate?.type === 'ControlField'
) {
return (
<ControlInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
@@ -226,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
return (
<CollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'CollectionItem' &&
fieldTemplate?.type === 'CollectionItem'
) {
return (
<CollectionItemInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField
@@ -259,19 +157,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'ImageCollection' &&
fieldTemplate?.type === 'ImageCollection'
) {
return (
<ImageCollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
@@ -295,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (field && fieldTemplate) {
// Fallback for when there is no component for the type
return null;
}
return (
<Box p={1}>
<Text

View File

@@ -1,13 +1,20 @@
import { Flex, FormControl, FormLabel, Icon, Tooltip } from '@chakra-ui/react';
import {
Flex,
FormControl,
FormLabel,
Icon,
Spacer,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { memo, useCallback } from 'react';
import { FaInfoCircle, FaTrash } from 'react-icons/fa';
import FieldTitle from './FieldTitle';
import EditableFieldTitle from './EditableFieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
@@ -18,8 +25,8 @@ type Props = {
const LinearViewField = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { isMouseOverField, handleMouseOut, handleMouseOver } =
useIsMouseOverField(nodeId, fieldName);
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
useMouseOverNode(nodeId);
const handleRemoveField = useCallback(() => {
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
@@ -27,8 +34,8 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
return (
<Flex
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
onMouseEnter={handleMouseOver}
onMouseLeave={handleMouseOut}
layerStyle="second"
sx={{
position: 'relative',
@@ -42,11 +49,15 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
sx={{
display: 'flex',
alignItems: 'center',
justifyContent: 'space-between',
mb: 0,
}}
>
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
<EditableFieldTitle
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
<Spacer />
<Tooltip
label={
<FieldTooltipContent
@@ -74,7 +85,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
</FormLabel>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</FormControl>
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
<NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} />
</Flex>
);
};

View File

@@ -1,12 +1,17 @@
import {
ControlInputFieldTemplate,
ControlInputFieldValue,
ControlPolymorphicInputFieldTemplate,
ControlPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const ControlInputFieldComponent = (
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
_props: FieldComponentProps<
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
>
) => {
return null;
};

View File

@@ -92,6 +92,7 @@ const ControlNetModelInputFieldComponent = (
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};

View File

@@ -0,0 +1,17 @@
import {
DenoiseMaskInputFieldTemplate,
DenoiseMaskInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const DenoiseMaskInputFieldComponent = (
_props: FieldComponentProps<
DenoiseMaskInputFieldValue,
DenoiseMaskInputFieldTemplate
>
) => {
return null;
};
export default memo(DenoiseMaskInputFieldComponent);

View File

@@ -9,9 +9,9 @@ import {
} from 'features/dnd/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
ImageInputFieldTemplate,
ImageInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo, useCallback, useMemo } from 'react';
import { FaUndo } from 'react-icons/fa';

View File

@@ -2,11 +2,16 @@ import {
LatentsInputFieldTemplate,
LatentsInputFieldValue,
FieldComponentProps,
LatentsPolymorphicInputFieldValue,
LatentsPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
const LatentsInputFieldComponent = (
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
_props: FieldComponentProps<
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
>
) => {
return null;
};

View File

@@ -101,8 +101,10 @@ const LoRAModelInputFieldComponent = (
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
error={!selectedLoRAModel}
onChange={handleChange}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},

View File

@@ -134,6 +134,7 @@ const MainModelInputFieldComponent = (
disabled={data.length === 0}
onChange={handleChangeModel}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},

View File

@@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput';
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
FloatInputFieldTemplate,
FloatInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo, useEffect, useMemo, useState } from 'react';

View File

@@ -1,12 +1,12 @@
import { Box, Flex } from '@chakra-ui/react';
import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
SDXLRefinerModelInputFieldTemplate,
SDXLRefinerModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
@@ -101,20 +101,17 @@ const RefinerModelInputFieldComponent = (
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
error={!selectedModel}
disabled={data.length === 0}
onChange={handleChangeModel}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},
}}
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton className="nodrag" iconMode />
</Box>
)}
{isSyncModelEnabled && <SyncModelsButton className="nodrag" iconMode />}
</Flex>
);
};

View File

@@ -128,10 +128,11 @@ const ModelInputFieldComponent = (
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
error={!selectedModel}
disabled={data.length === 0}
onChange={handleChangeModel}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},

View File

@@ -4,9 +4,9 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
@@ -88,17 +88,15 @@ const VaeModelInputFieldComponent = (
className="nowheel nodrag"
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={
selectedVaeModel?.base_model &&
MODEL_TYPE_MAP[selectedVaeModel?.base_model]
}
value={selectedVaeModel?.id ?? 'default'}
placeholder="Default"
data={data}
onChange={handleChangeModel}
disabled={data.length === 0}
error={!selectedVaeModel}
clearable
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},

View File

@@ -27,9 +27,11 @@ const NodeTitle = ({ nodeId, title }: Props) => {
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title || 'Problem Setting Title');
setLocalTitle(
newTitle || title || templateTitle || 'Problem Setting Title'
);
},
[nodeId, dispatch, title]
[dispatch, nodeId, title, templateTitle]
);
const handleChange = useCallback((newTitle: string) => {

View File

@@ -7,13 +7,22 @@ import {
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
import {
DRAG_HANDLE_CLASSNAME,
NODE_WIDTH,
} from 'features/nodes/types/constants';
import { NodeStatus } from 'features/nodes/types/types';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { PropsWithChildren, memo, useCallback, useMemo } from 'react';
import {
MouseEvent,
PropsWithChildren,
memo,
useCallback,
useMemo,
} from 'react';
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
@@ -23,6 +32,8 @@ type NodeWrapperProps = PropsWithChildren & {
const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
useMouseOverNode(nodeId);
const selectIsInProgress = useMemo(
() =>
@@ -36,25 +47,16 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const isInProgress = useAppSelector(selectIsInProgress);
const [
nodeSelectedLight,
nodeSelectedDark,
nodeInProgressLight,
nodeInProgressDark,
shadowsXl,
shadowsBase,
] = useToken('shadows', [
'nodeSelected.light',
'nodeSelected.dark',
'nodeInProgress.light',
'nodeInProgress.dark',
'shadows.xl',
'shadows.base',
]);
const [nodeInProgressLight, nodeInProgressDark, shadowsXl, shadowsBase] =
useToken('shadows', [
'nodeInProgress.light',
'nodeInProgress.dark',
'shadows.xl',
'shadows.base',
]);
const dispatch = useAppDispatch();
const selectedShadow = useColorModeValue(nodeSelectedLight, nodeSelectedDark);
const inProgressShadow = useColorModeValue(
nodeInProgressLight,
nodeInProgressDark
@@ -62,13 +64,21 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
const handleClick = useCallback(() => {
dispatch(contextMenusClosed());
}, [dispatch]);
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
dispatch(nodeExclusivelySelected(nodeId));
}
dispatch(contextMenusClosed());
},
[dispatch, nodeId]
);
return (
<Box
onClick={handleClick}
onMouseEnter={handleMouseOver}
onMouseLeave={handleMouseOut}
className={DRAG_HANDLE_CLASSNAME}
sx={{
h: 'full',
@@ -77,11 +87,6 @@ const NodeWrapper = (props: NodeWrapperProps) => {
w: width ?? NODE_WIDTH,
transitionProperty: 'common',
transitionDuration: '0.1s',
shadow: selected
? isInProgress
? undefined
: selectedShadow
: undefined,
cursor: 'grab',
opacity,
}}
@@ -116,6 +121,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
}}
/>
{children}
<NodeSelectionOverlay isSelected={selected} isHovered={isMouseOverNode} />
</Box>
);
};

View File

@@ -2,12 +2,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { useWorkflow } from 'features/nodes/hooks/useWorkflow';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSave } from 'react-icons/fa';
import { FaDownload } from 'react-icons/fa';
const SaveWorkflowButton = () => {
const DownloadWorkflowButton = () => {
const { t } = useTranslation();
const workflow = useWorkflow();
const handleSave = useCallback(() => {
const handleDownload = useCallback(() => {
const blob = new Blob([JSON.stringify(workflow, null, 2)]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
@@ -18,12 +18,12 @@ const SaveWorkflowButton = () => {
}, [workflow]);
return (
<IAIIconButton
icon={<FaSave />}
tooltip={t('nodes.saveWorkflow')}
aria-label={t('nodes.saveWorkflow')}
onClick={handleSave}
icon={<FaDownload />}
tooltip={t('nodes.downloadWorkflow')}
aria-label={t('nodes.downloadWorkflow')}
onClick={handleDownload}
/>
);
};
export default memo(SaveWorkflowButton);
export default memo(DownloadWorkflowButton);

View File

@@ -2,7 +2,7 @@ import { Flex } from '@chakra-ui/layout';
import { memo } from 'react';
import LoadWorkflowButton from './LoadWorkflowButton';
import ResetWorkflowButton from './ResetWorkflowButton';
import SaveWorkflowButton from './SaveWorkflowButton';
import DownloadWorkflowButton from './DownloadWorkflowButton';
const TopCenterPanel = () => {
return (
@@ -15,7 +15,7 @@ const TopCenterPanel = () => {
transform: 'translate(-50%)',
}}
>
<SaveWorkflowButton />
<DownloadWorkflowButton />
<LoadWorkflowButton />
<ResetWorkflowButton />
</Flex>

View File

@@ -0,0 +1,74 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvocationTemplate, NodeData } from 'features/nodes/types/types';
import { memo } from 'react';
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
import NodeTitle from '../../flow/nodes/common/NodeTitle';
import ScrollableContent from '../ScrollableContent';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
data: lastSelectedNode?.data,
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const InspectorDetailsTab = () => {
const { data, template } = useAppSelector(selector);
if (!template || !data) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <Content data={data} template={template} />;
};
export default memo(InspectorDetailsTab);
const Content = (props: { data: NodeData; template: InvocationTemplate }) => {
const { data } = props;
return (
<Box
sx={{
position: 'relative',
w: 'full',
h: 'full',
}}
>
<ScrollableContent>
<Flex
sx={{
flexDir: 'column',
position: 'relative',
p: 1,
gap: 2,
w: 'full',
}}
>
<NodeTitle nodeId={data.id} />
<NotesTextarea nodeId={data.id} />
</Flex>
</ScrollableContent>
</Box>
);
};

View File

@@ -4,12 +4,13 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { isInvocationNode } from 'features/nodes/types/types';
import { memo } from 'react';
import ImageOutputPreview from './outputs/ImageOutputPreview';
import ScrollableContent from '../ScrollableContent';
import { ImageOutput } from 'services/api/types';
import { AnyResult } from 'services/events/types';
import StringOutputPreview from './outputs/StringOutputPreview';
import NumberOutputPreview from './outputs/NumberOutputPreview';
import ScrollableContent from '../ScrollableContent';
import ImageOutputPreview from './outputs/ImageOutputPreview';
const selector = createSelector(
stateSelector,
@@ -21,11 +22,16 @@ const selector = createSelector(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
const nes =
nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
return {
node: lastSelectedNode,
template: lastSelectedNodeTemplate,
nes,
};
},
@@ -33,9 +39,9 @@ const selector = createSelector(
);
const InspectorOutputsTab = () => {
const { node, nes } = useAppSelector(selector);
const { node, template, nes } = useAppSelector(selector);
if (!node || !nes) {
if (!node || !nes || !isInvocationNode(node)) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
@@ -63,33 +69,16 @@ const InspectorOutputsTab = () => {
w: 'full',
}}
>
{nes.outputs.map((result, i) => {
if (result.type === 'string_output') {
return (
<StringOutputPreview key={getKey(result, i)} output={result} />
);
}
if (result.type === 'float_output') {
return (
<NumberOutputPreview key={getKey(result, i)} output={result} />
);
}
if (result.type === 'integer_output') {
return (
<NumberOutputPreview key={getKey(result, i)} output={result} />
);
}
if (result.type === 'image_output') {
return (
<ImageOutputPreview key={getKey(result, i)} output={result} />
);
}
return (
<pre key={getKey(result, i)}>
{JSON.stringify(result, null, 2)}
</pre>
);
})}
{template?.outputType === 'image_output' ? (
nes.outputs.map((result, i) => (
<ImageOutputPreview
key={getKey(result, i)}
output={result as ImageOutput}
/>
))
) : (
<DataViewer data={nes.outputs} label="Node Outputs" />
)}
</Flex>
</ScrollableContent>
</Box>

View File

@@ -10,6 +10,7 @@ import { memo } from 'react';
import InspectorDataTab from './InspectorDataTab';
import InspectorOutputsTab from './InspectorOutputsTab';
import InspectorTemplateTab from './InspectorTemplateTab';
// import InspectorDetailsTab from './InspectorDetailsTab';
const InspectorPanel = () => {
return (
@@ -29,12 +30,16 @@ const InspectorPanel = () => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
<TabList>
{/* <Tab>Details</Tab> */}
<Tab>Outputs</Tab>
<Tab>Data</Tab>
<Tab>Template</Tab>
</TabList>
<TabPanels>
{/* <TabPanel>
<InspectorDetailsTab />
</TabPanel> */}
<TabPanel>
<InspectorOutputsTab />
</TabPanel>

View File

@@ -1,13 +0,0 @@
import { Text } from '@chakra-ui/react';
import { memo } from 'react';
import { FloatOutput, IntegerOutput } from 'services/api/types';
type Props = {
output: IntegerOutput | FloatOutput;
};
const NumberOutputPreview = ({ output }: Props) => {
return <Text>{output.value}</Text>;
};
export default memo(NumberOutputPreview);

View File

@@ -1,13 +0,0 @@
import { Text } from '@chakra-ui/react';
import { memo } from 'react';
import { StringOutput } from 'services/api/types';
type Props = {
output: StringOutput;
};
const StringOutputPreview = ({ output }: Props) => {
return <Text>{output.value}</Text>;
};
export default memo(StringOutputPreview);

View File

@@ -22,6 +22,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
}
return map(nodeTemplate.inputs)
.filter((field) => ['any', 'direct'].includes(field.input))
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate');

View File

@@ -138,11 +138,14 @@ export const useBuildNodeData = () => {
data: {
id: nodeId,
type,
inputs,
outputs,
isOpen: true,
version: template.version,
label: '',
notes: '',
isOpen: true,
embedWorkflow: false,
isIntermediate: true,
inputs,
outputs,
},
};

View File

@@ -22,6 +22,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
}
return map(nodeTemplate.inputs)
.filter((field) => field.input === 'connection')
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate');

View File

@@ -0,0 +1,33 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { compareVersions } from 'compare-versions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
export const useDoNodeVersionsMatch = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
if (!nodeTemplate?.version || !node.data?.version) {
return false;
}
return compareVersions(nodeTemplate.version, node.data.version) === 0;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
};

View File

@@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
if (!isInvocationNode(node)) {
return;
}
return Boolean(node?.data.inputs[fieldName]?.value);
return node?.data.inputs[fieldName]?.value !== undefined;
},
defaultSelectorOptions
),

View File

@@ -0,0 +1,27 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
export const useEmbedWorkflow = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.embedWorkflow;
},
defaultSelectorOptions
),
[nodeId]
);
const embedWorkflow = useAppSelector(selector);
return embedWorkflow;
};

View File

@@ -15,7 +15,7 @@ export const useIsIntermediate = (nodeId: string) => {
if (!isInvocationNode(node)) {
return false;
}
return Boolean(node.data.inputs.is_intermediate?.value);
return node.data.isIntermediate;
},
defaultSelectorOptions
),

View File

@@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib';
import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
import { COLLECTION_TYPES } from '../types/constants';
import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from '../types/constants';
import { InvocationNodeData } from '../types/types';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const useIsValidConnection = () => {
const flow = useReactFlow();
const shouldValidateGraph = useAppSelector(
@@ -42,6 +52,19 @@ export const useIsValidConnection = () => {
return false;
}
if (
edges
.filter((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
})
.find((edge) => {
edge.source === source && edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return false;
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
@@ -53,21 +76,62 @@ export const useIsValidConnection = () => {
return false;
}
// Connection types must be the same for a connection
if (
sourceType !== targetType &&
sourceType !== 'CollectionItem' &&
targetType !== 'CollectionItem'
) {
if (
!(
COLLECTION_TYPES.includes(targetType) &&
COLLECTION_TYPES.includes(sourceType)
)
) {
return false;
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
* - Generic Collection can connect to any other Collection or Polymorphic
* - Any Collection can connect to a Generic Collection
*/
if (sourceType !== targetType) {
const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(targetType);
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToPolymorphicOfSameBaseType ||
isGenericCollectionToAnyCollectionOrPolymorphic ||
isCollectionToGenericCollection ||
isIntToFloat
);
}
// Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges);
},

View File

@@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger';
import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { zWorkflow } from 'features/nodes/types/types';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { ZodError } from 'zod';
import { fromZodError, fromZodIssue } from 'zod-validation-error';
import { workflowLoadRequested } from '../store/actions';
export const useLoadWorkflowFromFile = () => {
const dispatch = useAppDispatch();
@@ -27,49 +27,38 @@ export const useLoadWorkflowFromFile = () => {
const result = zWorkflow.safeParse(parsedJSON);
if (!result.success) {
const message = fromZodError(result.error, {
const { message } = fromZodError(result.error, {
prefix: 'Workflow Validation Error',
}).toString();
});
logger.error({ error: parseify(result.error) }, message);
dispatch(
addToast(
makeToast({
title: 'Unable to Validate Workflow',
description: (
<WorkflowValidationErrorContent error={result.error} />
),
status: 'error',
duration: 5000,
})
)
);
reader.abort();
return;
}
dispatch(workflowLoaded(result.data));
dispatch(workflowLoadRequested(result.data));
reader.abort();
} catch {
// file reader error
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
title: 'Unable to Load Workflow',
status: 'error',
})
)
);
reader.abort();
} catch (error) {
// file reader error
if (error) {
dispatch(
addToast(
makeToast({
title: 'Unable to Load Workflow',
status: 'error',
})
)
);
}
}
};

View File

@@ -0,0 +1,31 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback, useMemo } from 'react';
import { mouseOverNodeChanged } from '../store/nodesSlice';
export const useMouseOverNode = (nodeId: string) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => nodes.mouseOverNode === nodeId,
defaultSelectorOptions
),
[nodeId]
);
const isMouseOverNode = useAppSelector(selector);
const handleMouseOver = useCallback(() => {
!isMouseOverNode && dispatch(mouseOverNodeChanged(nodeId));
}, [dispatch, nodeId, isMouseOverNode]);
const handleMouseOut = useCallback(() => {
isMouseOverNode && dispatch(mouseOverNodeChanged(null));
}, [dispatch, isMouseOverNode]);
return { isMouseOverNode, handleMouseOver, handleMouseOut };
};

View File

@@ -21,6 +21,7 @@ export const useOutputFieldNames = (nodeId: string) => {
return [];
}
return map(nodeTemplate.outputs)
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate');

View File

@@ -1,5 +1,6 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import { Graph } from 'services/api/types';
import { Workflow } from '../types/types';
export const textToImageGraphBuilt = createAction<Graph>(
'nodes/textToImageGraphBuilt'
@@ -16,3 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
canvasGraphBuilt,
nodesGraphBuilt
);
export const workflowLoadRequested = createAction<Workflow>(
'nodes/workflowLoadRequested'
);

View File

@@ -1,5 +1,5 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es';
import { cloneDeep, forEach, isEqual, map, uniqBy } from 'lodash-es';
import {
addEdge,
applyEdgeChanges,
@@ -18,7 +18,7 @@ import {
Viewport,
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { sessionInvoked } from 'services/api/thunks/session';
import { sessionCanceled, sessionInvoked } from 'services/api/thunks/session';
import { ImageField } from 'services/api/types';
import {
appSocketGeneratorProgress,
@@ -102,6 +102,7 @@ export const initialNodesState: NodesState = {
nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 },
mouseOverField: null,
mouseOverNode: null,
nodesToCopy: [],
edgesToCopy: [],
selectionMode: SelectionMode.Partial,
@@ -245,6 +246,34 @@ const nodesSlice = createSlice({
}
field.label = label;
},
nodeEmbedWorkflowChanged: (
state,
action: PayloadAction<{ nodeId: string; embedWorkflow: boolean }>
) => {
const { nodeId, embedWorkflow } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.embedWorkflow = embedWorkflow;
},
nodeIsIntermediateChanged: (
state,
action: PayloadAction<{ nodeId: string; isIntermediate: boolean }>
) => {
const { nodeId, isIntermediate } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.isIntermediate = isIntermediate;
},
nodeIsOpenChanged: (
state,
action: PayloadAction<{ nodeId: string; isOpen: boolean }>
@@ -414,6 +443,17 @@ const nodesSlice = createSlice({
}
node.data.notes = notes;
},
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
const nodeId = action.payload;
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({
id: n.id,
type: 'select',
selected: n.id === nodeId ? true : false,
})),
state.nodes
);
},
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
state.selectedNodes = action.payload;
},
@@ -561,7 +601,7 @@ const nodesSlice = createSlice({
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
state.workflow.exposedFields = [];
state.workflow = cloneDeep(initialWorkflow);
},
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload;
@@ -637,6 +677,9 @@ const nodesSlice = createSlice({
) => {
state.mouseOverField = action.payload;
},
mouseOverNodeChanged: (state, action: PayloadAction<string | null>) => {
state.mouseOverNode = action.payload;
},
selectedAll: (state) => {
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
@@ -790,6 +833,13 @@ const nodesSlice = createSlice({
nes.outputs = [];
});
});
builder.addCase(sessionCanceled.fulfilled, (state) => {
map(state.nodeExecutionStates, (nes) => {
if (nes.status === NodeStatus.IN_PROGRESS) {
nes.status = NodeStatus.PENDING;
}
});
});
},
});
@@ -850,6 +900,10 @@ export const {
addNodePopoverClosed,
addNodePopoverToggled,
selectionModeChanged,
nodeEmbedWorkflowChanged,
nodeIsIntermediateChanged,
mouseOverNodeChanged,
nodeExclusivelySelected,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@@ -0,0 +1,4 @@
import { atom } from 'nanostores';
import { ReactFlowInstance } from 'reactflow';
export const $flow = atom<ReactFlowInstance | null>(null);

View File

@@ -35,6 +35,7 @@ export type NodesState = {
viewport: Viewport;
isReady: boolean;
mouseOverField: FieldIdentifier | null;
mouseOverNode: string | null;
nodesToCopy: Node<NodeData>[];
edgesToCopy: Edge<InvocationEdgeExtra>[];
isAddNodePopoverOpen: boolean;

View File

@@ -1,10 +1,20 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
import { COLLECTION_TYPES } from 'features/nodes/types/constants';
import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';
import { HandleType } from 'reactflow';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const makeConnectionErrorSelector = (
nodeId: string,
fieldName: string,
@@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = (
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes;
if (!state.nodes.shouldValidateGraph) {
// manual override!
return null;
}
if (!connectionStartParams || !currentConnectionFieldType) {
return 'No connection in progress';
}
@@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = (
return 'No connection data';
}
const targetFieldType =
const targetType =
handleType === 'target' ? fieldType : currentConnectionFieldType;
const sourceFieldType =
const sourceType =
handleType === 'source' ? fieldType : currentConnectionFieldType;
if (nodeId === connectionNodeId) {
@@ -55,30 +60,73 @@ export const makeConnectionErrorSelector = (
}
if (
fieldType !== currentConnectionFieldType &&
fieldType !== 'CollectionItem' &&
currentConnectionFieldType !== 'CollectionItem'
) {
if (
!(
COLLECTION_TYPES.includes(targetFieldType) &&
COLLECTION_TYPES.includes(sourceFieldType)
)
) {
// except for collection items, field types must match
return 'Field types must match';
}
}
if (
handleType === 'target' &&
edges.find((edge) => {
return edge.target === nodeId && edge.targetHandle === fieldName;
}) &&
// except CollectionItem inputs can have multiples
targetFieldType !== 'CollectionItem'
targetType !== 'CollectionItem'
) {
return 'Inputs may only have one connection';
return 'Input may only have one connection';
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
* - Generic Collection can connect to any other Collection or Polymorphic
* - Any Collection can connect to a Generic Collection
*/
if (sourceType !== targetType) {
const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(targetType);
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
if (
!(
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToPolymorphicOfSameBaseType ||
isGenericCollectionToAnyCollectionOrPolymorphic ||
isCollectionToGenericCollection ||
isIntToFloat
)
) {
return 'Field types must match';
}
}
const isGraphAcyclic = getIsGraphAcyclic(

View File

@@ -17,176 +17,297 @@ export const KIND_MAP = {
export const COLLECTION_TYPES: FieldType[] = [
'Collection',
'IntegerCollection',
'BooleanCollection',
'FloatCollection',
'StringCollection',
'BooleanCollection',
'ImageCollection',
'LatentsCollection',
'ConditioningCollection',
'ControlCollection',
'ColorCollection',
];
export const POLYMORPHIC_TYPES = [
'IntegerPolymorphic',
'BooleanPolymorphic',
'FloatPolymorphic',
'StringPolymorphic',
'ImagePolymorphic',
'LatentsPolymorphic',
'ConditioningPolymorphic',
'ControlPolymorphic',
'ColorPolymorphic',
];
export const MODEL_TYPES = [
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
'ONNXModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'VaeField',
'ClipField',
];
export const COLLECTION_MAP = {
integer: 'IntegerCollection',
boolean: 'BooleanCollection',
number: 'FloatCollection',
float: 'FloatCollection',
string: 'StringCollection',
ImageField: 'ImageCollection',
LatentsField: 'LatentsCollection',
ConditioningField: 'ConditioningCollection',
ControlField: 'ControlCollection',
ColorField: 'ColorCollection',
};
export const isCollectionItemType = (
itemType: string | undefined
): itemType is keyof typeof COLLECTION_MAP =>
Boolean(itemType && itemType in COLLECTION_MAP);
export const SINGLE_TO_POLYMORPHIC_MAP = {
integer: 'IntegerPolymorphic',
boolean: 'BooleanPolymorphic',
number: 'FloatPolymorphic',
float: 'FloatPolymorphic',
string: 'StringPolymorphic',
ImageField: 'ImagePolymorphic',
LatentsField: 'LatentsPolymorphic',
ConditioningField: 'ConditioningPolymorphic',
ControlField: 'ControlPolymorphic',
ColorField: 'ColorPolymorphic',
};
export const POLYMORPHIC_TO_SINGLE_MAP = {
IntegerPolymorphic: 'integer',
BooleanPolymorphic: 'boolean',
FloatPolymorphic: 'float',
StringPolymorphic: 'string',
ImagePolymorphic: 'ImageField',
LatentsPolymorphic: 'LatentsField',
ConditioningPolymorphic: 'ConditioningField',
ControlPolymorphic: 'ControlField',
ColorPolymorphic: 'ColorField',
};
export const isPolymorphicItemType = (
itemType: string | undefined
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
color: 'red.500',
},
float: {
title: 'Float',
description: 'Floats are numbers with a decimal point.',
color: 'orange.500',
},
string: {
title: 'String',
description: 'Strings are text.',
color: 'yellow.500',
},
boolean: {
title: 'Boolean',
color: 'green.500',
description: 'Booleans are true or false.',
title: 'Boolean',
},
enum: {
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
color: 'blue.500',
BooleanCollection: {
color: 'green.500',
description: 'A collection of booleans.',
title: 'Boolean Collection',
},
array: {
title: 'Array',
description: 'Enums are values that may be one of a number of options.',
color: 'base.500',
},
ImageField: {
title: 'Image',
description: 'Images may be passed between nodes.',
color: 'purple.500',
},
LatentsField: {
title: 'Latents',
description: 'Latents may be passed between nodes.',
color: 'pink.500',
},
LatentsCollection: {
title: 'Latents Collection',
description: 'Latents may be passed between nodes.',
color: 'pink.500',
},
ConditioningField: {
color: 'cyan.500',
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
ConditioningCollection: {
color: 'cyan.500',
title: 'Conditioning Collection',
description: 'Conditioning may be passed between nodes.',
},
ImageCollection: {
title: 'Image Collection',
description: 'A collection of images.',
color: 'base.300',
},
UNetField: {
color: 'red.500',
title: 'UNet',
description: 'UNet submodel.',
BooleanPolymorphic: {
color: 'green.500',
description: 'A collection of booleans.',
title: 'Boolean Polymorphic',
},
ClipField: {
color: 'green.500',
title: 'Clip',
description: 'Tokenizer and text_encoder submodels.',
},
VaeField: {
color: 'blue.500',
title: 'Vae',
description: 'Vae submodel.',
},
ControlField: {
color: 'cyan.500',
title: 'Control',
description: 'Control info passed between nodes.',
},
MainModelField: {
color: 'teal.500',
title: 'Model',
description: 'TODO',
},
SDXLRefinerModelField: {
color: 'teal.500',
title: 'Refiner Model',
description: 'TODO',
},
VaeModelField: {
color: 'teal.500',
title: 'VAE',
description: 'TODO',
},
LoRAModelField: {
color: 'teal.500',
title: 'LoRA',
description: 'TODO',
},
ControlNetModelField: {
color: 'teal.500',
title: 'ControlNet',
description: 'TODO',
},
Scheduler: {
color: 'base.500',
title: 'Scheduler',
description: 'TODO',
title: 'Clip',
},
Collection: {
color: 'base.500',
title: 'Collection',
description: 'TODO',
title: 'Collection',
},
CollectionItem: {
color: 'base.500',
title: 'Collection Item',
description: 'TODO',
title: 'Collection Item',
},
ColorCollection: {
color: 'pink.300',
description: 'A collection of colors.',
title: 'Color Collection',
},
ColorField: {
title: 'Color',
color: 'pink.300',
description: 'A RGBA color.',
color: 'base.500',
title: 'Color',
},
BooleanCollection: {
title: 'Boolean Collection',
description: 'A collection of booleans.',
color: 'green.500',
ColorPolymorphic: {
color: 'pink.300',
description: 'A collection of colors.',
title: 'Color Polymorphic',
},
IntegerCollection: {
title: 'Integer Collection',
description: 'A collection of integers.',
color: 'red.500',
ConditioningCollection: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning Collection',
},
ConditioningField: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning',
},
ConditioningPolymorphic: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning Polymorphic',
},
ControlCollection: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control Collection',
},
ControlField: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control',
},
ControlNetModelField: {
color: 'teal.500',
description: 'TODO',
title: 'ControlNet',
},
ControlPolymorphic: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control Polymorphic',
},
DenoiseMaskField: {
color: 'blue.300',
description: 'Denoise Mask may be passed between nodes',
title: 'Denoise Mask',
},
enum: {
color: 'blue.500',
description: 'Enums are values that may be one of a number of options.',
title: 'Enum',
},
float: {
color: 'orange.500',
description: 'Floats are numbers with a decimal point.',
title: 'Float',
},
FloatCollection: {
color: 'orange.500',
title: 'Float Collection',
description: 'A collection of floats.',
title: 'Float Collection',
},
ColorCollection: {
color: 'base.500',
title: 'Color Collection',
description: 'A collection of colors.',
FloatPolymorphic: {
color: 'orange.500',
description: 'A collection of floats.',
title: 'Float Polymorphic',
},
FilePath: {
color: 'base.500',
title: 'File Path',
description: 'A path to a file.',
ImageCollection: {
color: 'purple.500',
description: 'A collection of images.',
title: 'Image Collection',
},
ImageField: {
color: 'purple.500',
description: 'Images may be passed between nodes.',
title: 'Image',
},
ImagePolymorphic: {
color: 'purple.500',
description: 'A collection of images.',
title: 'Image Polymorphic',
},
integer: {
color: 'red.500',
description: 'Integers are whole numbers, without a decimal point.',
title: 'Integer',
},
IntegerCollection: {
color: 'red.500',
description: 'A collection of integers.',
title: 'Integer Collection',
},
IntegerPolymorphic: {
color: 'red.500',
description: 'A collection of integers.',
title: 'Integer Polymorphic',
},
LatentsCollection: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents Collection',
},
LatentsField: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents',
},
LatentsPolymorphic: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents Polymorphic',
},
LoRAModelField: {
color: 'teal.500',
description: 'TODO',
title: 'LoRA',
},
MainModelField: {
color: 'teal.500',
description: 'TODO',
title: 'Model',
},
ONNXModelField: {
color: 'base.500',
title: 'ONNX Model',
color: 'teal.500',
description: 'ONNX model field.',
title: 'ONNX Model',
},
Scheduler: {
color: 'base.500',
description: 'TODO',
title: 'Scheduler',
},
SDXLMainModelField: {
color: 'base.500',
title: 'SDXL Model',
color: 'teal.500',
description: 'SDXL model field.',
title: 'SDXL Model',
},
SDXLRefinerModelField: {
color: 'teal.500',
description: 'TODO',
title: 'Refiner Model',
},
string: {
color: 'yellow.500',
description: 'Strings are text.',
title: 'String',
},
StringCollection: {
color: 'yellow.500',
title: 'String Collection',
description: 'A collection of strings.',
title: 'String Collection',
},
StringPolymorphic: {
color: 'yellow.500',
description: 'A collection of strings.',
title: 'String Polymorphic',
},
UNetField: {
color: 'red.500',
description: 'UNet submodel.',
title: 'UNet',
},
VaeField: {
color: 'blue.500',
description: 'Vae submodel.',
title: 'Vae',
},
VaeModelField: {
color: 'teal.500',
description: 'TODO',
title: 'VAE',
},
};

View File

@@ -1,19 +1,24 @@
import {
SchedulerParam,
zBaseModel,
zMainModel,
zMainOrOnnxModel,
zOnnxModel,
zSDXLRefinerModel,
zScheduler,
} from 'features/parameters/types/parameterSchemas';
import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow';
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
import { Graph, _InputField, _OutputField } from 'services/api/types';
import {
AnyInvocationType,
AnyResult,
ProgressImage,
} from 'services/events/types';
import { O } from 'ts-toolbelt';
import { JsonObject } from 'type-fest';
import { z } from 'zod';
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
@@ -47,6 +52,10 @@ export type InvocationTemplate = {
* The type of this node's output
*/
outputType: string; // TODO: generate a union of output types
/**
* The invocation's version.
*/
version?: string;
};
export type FieldUIConfig = {
@@ -57,87 +66,63 @@ export type FieldUIConfig = {
// TODO: Get this from the OpenAPI schema? may be tricky...
export const zFieldType = z.enum([
// region Primitives
'integer',
'float',
'boolean',
'string',
'array',
'ImageField',
'LatentsField',
'ConditioningField',
'ControlField',
'ColorField',
'ImageCollection',
'ConditioningCollection',
'ColorCollection',
'LatentsCollection',
'IntegerCollection',
'FloatCollection',
'StringCollection',
'BooleanCollection',
// endregion
// region Models
'MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'ONNXModelField',
'VaeModelField',
'LoRAModelField',
'ControlNetModelField',
'UNetField',
'VaeField',
'BooleanPolymorphic',
'ClipField',
// endregion
// region Iterate/Collect
'Collection',
'CollectionItem',
// endregion
// region Misc
'FilePath',
'ColorCollection',
'ColorField',
'ColorPolymorphic',
'ConditioningCollection',
'ConditioningField',
'ConditioningPolymorphic',
'ControlCollection',
'ControlField',
'ControlNetModelField',
'ControlPolymorphic',
'DenoiseMaskField',
'enum',
'float',
'FloatCollection',
'FloatPolymorphic',
'ImageCollection',
'ImageField',
'ImagePolymorphic',
'integer',
'IntegerCollection',
'IntegerPolymorphic',
'LatentsCollection',
'LatentsField',
'LatentsPolymorphic',
'LoRAModelField',
'MainModelField',
'ONNXModelField',
'Scheduler',
// endregion
'SDXLMainModelField',
'SDXLRefinerModelField',
'string',
'StringCollection',
'StringPolymorphic',
'UNetField',
'VaeField',
'VaeModelField',
]);
export type FieldType = z.infer<typeof zFieldType>;
export const isFieldType = (value: unknown): value is FieldType =>
zFieldType.safeParse(value).success;
export const zReservedFieldType = z.enum([
'WorkflowField',
'IsIntermediate',
'MetadataField',
]);
/**
* An input field template is generated on each page load from the OpenAPI schema.
*
* The template provides the field type and other field metadata (e.g. title, description,
* maximum length, pattern to match, etc).
*/
export type InputFieldTemplate =
| IntegerInputFieldTemplate
| FloatInputFieldTemplate
| StringInputFieldTemplate
| BooleanInputFieldTemplate
| ImageInputFieldTemplate
| LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
| UNetInputFieldTemplate
| ClipInputFieldTemplate
| VaeInputFieldTemplate
| ControlInputFieldTemplate
| EnumInputFieldTemplate
| MainModelInputFieldTemplate
| SDXLMainModelInputFieldTemplate
| SDXLRefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ImageCollectionInputFieldTemplate
| SchedulerInputFieldTemplate;
export type ReservedFieldType = z.infer<typeof zReservedFieldType>;
export const isFieldType = (value: unknown): value is FieldType =>
zFieldType.safeParse(value).success ||
zReservedFieldType.safeParse(value).success;
/**
* Indicates the kind of input(s) this field may have.
@@ -205,30 +190,100 @@ export const zConditioningField = z.object({
});
export type ConditioningField = z.infer<typeof zConditioningField>;
export const zDenoiseMaskField = z.object({
mask_name: z.string().trim().min(1),
masked_latents_name: z.string().trim().min(1).optional(),
});
export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('integer'),
value: z.number().optional(),
value: z.number().int().optional(),
});
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerCollection'),
value: z.array(z.number().int()).optional(),
});
export type IntegerCollectionInputFieldValue = z.infer<
typeof zIntegerCollectionInputFieldValue
>;
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerPolymorphic'),
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
});
export type IntegerPolymorphicInputFieldValue = z.infer<
typeof zIntegerPolymorphicInputFieldValue
>;
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('float'),
value: z.number().optional(),
});
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatCollection'),
value: z.array(z.number()).optional(),
});
export type FloatCollectionInputFieldValue = z.infer<
typeof zFloatCollectionInputFieldValue
>;
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatPolymorphic'),
value: z.union([z.number(), z.array(z.number())]).optional(),
});
export type FloatPolymorphicInputFieldValue = z.infer<
typeof zFloatPolymorphicInputFieldValue
>;
export const zStringInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('string'),
value: z.string().optional(),
});
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringCollection'),
value: z.array(z.string()).optional(),
});
export type StringCollectionInputFieldValue = z.infer<
typeof zStringCollectionInputFieldValue
>;
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringPolymorphic'),
value: z.union([z.string(), z.array(z.string())]).optional(),
});
export type StringPolymorphicInputFieldValue = z.infer<
typeof zStringPolymorphicInputFieldValue
>;
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('boolean'),
value: z.boolean().optional(),
});
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanCollection'),
value: z.array(z.boolean()).optional(),
});
export type BooleanCollectionInputFieldValue = z.infer<
typeof zBooleanCollectionInputFieldValue
>;
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanPolymorphic'),
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
});
export type BooleanPolymorphicInputFieldValue = z.infer<
typeof zBooleanPolymorphicInputFieldValue
>;
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('enum'),
value: z.union([z.string(), z.number()]).optional(),
@@ -241,6 +296,30 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
});
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('LatentsCollection'),
value: z.array(zLatentsField).optional(),
});
export type LatentsCollectionInputFieldValue = z.infer<
typeof zLatentsCollectionInputFieldValue
>;
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('LatentsPolymorphic'),
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
});
export type LatentsPolymorphicInputFieldValue = z.infer<
typeof zLatentsPolymorphicInputFieldValue
>;
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('DenoiseMaskField'),
value: zDenoiseMaskField.optional(),
});
export type DenoiseMaskInputFieldValue = z.infer<
typeof zDenoiseMaskInputFieldValue
>;
export const zConditioningInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ConditioningField'),
value: zConditioningField.optional(),
@@ -249,10 +328,30 @@ export type ConditioningInputFieldValue = z.infer<
typeof zConditioningInputFieldValue
>;
export const zConditioningCollectionInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('ConditioningCollection'),
value: z.array(zConditioningField).optional(),
});
export type ConditioningCollectionInputFieldValue = z.infer<
typeof zConditioningCollectionInputFieldValue
>;
export const zConditioningPolymorphicInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('ConditioningPolymorphic'),
value: z
.union([zConditioningField, z.array(zConditioningField)])
.optional(),
});
export type ConditioningPolymorphicInputFieldValue = z.infer<
typeof zConditioningPolymorphicInputFieldValue
>;
export const zControlNetModel = zModelIdentifier;
export type ControlNetModel = z.infer<typeof zControlNetModel>;
export const zControlField = zInputFieldValueBase.extend({
export const zControlField = z.object({
image: zImageField,
control_model: zControlNetModel,
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
@@ -267,11 +366,27 @@ export const zControlField = zInputFieldValueBase.extend({
});
export type ControlField = z.infer<typeof zControlField>;
export const zControlInputFieldTemplate = zInputFieldValueBase.extend({
export const zControlInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ControlField'),
value: zControlField.optional(),
});
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldTemplate>;
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ControlPolymorphic'),
value: z.union([zControlField, z.array(zControlField)]).optional(),
});
export type ControlPolymorphicInputFieldValue = z.infer<
typeof zControlPolymorphicInputFieldValue
>;
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ControlCollection'),
value: z.array(zControlField).optional(),
});
export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue
>;
export const zModelType = z.enum([
'onnx',
@@ -352,6 +467,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
});
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImagePolymorphic'),
value: z.union([zImageField, z.array(zImageField)]).optional(),
});
export type ImagePolymorphicInputFieldValue = z.infer<
typeof zImagePolymorphicInputFieldValue
>;
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImageCollection'),
value: z.array(zImageField).optional(),
@@ -444,6 +567,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
});
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ColorCollection'),
value: z.array(zColorField).optional(),
});
export type ColorCollectionInputFieldValue = z.infer<
typeof zColorCollectionInputFieldValue
>;
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ColorPolymorphic'),
value: z.union([zColorField, z.array(zColorField)]).optional(),
});
export type ColorPolymorphicInputFieldValue = z.infer<
typeof zColorPolymorphicInputFieldValue
>;
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Scheduler'),
value: zScheduler.optional(),
@@ -453,29 +592,47 @@ export type SchedulerInputFieldValue = z.infer<
>;
export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerInputFieldValue,
zFloatInputFieldValue,
zStringInputFieldValue,
zBooleanCollectionInputFieldValue,
zBooleanInputFieldValue,
zImageInputFieldValue,
zLatentsInputFieldValue,
zConditioningInputFieldValue,
zUNetInputFieldValue,
zBooleanPolymorphicInputFieldValue,
zClipInputFieldValue,
zVaeInputFieldValue,
zControlInputFieldTemplate,
zEnumInputFieldValue,
zMainModelInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zVaeModelInputFieldValue,
zLoRAModelInputFieldValue,
zControlNetModelInputFieldValue,
zCollectionInputFieldValue,
zCollectionItemInputFieldValue,
zColorInputFieldValue,
zColorCollectionInputFieldValue,
zColorPolymorphicInputFieldValue,
zConditioningInputFieldValue,
zConditioningCollectionInputFieldValue,
zConditioningPolymorphicInputFieldValue,
zControlInputFieldValue,
zControlNetModelInputFieldValue,
zControlCollectionInputFieldValue,
zControlPolymorphicInputFieldValue,
zDenoiseMaskInputFieldValue,
zEnumInputFieldValue,
zFloatCollectionInputFieldValue,
zFloatInputFieldValue,
zFloatPolymorphicInputFieldValue,
zImageCollectionInputFieldValue,
zImagePolymorphicInputFieldValue,
zImageInputFieldValue,
zIntegerCollectionInputFieldValue,
zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue,
zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue,
zLoRAModelInputFieldValue,
zMainModelInputFieldValue,
zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,
zStringInputFieldValue,
zUNetInputFieldValue,
zVaeInputFieldValue,
zVaeModelInputFieldValue,
]);
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
@@ -484,7 +641,6 @@ export type InputFieldTemplateBase = {
name: string;
title: string;
description: string;
type: FieldType;
required: boolean;
fieldKind: 'input';
} & _InputField;
@@ -499,6 +655,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
exclusiveMinimum?: boolean;
};
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'IntegerCollection';
default: number[];
item_default?: number;
};
export type IntegerPolymorphicInputFieldTemplate = Omit<
IntegerInputFieldTemplate,
'type'
> & {
type: 'IntegerPolymorphic';
};
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
type: 'float';
default: number;
@@ -509,6 +678,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
exclusiveMinimum?: boolean;
};
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'FloatCollection';
default: number[];
item_default?: number;
};
export type FloatPolymorphicInputFieldTemplate = Omit<
FloatInputFieldTemplate,
'type'
> & {
type: 'FloatPolymorphic';
};
export type StringInputFieldTemplate = InputFieldTemplateBase & {
type: 'string';
default: string;
@@ -517,31 +699,95 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
pattern?: string;
};
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'StringCollection';
default: string[];
item_default?: string;
};
export type StringPolymorphicInputFieldTemplate = Omit<
StringInputFieldTemplate,
'type'
> & {
type: 'StringPolymorphic';
};
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
default: boolean;
type: 'boolean';
};
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'BooleanCollection';
default: boolean[];
item_default?: boolean;
};
export type BooleanPolymorphicInputFieldTemplate = Omit<
BooleanInputFieldTemplate,
'type'
> & {
type: 'BooleanPolymorphic';
};
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
default: ImageDTO;
default: ImageField;
type: 'ImageField';
};
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: ImageField[];
type: 'ImageCollection';
item_default?: ImageField;
};
export type ImagePolymorphicInputFieldTemplate = Omit<
ImageInputFieldTemplate,
'type'
> & {
type: 'ImagePolymorphic';
};
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'DenoiseMaskField';
};
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
default: string;
default: LatentsField;
type: 'LatentsField';
};
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: LatentsField[];
type: 'LatentsCollection';
item_default?: LatentsField;
};
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
default: LatentsField;
type: 'LatentsPolymorphic';
};
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'ConditioningField';
};
export type ConditioningCollectionInputFieldTemplate =
InputFieldTemplateBase & {
default: ConditioningField[];
type: 'ConditioningCollection';
item_default?: ConditioningField;
};
export type ConditioningPolymorphicInputFieldTemplate = Omit<
ConditioningInputFieldTemplate,
'type'
> & {
type: 'ConditioningPolymorphic';
};
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'UNetField';
@@ -562,6 +808,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
type: 'ControlField';
};
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'ControlCollection';
item_default?: ControlField;
};
export type ControlPolymorphicInputFieldTemplate = Omit<
ControlInputFieldTemplate,
'type'
> & {
type: 'ControlPolymorphic';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number;
type: 'enum';
@@ -614,11 +873,77 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
type: 'ColorField';
};
export type ColorPolymorphicInputFieldTemplate = Omit<
ColorInputFieldTemplate,
'type'
> & {
type: 'ColorPolymorphic';
};
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'ColorCollection';
};
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
default: SchedulerParam;
type: 'Scheduler';
};
export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'WorkflowField';
};
/**
* An input field template is generated on each page load from the OpenAPI schema.
*
* The template provides the field type and other field metadata (e.g. title, description,
* maximum length, pattern to match, etc).
*/
export type InputFieldTemplate =
| BooleanCollectionInputFieldTemplate
| BooleanPolymorphicInputFieldTemplate
| BooleanInputFieldTemplate
| ClipInputFieldTemplate
| CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ColorCollectionInputFieldTemplate
| ColorPolymorphicInputFieldTemplate
| ConditioningInputFieldTemplate
| ConditioningCollectionInputFieldTemplate
| ConditioningPolymorphicInputFieldTemplate
| ControlInputFieldTemplate
| ControlCollectionInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ControlPolymorphicInputFieldTemplate
| DenoiseMaskInputFieldTemplate
| EnumInputFieldTemplate
| FloatCollectionInputFieldTemplate
| FloatInputFieldTemplate
| FloatPolymorphicInputFieldTemplate
| ImageCollectionInputFieldTemplate
| ImagePolymorphicInputFieldTemplate
| ImageInputFieldTemplate
| IntegerCollectionInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate
| LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate
| LoRAModelInputFieldTemplate
| MainModelInputFieldTemplate
| SchedulerInputFieldTemplate
| SDXLMainModelInputFieldTemplate
| SDXLRefinerModelInputFieldTemplate
| StringCollectionInputFieldTemplate
| StringPolymorphicInputFieldTemplate
| StringInputFieldTemplate
| UNetInputFieldTemplate
| VaeInputFieldTemplate
| VaeModelInputFieldTemplate;
export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
@@ -639,7 +964,9 @@ export type TypeHints = {
export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation
title: string;
category?: string;
tags?: string[];
version?: string;
properties: Omit<
NonNullable<OpenAPIV3.SchemaObject['properties']> &
(_InputField | _OutputField),
@@ -690,8 +1017,22 @@ export type InvocationSchemaObject = (
) & { class: 'invocation' };
export const isSchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
export const isArraySchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.ArraySchemaObject =>
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
export const isNonArraySchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.NonArraySchemaObject =>
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
export const isRefObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
export const isInvocationSchemaObject = (
obj:
@@ -715,6 +1056,73 @@ export const isInvocationFieldSchema = (
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
export const zCoreMetadata = z
.object({
app_version: z.string().nullish(),
generation_mode: z.string().nullish(),
created_by: z.string().nullish(),
positive_prompt: z.string().nullish(),
negative_prompt: z.string().nullish(),
width: z.number().int().nullish(),
height: z.number().int().nullish(),
seed: z.number().int().nullish(),
rand_device: z.string().nullish(),
cfg_scale: z.number().nullish(),
steps: z.number().int().nullish(),
scheduler: z.string().nullish(),
clip_skip: z.number().int().nullish(),
model: z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish(),
controlnets: z.array(zControlField.deepPartial()).nullish(),
loras: z
.array(
z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
})
)
.nullish(),
vae: zVaeModelField.nullish(),
strength: z.number().nullish(),
init_image: z.string().nullish(),
positive_style_prompt: z.string().nullish(),
negative_style_prompt: z.string().nullish(),
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
refiner_cfg_scale: z.number().nullish(),
refiner_steps: z.number().int().nullish(),
refiner_scheduler: z.string().nullish(),
refiner_positive_aesthetic_score: z.number().nullish(),
refiner_negative_aesthetic_score: z.number().nullish(),
refiner_start: z.number().nullish(),
})
.passthrough();
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
export const zSemVer = z.string().refine((val) => {
const [major, minor, patch] = val.split('.');
return (
major !== undefined &&
Number.isInteger(Number(major)) &&
minor !== undefined &&
Number.isInteger(Number(minor)) &&
patch !== undefined &&
Number.isInteger(Number(patch))
);
});
export const zParsedSemver = zSemVer.transform((val) => {
const [major, minor, patch] = val.split('.');
return {
major: Number(major),
minor: Number(minor),
patch: Number(patch),
};
});
export type SemVer = z.infer<typeof zSemVer>;
export const zInvocationNodeData = z.object({
id: z.string().trim().min(1),
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
@@ -725,6 +1133,9 @@ export const zInvocationNodeData = z.object({
label: z.string(),
isOpen: z.boolean(),
notes: z.string(),
embedWorkflow: z.boolean(),
isIntermediate: z.boolean(),
version: zSemVer.optional(),
});
// Massage this to get better type safety while developing
@@ -745,28 +1156,38 @@ export const zNotesNodeData = z.object({
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
const zPosition = z
.object({
x: z.number(),
y: z.number(),
})
.default({ x: 0, y: 0 });
const zDimension = z.number().gt(0).nullish();
export const zWorkflowInvocationNode = z.object({
id: z.string().trim().min(1),
type: z.literal('invocation'),
data: zInvocationNodeData,
width: z.number().gt(0),
height: z.number().gt(0),
position: z.object({
x: z.number(),
y: z.number(),
}),
width: zDimension,
height: zDimension,
position: zPosition,
});
export type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
export const isWorkflowInvocationNode = (
val: unknown
): val is WorkflowInvocationNode =>
zWorkflowInvocationNode.safeParse(val).success;
export const zWorkflowNotesNode = z.object({
id: z.string().trim().min(1),
type: z.literal('notes'),
data: zNotesNodeData,
width: z.number().gt(0),
height: z.number().gt(0),
position: z.object({
x: z.number(),
y: z.number(),
}),
width: zDimension,
height: zDimension,
position: zPosition,
});
export const zWorkflowNode = z.discriminatedUnion('type', [
@@ -776,14 +1197,25 @@ export const zWorkflowNode = z.discriminatedUnion('type', [
export type WorkflowNode = z.infer<typeof zWorkflowNode>;
export const zWorkflowEdge = z.object({
export const zDefaultWorkflowEdge = z.object({
source: z.string().trim().min(1),
sourceHandle: z.string().trim().min(1),
target: z.string().trim().min(1),
targetHandle: z.string().trim().min(1),
id: z.string().trim().min(1),
type: z.enum(['default', 'collapsed']),
type: z.literal('default'),
});
export const zCollapsedWorkflowEdge = z.object({
source: z.string().trim().min(1),
target: z.string().trim().min(1),
id: z.string().trim().min(1),
type: z.literal('collapsed'),
});
export const zWorkflowEdge = z.union([
zDefaultWorkflowEdge,
zCollapsedWorkflowEdge,
]);
export const zFieldIdentifier = z.object({
nodeId: z.string().trim().min(1),
@@ -792,35 +1224,80 @@ export const zFieldIdentifier = z.object({
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
export const zSemVer = z.string().refine((val) => {
const [major, minor, patch] = val.split('.');
return (
major !== undefined &&
minor !== undefined &&
patch !== undefined &&
Number.isInteger(Number(major)) &&
Number.isInteger(Number(minor)) &&
Number.isInteger(Number(patch))
);
});
export type SemVer = z.infer<typeof zSemVer>;
export type WorkflowWarning = {
message: string;
issues: string[];
data: JsonObject;
};
export const zWorkflow = z.object({
name: z.string(),
author: z.string(),
description: z.string(),
version: z.string(),
contact: z.string(),
tags: z.string(),
notes: z.string(),
nodes: z.array(zWorkflowNode),
edges: z.array(zWorkflowEdge),
exposedFields: z.array(zFieldIdentifier),
name: z.string().default(''),
author: z.string().default(''),
description: z.string().default(''),
version: z.string().default(''),
contact: z.string().default(''),
tags: z.string().default(''),
notes: z.string().default(''),
nodes: z.array(zWorkflowNode).default([]),
edges: z.array(zWorkflowEdge).default([]),
exposedFields: z.array(zFieldIdentifier).default([]),
meta: z
.object({
version: zSemVer,
})
.default({ version: '1.0.0' }),
});
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
const { nodes, edges } = workflow;
const warnings: WorkflowWarning[] = [];
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id');
edges.forEach((edge, i) => {
const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target];
const issues: string[] = [];
if (!sourceNode) {
issues.push(`Output node ${edge.source} does not exist`);
} else if (
edge.type === 'default' &&
!(edge.sourceHandle in sourceNode.data.outputs)
) {
issues.push(
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
);
}
if (!targetNode) {
issues.push(`Input node ${edge.target} does not exist`);
} else if (
edge.type === 'default' &&
!(edge.targetHandle in targetNode.data.inputs)
) {
issues.push(
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
);
}
if (issues.length) {
delete edges[i];
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
warnings.push({
message: `Edge "${src} -> ${tgt}" skipped`,
issues,
data: edge,
});
}
});
return { workflow, warnings };
});
export type Workflow = z.infer<typeof zWorkflow>;
export type ImageMetadataAndWorkflow = {
metadata?: CoreMetadata;
workflow?: Workflow;
};
export type CurrentImageNodeData = {
id: string;
type: 'current_image';

View File

@@ -1,7 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { logger } from 'app/logging/logger';
import { NodesState } from '../store/types';
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
import { fromZodError } from 'zod-validation-error';
import { parseify } from 'common/util/serialize';
export const buildWorkflow = (nodesState: NodesState): Workflow => {
const { workflow: workflowMeta, nodes, edges } = nodesState;
@@ -11,17 +12,29 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
edges: [],
};
nodes.forEach((node) => {
const result = zWorkflowNode.safeParse(node);
if (!result.success) {
return;
}
workflow.nodes.push(result.data);
});
nodes
.filter((n) =>
['invocation', 'notes'].includes(n.type ?? '__UNKNOWN_NODE_TYPE__')
)
.forEach((node) => {
const result = zWorkflowNode.safeParse(node);
if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: 'Unable to parse node',
});
logger('nodes').warn({ node: parseify(node) }, message);
return;
}
workflow.nodes.push(result.data);
});
edges.forEach((edge) => {
const result = zWorkflowEdge.safeParse(edge);
if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: 'Unable to parse edge',
});
logger('nodes').warn({ edge: parseify(edge) }, message);
return;
}
workflow.edges.push(result.data);
@@ -29,7 +42,3 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
return workflow;
};
export const workflowSelector = createSelector(stateSelector, ({ nodes }) =>
buildWorkflow(nodes)
);

View File

@@ -1,5 +1,14 @@
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import {
COLLECTION_MAP,
POLYMORPHIC_TYPES,
SINGLE_TO_POLYMORPHIC_MAP,
isCollectionItemType,
isPolymorphicItemType,
} from '../types/constants';
import {
BooleanCollectionInputFieldTemplate,
BooleanInputFieldTemplate,
ClipInputFieldTemplate,
CollectionInputFieldTemplate,
@@ -8,12 +17,16 @@ import {
ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
DenoiseMaskInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
FloatCollectionInputFieldTemplate,
FloatPolymorphicInputFieldTemplate,
FloatInputFieldTemplate,
ImageCollectionInputFieldTemplate,
ImageInputFieldTemplate,
InputFieldTemplateBase,
IntegerCollectionInputFieldTemplate,
IntegerInputFieldTemplate,
InvocationFieldSchema,
InvocationSchemaObject,
@@ -23,12 +36,32 @@ import {
SDXLMainModelInputFieldTemplate,
SDXLRefinerModelInputFieldTemplate,
SchedulerInputFieldTemplate,
StringCollectionInputFieldTemplate,
StringInputFieldTemplate,
UNetInputFieldTemplate,
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
isFieldType,
isArraySchemaObject,
isNonArraySchemaObject,
isRefObject,
isSchemaObject,
ControlPolymorphicInputFieldTemplate,
ColorPolymorphicInputFieldTemplate,
ColorCollectionInputFieldTemplate,
IntegerPolymorphicInputFieldTemplate,
StringPolymorphicInputFieldTemplate,
BooleanPolymorphicInputFieldTemplate,
ImagePolymorphicInputFieldTemplate,
LatentsPolymorphicInputFieldTemplate,
LatentsCollectionInputFieldTemplate,
ConditioningPolymorphicInputFieldTemplate,
ConditioningCollectionInputFieldTemplate,
ControlCollectionInputFieldTemplate,
ImageField,
LatentsField,
ConditioningField,
} from '../types/types';
import { ControlField } from 'services/api/types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
@@ -45,15 +78,8 @@ export type BuildInputFieldArg = {
* @example
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
*/
export const refObjectToFieldType = (
refObject: OpenAPIV3.ReferenceObject
): FieldType => {
const name = refObject.$ref.split('/').slice(-1)[0];
if (!name) {
throw `Unknown field type: ${name}`;
}
return name as FieldType;
};
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
refObject.$ref.split('/').slice(-1)[0];
const buildIntegerInputFieldTemplate = ({
schemaObject,
@@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({
return template;
};
const buildIntegerPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
const template: IntegerPolymorphicInputFieldTemplate = {
...baseField,
type: 'IntegerPolymorphic',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildIntegerCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
const item_default =
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
? schemaObject.item_default
: 0;
const template: IntegerCollectionInputFieldTemplate = {
...baseField,
type: 'IntegerCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildFloatInputFieldTemplate = ({
schemaObject,
baseField,
@@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({
return template;
};
const buildFloatPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
const template: FloatPolymorphicInputFieldTemplate = {
...baseField,
type: 'FloatPolymorphic',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildFloatCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
const item_default = isNumber(schemaObject.item_default)
? schemaObject.item_default
: 0;
const template: FloatCollectionInputFieldTemplate = {
...baseField,
type: 'FloatCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildStringInputFieldTemplate = ({
schemaObject,
baseField,
@@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({
return template;
};
const buildStringPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
const template: StringPolymorphicInputFieldTemplate = {
...baseField,
type: 'StringPolymorphic',
default: schemaObject.default ?? '',
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.pattern !== undefined) {
template.pattern = schemaObject.pattern;
}
return template;
};
const buildStringCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
const item_default = isString(schemaObject.item_default)
? schemaObject.item_default
: '';
const template: StringCollectionInputFieldTemplate = {
...baseField,
type: 'StringCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildBooleanInputFieldTemplate = ({
schemaObject,
baseField,
@@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({
return template;
};
const buildBooleanPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
const template: BooleanPolymorphicInputFieldTemplate = {
...baseField,
type: 'BooleanPolymorphic',
default: schemaObject.default ?? false,
};
return template;
};
const buildBooleanCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
const item_default =
schemaObject.item_default && isBoolean(schemaObject.item_default)
? schemaObject.item_default
: false;
const template: BooleanCollectionInputFieldTemplate = {
...baseField,
type: 'BooleanCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildMainModelInputFieldTemplate = ({
schemaObject,
baseField,
@@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({
return template;
};
const buildImagePolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
const template: ImagePolymorphicInputFieldTemplate = {
...baseField,
type: 'ImagePolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageCollectionInputFieldTemplate = ({
schemaObject,
baseField,
@@ -257,6 +468,20 @@ const buildImageCollectionInputFieldTemplate = ({
const template: ImageCollectionInputFieldTemplate = {
...baseField,
type: 'ImageCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ImageField) ?? undefined,
};
return template;
};
const buildDenoiseMaskInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): DenoiseMaskInputFieldTemplate => {
const template: DenoiseMaskInputFieldTemplate = {
...baseField,
type: 'DenoiseMaskField',
default: schemaObject.default ?? undefined,
};
@@ -276,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({
return template;
};
const buildLatentsPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
const template: LatentsPolymorphicInputFieldTemplate = {
...baseField,
type: 'LatentsPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLatentsCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
const template: LatentsCollectionInputFieldTemplate = {
...baseField,
type: 'LatentsCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
};
return template;
};
const buildConditioningInputFieldTemplate = ({
schemaObject,
baseField,
@@ -289,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({
return template;
};
const buildConditioningPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
const template: ConditioningPolymorphicInputFieldTemplate = {
...baseField,
type: 'ConditioningPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildConditioningCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
const template: ConditioningCollectionInputFieldTemplate = {
...baseField,
type: 'ConditioningCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
};
return template;
};
const buildUNetInputFieldTemplate = ({
schemaObject,
baseField,
@@ -342,6 +621,33 @@ const buildControlInputFieldTemplate = ({
return template;
};
const buildControlPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
const template: ControlPolymorphicInputFieldTemplate = {
...baseField,
type: 'ControlPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildControlCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
const template: ControlCollectionInputFieldTemplate = {
...baseField,
type: 'ControlCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ControlField) ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@@ -395,6 +701,32 @@ const buildColorInputFieldTemplate = ({
return template;
};
const buildColorPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
const template: ColorPolymorphicInputFieldTemplate = {
...baseField,
type: 'ColorPolymorphic',
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
};
return template;
};
const buildColorCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
const template: ColorCollectionInputFieldTemplate = {
...baseField,
type: 'ColorCollection',
default: schemaObject.default ?? [],
};
return template;
};
const buildSchedulerInputFieldTemplate = ({
schemaObject,
baseField,
@@ -410,49 +742,136 @@ const buildSchedulerInputFieldTemplate = ({
export const getFieldType = (
schemaObject: InvocationFieldSchema
): FieldType => {
let fieldType = '';
const { ui_type } = schemaObject;
if (ui_type) {
fieldType = ui_type;
): string | undefined => {
if (schemaObject?.ui_type) {
return schemaObject.ui_type;
} else if (!schemaObject.type) {
// console.log('refObject', schemaObject);
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
if (schemaObject.allOf) {
fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
const allOf = schemaObject.allOf;
if (allOf && allOf[0] && isRefObject(allOf[0])) {
return refObjectToSchemaName(allOf[0]);
}
} else if (schemaObject.anyOf) {
fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
const anyOf = schemaObject.anyOf;
/**
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
* - an `anyOf` with two items
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
* - the other is a `SchemaObject` or `ReferenceObject` of type T
*
* Any other cases we ignore.
*/
let firstType: string | undefined;
let secondType: string | undefined;
if (isArraySchemaObject(anyOf[0])) {
// first is array, second is not
const first = anyOf[0].items;
const second = anyOf[1];
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
} else if (isArraySchemaObject(anyOf[1])) {
// first is not array, second is
const first = anyOf[0];
const second = anyOf[1].items;
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
}
if (firstType === secondType && isPolymorphicItemType(firstType)) {
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
}
}
} else if (schemaObject.enum) {
fieldType = 'enum';
return 'enum';
} else if (schemaObject.type) {
if (schemaObject.type === 'number') {
// floats are "number" in OpenAPI, while ints are "integer"
fieldType = 'float';
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
return 'float';
} else if (schemaObject.type === 'array') {
const itemType = isSchemaObject(schemaObject.items)
? schemaObject.items.type
: refObjectToSchemaName(schemaObject.items);
if (isCollectionItemType(itemType)) {
return COLLECTION_MAP[itemType];
}
return;
} else {
fieldType = schemaObject.type;
return schemaObject.type;
}
}
if (!isFieldType(fieldType)) {
throw `Field type "${fieldType}" is unknown!`;
}
return fieldType;
return;
};
const TEMPLATE_BUILDER_MAP = {
boolean: buildBooleanInputFieldTemplate,
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
ClipField: buildClipInputFieldTemplate,
Collection: buildCollectionInputFieldTemplate,
CollectionItem: buildCollectionItemInputFieldTemplate,
ColorCollection: buildColorCollectionInputFieldTemplate,
ColorField: buildColorInputFieldTemplate,
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
ConditioningField: buildConditioningInputFieldTemplate,
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
ControlCollection: buildControlCollectionInputFieldTemplate,
ControlField: buildControlInputFieldTemplate,
ControlNetModelField: buildControlNetModelInputFieldTemplate,
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
enum: buildEnumInputFieldTemplate,
float: buildFloatInputFieldTemplate,
FloatCollection: buildFloatCollectionInputFieldTemplate,
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
ImageCollection: buildImageCollectionInputFieldTemplate,
ImageField: buildImageInputFieldTemplate,
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
integer: buildIntegerInputFieldTemplate,
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
LoRAModelField: buildLoRAModelInputFieldTemplate,
MainModelField: buildMainModelInputFieldTemplate,
Scheduler: buildSchedulerInputFieldTemplate,
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
string: buildStringInputFieldTemplate,
StringCollection: buildStringCollectionInputFieldTemplate,
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
UNetField: buildUNetInputFieldTemplate,
VaeField: buildVaeInputFieldTemplate,
VaeModelField: buildVaeModelInputFieldTemplate,
};
const isTemplatedFieldType = (
fieldType: string | undefined
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
/**
* Builds an input field from an invocation schema property.
* @param fieldSchema The schema object
@@ -461,16 +880,14 @@ export const getFieldType = (
export const buildInputFieldTemplate = (
nodeSchema: InvocationSchemaObject,
fieldSchema: InvocationFieldSchema,
name: string
name: string,
fieldType: FieldType
) => {
// console.log('input', schemaObject);
const fieldType = getFieldType(fieldSchema);
// console.log('input fieldType', fieldType);
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
const extra = {
input,
// TODO: Can we support polymorphic inputs in the UI?
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
ui_hidden,
ui_component,
ui_type,
@@ -486,140 +903,12 @@ export const buildInputFieldTemplate = (
...extra,
};
if (fieldType === 'ImageField') {
return buildImageInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
if (!isTemplatedFieldType(fieldType)) {
return;
}
if (fieldType === 'ImageCollection') {
return buildImageCollectionInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'LatentsField') {
return buildLatentsInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ConditioningField') {
return buildConditioningInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'UNetField') {
return buildUNetInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ClipField') {
return buildClipInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'VaeField') {
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
}
if (fieldType === 'ControlField') {
return buildControlInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'MainModelField') {
return buildMainModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'SDXLRefinerModelField') {
return buildRefinerModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'SDXLMainModelField') {
return buildSDXLMainModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'VaeModelField') {
return buildVaeModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'LoRAModelField') {
return buildLoRAModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ControlNetModelField') {
return buildControlNetModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'enum') {
return buildEnumInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'integer') {
return buildIntegerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'float') {
return buildFloatInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'string') {
return buildStringInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'boolean') {
return buildBooleanInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'Collection') {
return buildCollectionInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'CollectionItem') {
return buildCollectionItemInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ColorField') {
return buildColorInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'Scheduler') {
return buildSchedulerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
return;
return TEMPLATE_BUILDER_MAP[fieldType]({
schemaObject: fieldSchema,
baseField,
});
};

View File

@@ -1,100 +1,79 @@
import { InputFieldTemplate, InputFieldValue } from '../types/types';
const FIELD_VALUE_FALLBACK_MAP = {
'enum.number': 0,
'enum.string': '',
boolean: false,
BooleanCollection: [],
BooleanPolymorphic: false,
ClipField: undefined,
Collection: [],
CollectionItem: undefined,
ColorCollection: [],
ColorField: undefined,
ColorPolymorphic: undefined,
ConditioningCollection: [],
ConditioningField: undefined,
ConditioningPolymorphic: undefined,
ControlCollection: [],
ControlField: undefined,
ControlNetModelField: undefined,
ControlPolymorphic: undefined,
DenoiseMaskField: undefined,
float: 0,
FloatCollection: [],
FloatPolymorphic: 0,
ImageCollection: [],
ImageField: undefined,
ImagePolymorphic: undefined,
integer: 0,
IntegerCollection: [],
IntegerPolymorphic: 0,
LatentsCollection: [],
LatentsField: undefined,
LatentsPolymorphic: undefined,
LoRAModelField: undefined,
MainModelField: undefined,
ONNXModelField: undefined,
Scheduler: 'euler',
SDXLMainModelField: undefined,
SDXLRefinerModelField: undefined,
string: '',
StringCollection: [],
StringPolymorphic: '',
UNetField: undefined,
VaeField: undefined,
VaeModelField: undefined,
};
export const buildInputFieldValue = (
id: string,
template: InputFieldTemplate
): InputFieldValue => {
const fieldValue: InputFieldValue = {
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
// `InputFieldValue` union, but TS doesn't seem to like it...
const fieldValue = {
id,
name: template.name,
type: template.type,
label: '',
fieldKind: 'input',
};
if (template.type === 'string') {
fieldValue.value = template.default ?? '';
}
if (template.type === 'integer') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'float') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'boolean') {
fieldValue.value = template.default ?? false;
}
} as InputFieldValue;
if (template.type === 'enum') {
if (template.enumType === 'number') {
fieldValue.value = template.default ?? 0;
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
}
if (template.enumType === 'string') {
fieldValue.value = template.default ?? '';
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
}
}
if (template.type === 'Collection') {
fieldValue.value = template.default ?? 1;
}
if (template.type === 'ImageField') {
fieldValue.value = undefined;
}
if (template.type === 'ImageCollection') {
fieldValue.value = [];
}
if (template.type === 'LatentsField') {
fieldValue.value = undefined;
}
if (template.type === 'ConditioningField') {
fieldValue.value = undefined;
}
if (template.type === 'UNetField') {
fieldValue.value = undefined;
}
if (template.type === 'ClipField') {
fieldValue.value = undefined;
}
if (template.type === 'VaeField') {
fieldValue.value = undefined;
}
if (template.type === 'ControlField') {
fieldValue.value = undefined;
}
if (template.type === 'MainModelField') {
fieldValue.value = undefined;
}
if (template.type === 'SDXLRefinerModelField') {
fieldValue.value = undefined;
}
if (template.type === 'VaeModelField') {
fieldValue.value = undefined;
}
if (template.type === 'LoRAModelField') {
fieldValue.value = undefined;
}
if (template.type === 'ControlNetModelField') {
fieldValue.value = undefined;
}
if (template.type === 'Scheduler') {
fieldValue.value = 'euler';
} else {
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
}
return fieldValue;

View File

@@ -0,0 +1,45 @@
import * as png from '@stevebel/png';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import {
ImageMetadataAndWorkflow,
zCoreMetadata,
zWorkflow,
} from 'features/nodes/types/types';
import { get } from 'lodash-es';
export const getMetadataAndWorkflowFromImageBlob = async (
image: Blob
): Promise<ImageMetadataAndWorkflow> => {
const data: ImageMetadataAndWorkflow = {};
const buffer = await image.arrayBuffer();
const text = png.decode(buffer).text;
const rawMetadata = get(text, 'invokeai_metadata');
if (rawMetadata) {
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
if (metadataResult.success) {
data.metadata = metadataResult.data;
} else {
logger('system').error(
{ error: parseify(metadataResult.error) },
'Problem reading metadata from image'
);
}
}
const rawWorkflow = get(text, 'invokeai_workflow');
if (rawWorkflow) {
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
if (workflowResult.success) {
data.workflow = workflowResult.data;
} else {
logger('system').error(
{ error: parseify(workflowResult.error) },
'Problem reading workflow from image'
);
}
}
return data;
};

Some files were not shown because too many files have changed in this diff Show More