mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-06 08:35:10 -05:00
Merge branch 'main' into refactor/rename-get-logger
This commit is contained in:
@@ -60,7 +60,7 @@ class Config:
|
||||
thumbnail_path = None
|
||||
|
||||
def find_and_load(self):
|
||||
"""find the yaml config file and load"""
|
||||
"""Find the yaml config file and load"""
|
||||
root = app_config.root_path
|
||||
if not self.confirm_and_load(os.path.abspath(root)):
|
||||
print("\r\nSpecify custom database and outputs paths:")
|
||||
@@ -70,7 +70,7 @@ class Config:
|
||||
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
|
||||
|
||||
def confirm_and_load(self, invoke_root):
|
||||
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
|
||||
"""Validate a yaml path exists, confirms the user wants to use it and loads config."""
|
||||
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
|
||||
if os.path.exists(yaml_path):
|
||||
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
|
||||
@@ -337,33 +337,24 @@ class InvokeAIMetadataParser:
|
||||
|
||||
def map_scheduler(self, old_scheduler):
|
||||
"""Convert the legacy sampler names to matching 3.0 schedulers"""
|
||||
|
||||
# this was more elegant as a case statement, but that's not available in python 3.9
|
||||
if old_scheduler is None:
|
||||
return None
|
||||
|
||||
match (old_scheduler):
|
||||
case "ddim":
|
||||
return "ddim"
|
||||
case "plms":
|
||||
return "pnmd"
|
||||
case "k_lms":
|
||||
return "lms"
|
||||
case "k_dpm_2":
|
||||
return "kdpm_2"
|
||||
case "k_dpm_2_a":
|
||||
return "kdpm_2_a"
|
||||
case "dpmpp_2":
|
||||
return "dpmpp_2s"
|
||||
case "k_dpmpp_2":
|
||||
return "dpmpp_2m"
|
||||
case "k_dpmpp_2_a":
|
||||
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||
case "k_euler":
|
||||
return "euler"
|
||||
case "k_euler_a":
|
||||
return "euler_a"
|
||||
case "k_heun":
|
||||
return "heun"
|
||||
return None
|
||||
scheduler_map = dict(
|
||||
ddim="ddim",
|
||||
plms="pnmd",
|
||||
k_lms="lms",
|
||||
k_dpm_2="kdpm_2",
|
||||
k_dpm_2_a="kdpm_2_a",
|
||||
dpmpp_2="dpmpp_2s",
|
||||
k_dpmpp_2="dpmpp_2m",
|
||||
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
|
||||
k_euler="euler",
|
||||
k_euler_a="euler_a",
|
||||
k_heun="heun",
|
||||
)
|
||||
return scheduler_map.get(old_scheduler)
|
||||
|
||||
def split_prompt(self, raw_prompt: str):
|
||||
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
|
||||
@@ -524,27 +515,27 @@ class MediaImportProcessor:
|
||||
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
|
||||
)
|
||||
input_option = input("Specify desired board option: ")
|
||||
match (input_option):
|
||||
case "1":
|
||||
if len(board_names) < 1:
|
||||
print("\r\nThere are no existing board names to choose from. Select another option!")
|
||||
continue
|
||||
board_name = self.select_item_from_list(
|
||||
board_names, "board name", True, "Cancel, go back and choose a different board option."
|
||||
)
|
||||
if board_name is not None:
|
||||
# This was more elegant as a case statement, but not supported in python 3.9
|
||||
if input_option == "1":
|
||||
if len(board_names) < 1:
|
||||
print("\r\nThere are no existing board names to choose from. Select another option!")
|
||||
continue
|
||||
board_name = self.select_item_from_list(
|
||||
board_names, "board name", True, "Cancel, go back and choose a different board option."
|
||||
)
|
||||
if board_name is not None:
|
||||
return board_name
|
||||
elif input_option == "2":
|
||||
while True:
|
||||
board_name = input("Specify new/existing board name: ")
|
||||
if board_name:
|
||||
return board_name
|
||||
case "2":
|
||||
while True:
|
||||
board_name = input("Specify new/existing board name: ")
|
||||
if board_name:
|
||||
return board_name
|
||||
case "3":
|
||||
return "IMPORT"
|
||||
case "4":
|
||||
return f"IMPORT_{timestamp_string}"
|
||||
case "5":
|
||||
return "IMPORT_APPVERSION"
|
||||
elif input_option == "3":
|
||||
return "IMPORT"
|
||||
elif input_option == "4":
|
||||
return f"IMPORT_{timestamp_string}"
|
||||
elif input_option == "5":
|
||||
return "IMPORT_APPVERSION"
|
||||
|
||||
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
|
||||
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""
|
||||
|
||||
@@ -7,5 +7,4 @@ stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
src/services/api/schema.d.ts
|
||||
|
||||
@@ -7,8 +7,7 @@ index.html
|
||||
.yarn/
|
||||
.yalc/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
src/services/api/schema.d.ts
|
||||
docs/
|
||||
static/
|
||||
src/theme/css/overlayscrollbars.css
|
||||
|
||||
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
Normal file
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-7d912410.js
vendored
169
invokeai/frontend/web/dist/assets/App-7d912410.js
vendored
File diff suppressed because one or more lines are too long
310
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-707a230a.js
vendored
Normal file
310
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-707a230a.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-ext-wght-normal-848492d3.woff2) format("woff2-variations");unicode-range:U+0460-052F,U+1C80-1C88,U+20B4,U+2DE0-2DFF,U+A640-A69F,U+FE2E-FE2F}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-wght-normal-262a1054.woff2) format("woff2-variations");unicode-range:U+0301,U+0400-045F,U+0490-0491,U+04B0-04B1,U+2116}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-ext-wght-normal-fe977ddb.woff2) format("woff2-variations");unicode-range:U+1F00-1FFF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-wght-normal-89b4a3fe.woff2) format("woff2-variations");unicode-range:U+0370-03FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-vietnamese-wght-normal-ac4e131c.woff2) format("woff2-variations");unicode-range:U+0102-0103,U+0110-0111,U+0128-0129,U+0168-0169,U+01A0-01A1,U+01AF-01B0,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1EA0-1EF9,U+20AB}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-ext-wght-normal-45606f83.woff2) format("woff2-variations");unicode-range:U+0100-02AF,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1E00-1EFF,U+2020,U+20A0-20AB,U+20AD-20CF,U+2113,U+2C60-2C7F,U+A720-A7FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-wght-normal-450f3ba4.woff2) format("woff2-variations");unicode-range:U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD}/*!
|
||||
@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-ext-wght-normal-848492d3.woff2) format("woff2-variations");unicode-range:U+0460-052F,U+1C80-1C88,U+20B4,U+2DE0-2DFF,U+A640-A69F,U+FE2E-FE2F}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-wght-normal-262a1054.woff2) format("woff2-variations");unicode-range:U+0301,U+0400-045F,U+0490-0491,U+04B0-04B1,U+2116}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-ext-wght-normal-fe977ddb.woff2) format("woff2-variations");unicode-range:U+1F00-1FFF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-wght-normal-89b4a3fe.woff2) format("woff2-variations");unicode-range:U+0370-03FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-vietnamese-wght-normal-ac4e131c.woff2) format("woff2-variations");unicode-range:U+0102-0103,U+0110-0111,U+0128-0129,U+0168-0169,U+01A0-01A1,U+01AF-01B0,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1EA0-1EF9,U+20AB}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-ext-wght-normal-45606f83.woff2) format("woff2-variations");unicode-range:U+0100-02AF,U+0304,U+0308,U+0329,U+1E00-1E9F,U+1EF2-1EFF,U+2020,U+20A0-20AB,U+20AD-20CF,U+2113,U+2C60-2C7F,U+A720-A7FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-wght-normal-450f3ba4.woff2) format("woff2-variations");unicode-range:U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0304,U+0308,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD}/*!
|
||||
* OverlayScrollbars
|
||||
* Version: 2.2.1
|
||||
*
|
||||
File diff suppressed because one or more lines are too long
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
Normal file
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
Normal file
File diff suppressed because one or more lines are too long
151
invokeai/frontend/web/dist/assets/index-2c171c8f.js
vendored
151
invokeai/frontend/web/dist/assets/index-2c171c8f.js
vendored
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/menu-3d10c968.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/menu-3d10c968.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-2c171c8f.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-08cda350.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
||||
42
invokeai/frontend/web/dist/locales/en.json
vendored
42
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -19,7 +19,7 @@
|
||||
"toggleAutoscroll": "Toggle autoscroll",
|
||||
"toggleLogViewer": "Toggle Log Viewer",
|
||||
"showGallery": "Show Gallery",
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"showOptionsPanel": "Show Side Panel",
|
||||
"menu": "Menu"
|
||||
},
|
||||
"common": {
|
||||
@@ -52,7 +52,7 @@
|
||||
"img2img": "Image To Image",
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Node Editor",
|
||||
"nodes": "Workflow Editor",
|
||||
"batch": "Batch Manager",
|
||||
"modelManager": "Model Manager",
|
||||
"postprocessing": "Post Processing",
|
||||
@@ -95,7 +95,6 @@
|
||||
"statusModelConverted": "Model Converted",
|
||||
"statusMergingModels": "Merging Models",
|
||||
"statusMergedModels": "Models Merged",
|
||||
"pinOptionsPanel": "Pin Options Panel",
|
||||
"loading": "Loading",
|
||||
"loadingInvokeAI": "Loading Invoke AI",
|
||||
"random": "Random",
|
||||
@@ -116,7 +115,6 @@
|
||||
"maintainAspectRatio": "Maintain Aspect Ratio",
|
||||
"autoSwitchNewImages": "Auto-Switch to New Images",
|
||||
"singleColumnLayout": "Single Column Layout",
|
||||
"pinGallery": "Pin Gallery",
|
||||
"allImagesLoaded": "All Images Loaded",
|
||||
"loadMore": "Load More",
|
||||
"noImagesInGallery": "No Images to Display",
|
||||
@@ -133,6 +131,7 @@
|
||||
"generalHotkeys": "General Hotkeys",
|
||||
"galleryHotkeys": "Gallery Hotkeys",
|
||||
"unifiedCanvasHotkeys": "Unified Canvas Hotkeys",
|
||||
"nodesHotkeys": "Nodes Hotkeys",
|
||||
"invoke": {
|
||||
"title": "Invoke",
|
||||
"desc": "Generate an image"
|
||||
@@ -332,6 +331,10 @@
|
||||
"acceptStagingImage": {
|
||||
"title": "Accept Staging Image",
|
||||
"desc": "Accept Current Staging Area Image"
|
||||
},
|
||||
"addNodes": {
|
||||
"title": "Add Nodes",
|
||||
"desc": "Opens the add node menu"
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@@ -503,13 +506,15 @@
|
||||
"hiresStrength": "High Res Strength",
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
"codeformerFidelity": "Fidelity",
|
||||
"compositingSettingsHeader": "Compositing Settings",
|
||||
"maskAdjustmentsHeader": "Mask Adjustments",
|
||||
"maskBlur": "Mask Blur",
|
||||
"maskBlurMethod": "Mask Blur Method",
|
||||
"seamSize": "Seam Size",
|
||||
"seamBlur": "Seam Blur",
|
||||
"seamStrength": "Seam Strength",
|
||||
"seamSteps": "Seam Steps",
|
||||
"maskBlur": "Blur",
|
||||
"maskBlurMethod": "Blur Method",
|
||||
"coherencePassHeader": "Coherence Pass",
|
||||
"coherenceSteps": "Steps",
|
||||
"coherenceStrength": "Strength",
|
||||
"seamLowThreshold": "Low",
|
||||
"seamHighThreshold": "High",
|
||||
"scaleBeforeProcessing": "Scale Before Processing",
|
||||
"scaledWidth": "Scaled W",
|
||||
"scaledHeight": "Scaled H",
|
||||
@@ -565,10 +570,11 @@
|
||||
"useSlidersForAll": "Use Sliders For All Options",
|
||||
"showProgressInViewer": "Show Progress Images in Viewer",
|
||||
"antialiasProgressImages": "Antialias Progress Images",
|
||||
"autoChangeDimensions": "Update W/H To Model Defaults On Change",
|
||||
"resetWebUI": "Reset Web UI",
|
||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
|
||||
"resetComplete": "Web UI has been reset.",
|
||||
"consoleLogLevel": "Log Level",
|
||||
"shouldLogToConsole": "Console Logging",
|
||||
"developer": "Developer",
|
||||
@@ -708,14 +714,16 @@
|
||||
"ui": {
|
||||
"showProgressImages": "Show Progress Images",
|
||||
"hideProgressImages": "Hide Progress Images",
|
||||
"swapSizes": "Swap Sizes"
|
||||
"swapSizes": "Swap Sizes",
|
||||
"lockRatio": "Lock Ratio"
|
||||
},
|
||||
"nodes": {
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveGraph": "Save Graph",
|
||||
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
|
||||
"clearGraph": "Clear Graph",
|
||||
"clearGraphDesc": "Are you sure you want to clear all nodes?",
|
||||
"reloadNodeTemplates": "Reload Node Templates",
|
||||
"downloadWorkflow": "Download Workflow JSON",
|
||||
"loadWorkflow": "Load Workflow",
|
||||
"resetWorkflow": "Reset Workflow",
|
||||
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
|
||||
"resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.",
|
||||
"zoomInNodes": "Zoom In",
|
||||
"zoomOutNodes": "Zoom Out",
|
||||
"fitViewportNodes": "Fit View",
|
||||
|
||||
@@ -74,6 +74,8 @@
|
||||
"@nanostores/react": "^0.7.1",
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
"@stevebel/png": "^1.5.1",
|
||||
"compare-versions": "^6.1.0",
|
||||
"dateformat": "^5.0.3",
|
||||
"formik": "^2.4.3",
|
||||
"framer-motion": "^10.16.1",
|
||||
@@ -110,6 +112,7 @@
|
||||
"roarr": "^7.15.1",
|
||||
"serialize-error": "^11.0.1",
|
||||
"socket.io-client": "^4.7.2",
|
||||
"type-fest": "^4.2.0",
|
||||
"use-debounce": "^9.0.4",
|
||||
"use-image": "^1.1.1",
|
||||
"uuid": "^9.0.0",
|
||||
|
||||
@@ -506,12 +506,13 @@
|
||||
"hiresStrength": "High Res Strength",
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
"codeformerFidelity": "Fidelity",
|
||||
"compositingSettingsHeader": "Compositing Settings",
|
||||
"maskAdjustmentsHeader": "Mask Adjustments",
|
||||
"maskBlur": "Mask Blur",
|
||||
"maskBlurMethod": "Mask Blur Method",
|
||||
"maskBlur": "Blur",
|
||||
"maskBlurMethod": "Blur Method",
|
||||
"coherencePassHeader": "Coherence Pass",
|
||||
"coherenceSteps": "Coherence Pass Steps",
|
||||
"coherenceStrength": "Coherence Pass Strength",
|
||||
"coherenceSteps": "Steps",
|
||||
"coherenceStrength": "Strength",
|
||||
"seamLowThreshold": "Low",
|
||||
"seamHighThreshold": "High",
|
||||
"scaleBeforeProcessing": "Scale Before Processing",
|
||||
@@ -569,6 +570,7 @@
|
||||
"useSlidersForAll": "Use Sliders For All Options",
|
||||
"showProgressInViewer": "Show Progress Images in Viewer",
|
||||
"antialiasProgressImages": "Antialias Progress Images",
|
||||
"autoChangeDimensions": "Update W/H To Model Defaults On Change",
|
||||
"resetWebUI": "Reset Web UI",
|
||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||
@@ -712,11 +714,12 @@
|
||||
"ui": {
|
||||
"showProgressImages": "Show Progress Images",
|
||||
"hideProgressImages": "Hide Progress Images",
|
||||
"swapSizes": "Swap Sizes"
|
||||
"swapSizes": "Swap Sizes",
|
||||
"lockRatio": "Lock Ratio"
|
||||
},
|
||||
"nodes": {
|
||||
"reloadNodeTemplates": "Reload Node Templates",
|
||||
"saveWorkflow": "Save Workflow",
|
||||
"downloadWorkflow": "Download Workflow JSON",
|
||||
"loadWorkflow": "Load Workflow",
|
||||
"resetWorkflow": "Reset Workflow",
|
||||
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
|
||||
|
||||
@@ -14,6 +14,7 @@ import i18n from 'i18n';
|
||||
import { size } from 'lodash-es';
|
||||
import { ReactNode, memo, useCallback, useEffect } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
import { usePreselectedImage } from '../../features/parameters/hooks/usePreselectedImage';
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
@@ -23,13 +24,22 @@ const DEFAULT_CONFIG = {};
|
||||
interface Props {
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
selectedImage?: {
|
||||
imageName: string;
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
headerComponent,
|
||||
selectedImage,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
const { handlePreselectedImage } = usePreselectedImage();
|
||||
const handleReset = useCallback(() => {
|
||||
localStorage.clear();
|
||||
location.reload();
|
||||
@@ -51,6 +61,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
dispatch(appStarted());
|
||||
}, [dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
handlePreselectedImage(selectedImage);
|
||||
}, [handlePreselectedImage, selectedImage]);
|
||||
|
||||
return (
|
||||
<ErrorBoundary
|
||||
onReset={handleReset}
|
||||
|
||||
@@ -26,6 +26,10 @@ interface Props extends PropsWithChildren {
|
||||
headerComponent?: ReactNode;
|
||||
middleware?: Middleware[];
|
||||
projectId?: string;
|
||||
selectedImage?: {
|
||||
imageName: string;
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({
|
||||
@@ -35,6 +39,7 @@ const InvokeAIUI = ({
|
||||
headerComponent,
|
||||
middleware,
|
||||
projectId,
|
||||
selectedImage,
|
||||
}: Props) => {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
@@ -81,7 +86,11 @@ const InvokeAIUI = ({
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<AppDndContext>
|
||||
<App config={config} headerComponent={headerComponent} />
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
selectedImage={selectedImage}
|
||||
/>
|
||||
</AppDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
|
||||
@@ -15,7 +15,9 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasImageToControlNetListener } from './listeners/canvasImageToControlNet';
|
||||
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
|
||||
import { addCanvasMaskToControlNetListener } from './listeners/canvasMaskToControlNet';
|
||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
@@ -41,6 +43,8 @@ import {
|
||||
addImageUploadedFulfilledListener,
|
||||
addImageUploadedRejectedListener,
|
||||
} from './listeners/imageUploaded';
|
||||
import { addImagesStarredListener } from './listeners/imagesStarred';
|
||||
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
|
||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||
import { addModelSelectedListener } from './listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
||||
@@ -80,8 +84,7 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addImagesStarredListener } from './listeners/imagesStarred';
|
||||
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@@ -137,6 +140,8 @@ addSessionReadyToInvokeListener();
|
||||
// Canvas actions
|
||||
addCanvasSavedToGalleryListener();
|
||||
addCanvasMaskSavedToGalleryListener();
|
||||
addCanvasImageToControlNetListener();
|
||||
addCanvasMaskToControlNetListener();
|
||||
addCanvasDownloadedAsImageListener();
|
||||
addCanvasCopiedToClipboardListener();
|
||||
addCanvasMergedListener();
|
||||
@@ -198,6 +203,9 @@ addBoardIdSelectedListener();
|
||||
// Node schemas
|
||||
addReceivedOpenAPISchemaListener();
|
||||
|
||||
// Workflows
|
||||
addWorkflowLoadedListener();
|
||||
|
||||
// DND
|
||||
addImageDroppedListener();
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { canvasImageToControlNet } from 'features/canvas/store/actions';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addCanvasImageToControlNetListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasImageToControlNet,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
log.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Saving Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
const imageDTO = await dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([blob], 'savedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
toastOptions: { title: 'Canvas Sent to ControlNet & Assets' },
|
||||
},
|
||||
})
|
||||
).unwrap();
|
||||
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId: action.payload.controlNet.controlNetId,
|
||||
controlImage: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,70 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addCanvasMaskToControlNetListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasMaskToControlNet,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
|
||||
const canvasBlobsAndImageData = await getCanvasData(
|
||||
state.canvas.layerState,
|
||||
state.canvas.boundingBoxCoordinates,
|
||||
state.canvas.boundingBoxDimensions,
|
||||
state.canvas.isMaskEnabled,
|
||||
state.canvas.shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { maskBlob } = canvasBlobsAndImageData;
|
||||
|
||||
if (!maskBlob) {
|
||||
log.error('Problem getting mask layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Importing Mask',
|
||||
description: 'Unable to export mask',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
const imageDTO = await dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
toastOptions: { title: 'Mask Sent to ControlNet & Assets' },
|
||||
},
|
||||
})
|
||||
).unwrap();
|
||||
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId: action.payload.controlNet.controlNetId,
|
||||
controlImage: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,9 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
modelChanged,
|
||||
setHeight,
|
||||
setWidth,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
|
||||
@@ -74,6 +77,22 @@ export const addModelSelectedListener = () => {
|
||||
}
|
||||
}
|
||||
|
||||
// Update Width / Height / Bounding Box Dimensions on Model Change
|
||||
if (
|
||||
state.generation.model?.base_model !== newModel.base_model &&
|
||||
state.ui.shouldAutoChangeDimensions
|
||||
) {
|
||||
if (['sdxl', 'sdxl-refiner'].includes(newModel.base_model)) {
|
||||
dispatch(setWidth(1024));
|
||||
dispatch(setHeight(1024));
|
||||
dispatch(setBoundingBoxDimensions({ width: 1024, height: 1024 }));
|
||||
} else {
|
||||
dispatch(setWidth(512));
|
||||
dispatch(setHeight(512));
|
||||
dispatch(setBoundingBoxDimensions({ width: 512, height: 512 }));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(modelChanged(newModel));
|
||||
},
|
||||
});
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addWorkflowLoadedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const workflow = action.payload;
|
||||
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||
|
||||
const { workflow: validatedWorkflow, errors } = validateWorkflow(
|
||||
workflow,
|
||||
nodeTemplates
|
||||
);
|
||||
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
|
||||
if (!errors.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded with Warnings',
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
errors.forEach(({ message, ...rest }) => {
|
||||
log.warn(rest, message);
|
||||
});
|
||||
}
|
||||
|
||||
dispatch(setActiveTab('nodes'));
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -6,11 +6,11 @@ import {
|
||||
configureStore,
|
||||
} from '@reduxjs/toolkit';
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
|
||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
|
||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
|
||||
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
|
||||
import loraReducer from 'features/lora/store/loraSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
|
||||
@@ -86,8 +86,8 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
||||
<Box
|
||||
sx={{
|
||||
p: 2,
|
||||
pt: 3,
|
||||
p: 4,
|
||||
pb: 4,
|
||||
borderBottomRadius: 'base',
|
||||
bg: 'base.150',
|
||||
_dark: {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import {
|
||||
KeyboardEvent,
|
||||
ReactNode,
|
||||
@@ -18,8 +20,6 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/types';
|
||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
isSelected: boolean;
|
||||
isHovered: boolean;
|
||||
};
|
||||
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||
const shadow = useMemo(() => {
|
||||
if (isSelected && isHovered) {
|
||||
return 'nodeHoveredSelected.light';
|
||||
}
|
||||
if (isSelected) {
|
||||
return 'nodeSelected.light';
|
||||
}
|
||||
if (isHovered) {
|
||||
return 'nodeHovered.light';
|
||||
}
|
||||
return undefined;
|
||||
}, [isHovered, isSelected]);
|
||||
const shadowDark = useMemo(() => {
|
||||
if (isSelected && isHovered) {
|
||||
return 'nodeHoveredSelected.dark';
|
||||
}
|
||||
if (isSelected) {
|
||||
return 'nodeSelected.dark';
|
||||
}
|
||||
if (isHovered) {
|
||||
return 'nodeHovered.dark';
|
||||
}
|
||||
return undefined;
|
||||
}, [isHovered, isSelected]);
|
||||
return (
|
||||
<Box
|
||||
className="selection-box"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
bottom: 0,
|
||||
insetInlineStart: 0,
|
||||
borderRadius: 'base',
|
||||
opacity: isSelected || isHovered ? 1 : 0.5,
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
pointerEvents: 'none',
|
||||
shadow,
|
||||
_dark: {
|
||||
shadow: shadowDark,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SelectionOverlay);
|
||||
@@ -63,7 +63,11 @@ const selector = createSelector(
|
||||
return;
|
||||
}
|
||||
|
||||
if (fieldTemplate.required && !field.value && !hasConnection) {
|
||||
if (
|
||||
fieldTemplate.required &&
|
||||
field.value === undefined &&
|
||||
!hasConnection
|
||||
) {
|
||||
reasons.push(
|
||||
`${node.data.label || nodeTemplate.title} -> ${
|
||||
field.label || fieldTemplate.title
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
export const colorTokenToCssVar = (colorToken: string) =>
|
||||
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
|
||||
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
||||
@@ -20,3 +21,11 @@ export const canvasMerged = createAction('canvas/canvasMerged');
|
||||
export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
|
||||
'canvas/stagingAreaImageSaved'
|
||||
);
|
||||
|
||||
export const canvasMaskToControlNet = createAction<{
|
||||
controlNet: ControlNetConfig;
|
||||
}>('canvas/canvasMaskToControlNet');
|
||||
|
||||
export const canvasImageToControlNet = createAction<{
|
||||
controlNet: ControlNetConfig;
|
||||
}>('canvas/canvasImageToControlNet');
|
||||
|
||||
@@ -17,11 +17,13 @@ import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useToggle } from 'react-use';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import ControlNetImagePreview from './ControlNetImagePreview';
|
||||
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
||||
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||
@@ -36,6 +38,8 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
const { controlNetId } = controlNet;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
@@ -108,6 +112,9 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
>
|
||||
<ParamControlNetModel controlNet={controlNet} />
|
||||
</Box>
|
||||
{activeTabName === 'unifiedCanvas' && (
|
||||
<ControlNetCanvasImageImports controlNet={controlNet} />
|
||||
)}
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
tooltip="Duplicate"
|
||||
@@ -167,6 +174,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
|
||||
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||
<Flex
|
||||
|
||||
@@ -5,13 +5,21 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'features/dnd/types';
|
||||
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
|
||||
import {
|
||||
useAddImageToBoardMutation,
|
||||
useChangeImageIsIntermediateMutation,
|
||||
useGetImageDTOQuery,
|
||||
useRemoveImageFromBoardMutation,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/types';
|
||||
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
|
||||
import {
|
||||
@@ -26,11 +34,13 @@ type Props = {
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
({ controlNet, gallery }) => {
|
||||
const { pendingControlImages } = controlNet;
|
||||
const { autoAddBoardId } = gallery;
|
||||
|
||||
return {
|
||||
pendingControlImages,
|
||||
autoAddBoardId,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@@ -47,7 +57,8 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { pendingControlImages } = useAppSelector(selector);
|
||||
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||
|
||||
@@ -59,9 +70,57 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
processedControlImageName ?? skipToken
|
||||
);
|
||||
|
||||
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
|
||||
const [addToBoard] = useAddImageToBoardMutation();
|
||||
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
const handleSaveControlImage = useCallback(async () => {
|
||||
if (!processedControlImage) {
|
||||
return;
|
||||
}
|
||||
|
||||
await changeIsIntermediate({
|
||||
imageDTO: processedControlImage,
|
||||
is_intermediate: false,
|
||||
}).unwrap();
|
||||
|
||||
if (autoAddBoardId !== 'none') {
|
||||
addToBoard({
|
||||
imageDTO: processedControlImage,
|
||||
board_id: autoAddBoardId,
|
||||
});
|
||||
} else {
|
||||
removeFromBoard({ imageDTO: processedControlImage });
|
||||
}
|
||||
}, [
|
||||
processedControlImage,
|
||||
changeIsIntermediate,
|
||||
autoAddBoardId,
|
||||
addToBoard,
|
||||
removeFromBoard,
|
||||
]);
|
||||
|
||||
const handleSetControlImageToDimensions = useCallback(() => {
|
||||
if (!controlImage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: controlImage.width,
|
||||
height: controlImage.height,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
dispatch(setWidth(controlImage.width));
|
||||
dispatch(setHeight(controlImage.height));
|
||||
}
|
||||
}, [controlImage, activeTabName, dispatch]);
|
||||
|
||||
const handleMouseEnter = useCallback(() => {
|
||||
setIsMouseOverImage(true);
|
||||
}, []);
|
||||
@@ -121,13 +180,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
imageDTO={controlImage}
|
||||
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||
postUploadAction={postUploadAction}
|
||||
>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <FaUndo /> : undefined}
|
||||
tooltip="Reset Control Image"
|
||||
/>
|
||||
</IAIDndImage>
|
||||
/>
|
||||
|
||||
<Box
|
||||
sx={{
|
||||
@@ -148,14 +201,29 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
imageDTO={processedControlImage}
|
||||
isUploadDisabled={true}
|
||||
isDropDisabled={!isEnabled}
|
||||
>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <FaUndo /> : undefined}
|
||||
tooltip="Reset Control Image"
|
||||
/>
|
||||
</IAIDndImage>
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <FaUndo /> : undefined}
|
||||
tooltip="Reset Control Image"
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <FaSave size={16} /> : undefined}
|
||||
tooltip="Save Control Image"
|
||||
styleOverrides={{ marginTop: 6 }}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <FaRulerVertical size={16} /> : undefined}
|
||||
tooltip="Set Control Image Dimensions To W/H"
|
||||
styleOverrides={{ marginTop: 12 }}
|
||||
/>
|
||||
</>
|
||||
|
||||
{pendingControlImages.includes(controlNetId) && (
|
||||
<Flex
|
||||
sx={{
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import {
|
||||
canvasImageToControlNet,
|
||||
canvasMaskToControlNet,
|
||||
} from 'features/canvas/store/actions';
|
||||
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaImage, FaMask } from 'react-icons/fa';
|
||||
|
||||
type ControlNetCanvasImageImportsProps = {
|
||||
controlNet: ControlNetConfig;
|
||||
};
|
||||
|
||||
const ControlNetCanvasImageImports = (
|
||||
props: ControlNetCanvasImageImportsProps
|
||||
) => {
|
||||
const { controlNet } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleImportImageFromCanvas = useCallback(() => {
|
||||
dispatch(canvasImageToControlNet({ controlNet }));
|
||||
}, [controlNet, dispatch]);
|
||||
|
||||
const handleImportMaskFromCanvas = useCallback(() => {
|
||||
dispatch(canvasMaskToControlNet({ controlNet }));
|
||||
}, [controlNet, dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
icon={<FaImage />}
|
||||
tooltip="Import Image From Canvas"
|
||||
aria-label="Import Image From Canvas"
|
||||
onClick={handleImportImageFromCanvas}
|
||||
/>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
icon={<FaMask />}
|
||||
tooltip="Import Mask From Canvas"
|
||||
aria-label="Import Mask From Canvas"
|
||||
onClick={handleImportMaskFromCanvas}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetCanvasImageImports);
|
||||
@@ -4,11 +4,11 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import { memo } from 'react';
|
||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
|
||||
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
|
||||
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
|
||||
@@ -15,6 +15,7 @@ import { BoardDTO } from 'services/api/types';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
|
||||
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
|
||||
type Props = {
|
||||
board?: BoardDTO;
|
||||
@@ -33,12 +34,16 @@ const BoardContextMenu = ({
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(stateSelector, ({ gallery, system }) => {
|
||||
const isAutoAdd = gallery.autoAddBoardId === board_id;
|
||||
const isProcessing = system.isProcessing;
|
||||
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
|
||||
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
|
||||
}),
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ gallery, system }) => {
|
||||
const isAutoAdd = gallery.autoAddBoardId === board_id;
|
||||
const isProcessing = system.isProcessing;
|
||||
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
|
||||
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[board_id]
|
||||
);
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ import {
|
||||
MenuButton,
|
||||
MenuList,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
@@ -37,12 +38,12 @@ import {
|
||||
FaSeedling,
|
||||
FaShareAlt,
|
||||
} from 'react-icons/fa';
|
||||
import { MdDeviceHub } from 'react-icons/md';
|
||||
import {
|
||||
useGetImageDTOQuery,
|
||||
useGetImageMetadataQuery,
|
||||
useGetImageMetadataFromFileQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToImg2Img } from '../../store/actions';
|
||||
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||
|
||||
@@ -101,22 +102,27 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
||||
lastSelectedImage,
|
||||
500
|
||||
);
|
||||
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(
|
||||
lastSelectedImage?.image_name ?? skipToken
|
||||
);
|
||||
|
||||
const { currentData: metadataData } = useGetImageMetadataQuery(
|
||||
debounceState.isPending()
|
||||
? skipToken
|
||||
: debouncedMetadataQueryArg?.image_name ?? skipToken
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
lastSelectedImage ?? skipToken,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const metadata = metadataData?.metadata;
|
||||
const handleLoadWorkflow = useCallback(() => {
|
||||
if (!workflow) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoadRequested(workflow));
|
||||
}, [dispatch, workflow]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(metadata);
|
||||
@@ -153,6 +159,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
useHotkeys('p', handleUsePrompt, [imageDTO]);
|
||||
|
||||
useHotkeys('w', handleLoadWorkflow, [workflow]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
dispatch(sentImageToImg2Img());
|
||||
dispatch(initialImageSelected(imageDTO));
|
||||
@@ -259,22 +267,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
icon={<MdDeviceHub />}
|
||||
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||
isDisabled={!workflow}
|
||||
onClick={handleLoadWorkflow}
|
||||
/>
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
icon={<FaQuoteRight />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={!metadata?.positive_prompt}
|
||||
onClick={handleUsePrompt}
|
||||
/>
|
||||
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
icon={<FaSeedling />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!metadata?.seed}
|
||||
onClick={handleUseSeed}
|
||||
/>
|
||||
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
icon={<FaAsterisk />}
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Flex, MenuItem, Text } from '@chakra-ui/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
@@ -26,15 +25,15 @@ import {
|
||||
FaShare,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||
import { MdDeviceHub, MdStar, MdStarBorder } from 'react-icons/md';
|
||||
import {
|
||||
useGetImageMetadataQuery,
|
||||
useGetImageMetadataFromFileQuery,
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
imageDTO: ImageDTO;
|
||||
@@ -50,15 +49,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
||||
imageDTO.image_name,
|
||||
500
|
||||
);
|
||||
|
||||
const { currentData } = useGetImageMetadataQuery(
|
||||
debounceState.isPending()
|
||||
? skipToken
|
||||
: debouncedMetadataQueryArg ?? skipToken
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
imageDTO,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
@@ -67,8 +66,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||
useCopyImageToClipboard();
|
||||
|
||||
const metadata = currentData?.metadata;
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
@@ -99,6 +96,13 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
recallSeed(metadata?.seed);
|
||||
}, [metadata?.seed, recallSeed]);
|
||||
|
||||
const handleLoadWorkflow = useCallback(() => {
|
||||
if (!workflow) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoadRequested(workflow));
|
||||
}, [dispatch, workflow]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
dispatch(sentImageToImg2Img());
|
||||
dispatch(initialImageSelected(imageDTO));
|
||||
@@ -118,7 +122,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
}, [dispatch, imageDTO, t, toaster]);
|
||||
|
||||
const handleUseAllParameters = useCallback(() => {
|
||||
console.log(metadata);
|
||||
recallAllParameters(metadata);
|
||||
}, [metadata, recallAllParameters]);
|
||||
|
||||
@@ -169,27 +172,34 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
{t('parameters.downloadImage')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<FaQuoteRight />}
|
||||
icon={isLoading ? <SpinnerIcon /> : <MdDeviceHub />}
|
||||
onClickCapture={handleLoadWorkflow}
|
||||
isDisabled={isLoading || !workflow}
|
||||
>
|
||||
{t('nodes.loadWorkflow')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
|
||||
onClickCapture={handleRecallPrompt}
|
||||
isDisabled={
|
||||
metadata?.positive_prompt === undefined &&
|
||||
metadata?.negative_prompt === undefined
|
||||
isLoading ||
|
||||
(metadata?.positive_prompt === undefined &&
|
||||
metadata?.negative_prompt === undefined)
|
||||
}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
icon={<FaSeedling />}
|
||||
icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
|
||||
onClickCapture={handleRecallSeed}
|
||||
isDisabled={metadata?.seed === undefined}
|
||||
isDisabled={isLoading || metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<FaAsterisk />}
|
||||
icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={!metadata}
|
||||
isDisabled={isLoading || !metadata}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
@@ -228,20 +238,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
>
|
||||
{t('gallery.deleteImage')}
|
||||
</MenuItem>
|
||||
{metadata?.created_by && (
|
||||
<Flex
|
||||
sx={{
|
||||
padding: '5px 10px',
|
||||
marginTop: '5px',
|
||||
}}
|
||||
>
|
||||
<Text fontSize="xs" fontWeight="bold">
|
||||
Created by {metadata?.created_by}
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SingleSelectionMenuItems);
|
||||
|
||||
const SpinnerIcon = () => (
|
||||
<Flex w="14px" alignItems="center" justifyContent="center">
|
||||
<Spinner size="xs" />
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -39,7 +39,7 @@ const ImageGalleryContent = () => {
|
||||
const { galleryView } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { isOpen: isBoardListOpen, onToggle: onToggleBoardList } =
|
||||
useDisclosure();
|
||||
useDisclosure({ defaultIsOpen: true });
|
||||
|
||||
const handleClickImages = useCallback(() => {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
ImageDraggableData,
|
||||
TypesafeDraggableData,
|
||||
} from 'features/dnd/types';
|
||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts';
|
||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
|
||||
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
|
||||
import { isString } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { FaCopy, FaSave } from 'react-icons/fa';
|
||||
import { FaCopy, FaDownload } from 'react-icons/fa';
|
||||
|
||||
type Props = {
|
||||
label: string;
|
||||
@@ -23,7 +23,7 @@ const DataViewer = (props: Props) => {
|
||||
navigator.clipboard.writeText(dataString);
|
||||
}, [dataString]);
|
||||
|
||||
const handleSave = useCallback(() => {
|
||||
const handleDownload = useCallback(() => {
|
||||
const blob = new Blob([dataString]);
|
||||
const a = document.createElement('a');
|
||||
a.href = URL.createObjectURL(blob);
|
||||
@@ -73,13 +73,13 @@ const DataViewer = (props: Props) => {
|
||||
</Box>
|
||||
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
|
||||
{withDownload && (
|
||||
<Tooltip label={`Save ${label} JSON`}>
|
||||
<Tooltip label={`Download ${label} JSON`}>
|
||||
<IconButton
|
||||
aria-label={`Save ${label} JSON`}
|
||||
icon={<FaSave />}
|
||||
aria-label={`Download ${label} JSON`}
|
||||
icon={<FaDownload />}
|
||||
variant="ghost"
|
||||
opacity={0.7}
|
||||
onClick={handleSave}
|
||||
onClick={handleDownload}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { CoreMetadata } from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { UnsafeImageMetadata } from 'services/api/types';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
type Props = {
|
||||
metadata?: UnsafeImageMetadata['metadata'];
|
||||
metadata?: CoreMetadata;
|
||||
};
|
||||
|
||||
const ImageMetadataActions = (props: Props) => {
|
||||
@@ -94,20 +94,22 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallNegativePrompt}
|
||||
/>
|
||||
)}
|
||||
{metadata.seed !== undefined && (
|
||||
{metadata.seed !== undefined && metadata.seed !== null && (
|
||||
<ImageMetadataItem
|
||||
label="Seed"
|
||||
value={metadata.seed}
|
||||
onClick={handleRecallSeed}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined && (
|
||||
<ImageMetadataItem
|
||||
label="Model"
|
||||
value={metadata.model.model_name}
|
||||
onClick={handleRecallModel}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined &&
|
||||
metadata.model !== null &&
|
||||
metadata.model.model_name && (
|
||||
<ImageMetadataItem
|
||||
label="Model"
|
||||
value={metadata.model.model_name}
|
||||
onClick={handleRecallModel}
|
||||
/>
|
||||
)}
|
||||
{metadata.width && (
|
||||
<ImageMetadataItem
|
||||
label="Width"
|
||||
@@ -150,7 +152,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallSteps}
|
||||
/>
|
||||
)}
|
||||
{metadata.cfg_scale !== undefined && (
|
||||
{metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
|
||||
<ImageMetadataItem
|
||||
label="CFG scale"
|
||||
value={metadata.cfg_scale}
|
||||
|
||||
@@ -9,14 +9,12 @@ import {
|
||||
Tabs,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { memo } from 'react';
|
||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import ImageMetadataActions from './ImageMetadataActions';
|
||||
import DataViewer from './DataViewer';
|
||||
import ImageMetadataActions from './ImageMetadataActions';
|
||||
|
||||
type ImageMetadataViewerProps = {
|
||||
image: ImageDTO;
|
||||
@@ -29,18 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
// dispatch(setShouldShowImageDetails(false));
|
||||
// });
|
||||
|
||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
||||
image.image_name,
|
||||
500
|
||||
);
|
||||
|
||||
const { currentData } = useGetImageMetadataQuery(
|
||||
debounceState.isPending()
|
||||
? skipToken
|
||||
: debouncedMetadataQueryArg ?? skipToken
|
||||
);
|
||||
const metadata = currentData?.metadata;
|
||||
const graph = currentData?.graph;
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -71,17 +63,17 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
||||
>
|
||||
<TabList>
|
||||
<Tab>Core Metadata</Tab>
|
||||
<Tab>Metadata</Tab>
|
||||
<Tab>Image Details</Tab>
|
||||
<Tab>Graph</Tab>
|
||||
<Tab>Workflow</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels>
|
||||
<TabPanel>
|
||||
{metadata ? (
|
||||
<DataViewer data={metadata} label="Core Metadata" />
|
||||
<DataViewer data={metadata} label="Metadata" />
|
||||
) : (
|
||||
<IAINoContentFallback label="No core metadata found" />
|
||||
<IAINoContentFallback label="No metadata found" />
|
||||
)}
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
@@ -92,10 +84,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
)}
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
{graph ? (
|
||||
<DataViewer data={graph} label="Graph" />
|
||||
{workflow ? (
|
||||
<DataViewer data={workflow} label="Workflow" />
|
||||
) : (
|
||||
<IAINoContentFallback label="No graph found" />
|
||||
<IAINoContentFallback label="No workflow found" />
|
||||
)}
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
|
||||
@@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@@ -13,6 +14,7 @@ import {
|
||||
OnConnectStart,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnInit,
|
||||
OnMoveEnd,
|
||||
OnNodesChange,
|
||||
OnNodesDelete,
|
||||
@@ -147,6 +149,11 @@ export const Flow = () => {
|
||||
dispatch(contextMenusClosed());
|
||||
}, [dispatch]);
|
||||
|
||||
const onInit: OnInit = useCallback((flow) => {
|
||||
$flow.set(flow);
|
||||
flow.fitView();
|
||||
}, []);
|
||||
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionCopied());
|
||||
@@ -170,6 +177,7 @@ export const Flow = () => {
|
||||
edgeTypes={edgeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onInit={onInit}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import { Flex, Image, Text } from '@chakra-ui/react';
|
||||
import { useState, PropsWithChildren, memo } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { Flex, Image, Text } from '@chakra-ui/react';
|
||||
import { motion } from 'framer-motion';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import NodeWrapper from '../common/NodeWrapper';
|
||||
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { PropsWithChildren, memo } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import NodeWrapper from '../common/NodeWrapper';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ system, gallery }) => {
|
||||
const imageDTO = gallery.selection[gallery.selection.length - 1];
|
||||
@@ -54,44 +56,90 @@ const CurrentImageNode = (props: NodeProps) => {
|
||||
|
||||
export default memo(CurrentImageNode);
|
||||
|
||||
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
|
||||
<NodeWrapper
|
||||
nodeId={props.nodeProps.data.id}
|
||||
selected={props.nodeProps.selected}
|
||||
width={384}
|
||||
>
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
}}
|
||||
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => {
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
const handleMouseEnter = () => {
|
||||
setIsHovering(true);
|
||||
};
|
||||
|
||||
const handleMouseLeave = () => {
|
||||
setIsHovering(false);
|
||||
};
|
||||
|
||||
return (
|
||||
<NodeWrapper
|
||||
nodeId={props.nodeProps.id}
|
||||
selected={props.nodeProps.selected}
|
||||
width={384}
|
||||
>
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
onMouseEnter={handleMouseEnter}
|
||||
onMouseLeave={handleMouseLeave}
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={{
|
||||
borderTopRadius: 'base',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 8,
|
||||
position: 'relative',
|
||||
flexDirection: 'column',
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
fontWeight: 600,
|
||||
color: 'base.700',
|
||||
_dark: { color: 'base.200' },
|
||||
borderTopRadius: 'base',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 8,
|
||||
}}
|
||||
>
|
||||
Current Image
|
||||
</Text>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
fontWeight: 600,
|
||||
color: 'base.700',
|
||||
_dark: { color: 'base.200' },
|
||||
}}
|
||||
>
|
||||
Current Image
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex
|
||||
layerStyle="nodeBody"
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
borderBottomRadius: 'base',
|
||||
p: 2,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
{isHovering && (
|
||||
<motion.div
|
||||
key="nextPrevButtons"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 40,
|
||||
left: -2,
|
||||
right: -2,
|
||||
bottom: 0,
|
||||
pointerEvents: 'none',
|
||||
}}
|
||||
>
|
||||
<NextPrevImageButtons />
|
||||
</motion.div>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex
|
||||
layerStyle="nodeBody"
|
||||
sx={{ w: 'full', h: 'full', borderBottomRadius: 'base', p: 2 }}
|
||||
>
|
||||
{props.children}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</NodeWrapper>
|
||||
);
|
||||
</NodeWrapper>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const embedWorkflow = useEmbedWorkflow(nodeId);
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
nodeEmbedWorkflowChanged({
|
||||
nodeId,
|
||||
embedWorkflow: e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
|
||||
if (!hasImageOutput) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
|
||||
<Checkbox
|
||||
className="nopan"
|
||||
size="sm"
|
||||
onChange={handleChange}
|
||||
isChecked={embedWorkflow}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(EmbedWorkflowCheckbox);
|
||||
@@ -41,7 +41,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
flexDirection: 'column',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
py: 1,
|
||||
py: 2,
|
||||
gap: 1,
|
||||
borderBottomRadius: withFooter ? 0 : 'base',
|
||||
}}
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
import {
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Spacer,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
|
||||
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { memo } from 'react';
|
||||
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
|
||||
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
@@ -27,48 +19,13 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||
px: 2,
|
||||
py: 0,
|
||||
h: 6,
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<Spacer />
|
||||
<SaveImageCheckbox nodeId={nodeId} />
|
||||
<EmbedWorkflowCheckbox nodeId={nodeId} />
|
||||
<SaveToGalleryCheckbox nodeId={nodeId} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InvocationNodeFooter);
|
||||
|
||||
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const is_intermediate = useIsIntermediate(nodeId);
|
||||
const handleChangeIsIntermediate = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
fieldBooleanValueChanged({
|
||||
nodeId,
|
||||
fieldName: 'is_intermediate',
|
||||
value: !e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
|
||||
if (!hasImageOutput) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
|
||||
<Checkbox
|
||||
className="nopan"
|
||||
size="sm"
|
||||
onChange={handleChangeIsIntermediate}
|
||||
isChecked={!is_intermediate}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
SaveImageCheckbox.displayName = 'SaveImageCheckbox';
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Icon,
|
||||
Modal,
|
||||
ModalBody,
|
||||
@@ -14,16 +12,16 @@ import {
|
||||
Tooltip,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
import NotesTextarea from './NotesTextarea';
|
||||
import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
@@ -33,6 +31,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const title = useNodeTemplateTitle(nodeId);
|
||||
const doVersionsMatch = useDoNodeVersionsMatch(nodeId);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -54,7 +53,11 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
||||
>
|
||||
<Icon
|
||||
as={FaInfoCircle}
|
||||
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
|
||||
sx={{
|
||||
boxSize: 4,
|
||||
w: 8,
|
||||
color: doVersionsMatch ? 'base.400' : 'error.400',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
@@ -80,45 +83,78 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
|
||||
const title = useMemo(() => {
|
||||
if (data?.label && nodeTemplate?.title) {
|
||||
return `${data.label} (${nodeTemplate.title})`;
|
||||
}
|
||||
|
||||
if (data?.label && !nodeTemplate) {
|
||||
return data.label;
|
||||
}
|
||||
|
||||
if (!data?.label && nodeTemplate) {
|
||||
return nodeTemplate.title;
|
||||
}
|
||||
|
||||
return 'Unknown Node';
|
||||
}, [data, nodeTemplate]);
|
||||
|
||||
const versionComponent = useMemo(() => {
|
||||
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!data.version) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version unknown
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodeTemplate.version) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (unknown template)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (update node)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (update app)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return <Text as="span">Version {data.version}</Text>;
|
||||
}, [data, nodeTemplate]);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex sx={{ flexDir: 'column' }}>
|
||||
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
|
||||
<Text as="span" sx={{ fontWeight: 600 }}>
|
||||
{title}
|
||||
</Text>
|
||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||
{nodeTemplate?.description}
|
||||
</Text>
|
||||
{versionComponent}
|
||||
{data?.notes && <Text>{data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipContent.displayName = 'TooltipContent';
|
||||
|
||||
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const data = useNodeData(nodeId);
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>Notes</FormLabel>
|
||||
<IAITextarea
|
||||
value={data?.notes}
|
||||
onChange={handleNotesChanged}
|
||||
rows={10}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
NotesTextarea.displayName = 'NodesTextarea';
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import { FormControl, FormLabel } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const data = useNodeData(nodeId);
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>Notes</FormLabel>
|
||||
<IAITextarea
|
||||
value={data?.notes}
|
||||
onChange={handleNotesChanged}
|
||||
rows={10}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NotesTextarea);
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
|
||||
import { nodeIsIntermediateChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const SaveToGalleryCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const isIntermediate = useIsIntermediate(nodeId);
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
nodeIsIntermediateChanged({
|
||||
nodeId,
|
||||
isIntermediate: !e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
|
||||
if (!hasImageOutput) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save to Gallery</FormLabel>
|
||||
<Checkbox
|
||||
className="nopan"
|
||||
size="sm"
|
||||
onChange={handleChange}
|
||||
isChecked={!isIntermediate}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SaveToGalleryCheckbox);
|
||||
@@ -0,0 +1,167 @@
|
||||
import {
|
||||
Editable,
|
||||
EditableInput,
|
||||
EditablePreview,
|
||||
Flex,
|
||||
Tooltip,
|
||||
forwardRef,
|
||||
useEditableControls,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
||||
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
|
||||
import FieldTooltipContent from './FieldTooltipContent';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
isMissingInput?: boolean;
|
||||
withTooltip?: boolean;
|
||||
}
|
||||
|
||||
const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
const {
|
||||
nodeId,
|
||||
fieldName,
|
||||
kind,
|
||||
isMissingInput = false,
|
||||
withTooltip = false,
|
||||
} = props;
|
||||
const label = useFieldLabel(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const [localTitle, setLocalTitle] = useState(
|
||||
label || fieldTemplateTitle || 'Unknown Field'
|
||||
);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (newTitle: string) => {
|
||||
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
|
||||
return;
|
||||
}
|
||||
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
|
||||
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
|
||||
},
|
||||
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
|
||||
);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
setLocalTitle(newTitle);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// Another component may change the title; sync local title with global state
|
||||
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
|
||||
}, [label, fieldTemplateTitle]);
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={
|
||||
withTooltip ? (
|
||||
<FieldTooltipContent
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
/>
|
||||
) : undefined
|
||||
}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
placement="top"
|
||||
hasArrow
|
||||
>
|
||||
<Flex
|
||||
ref={ref}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
overflow: 'hidden',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'flex-start',
|
||||
gap: 1,
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<Editable
|
||||
value={localTitle}
|
||||
onChange={handleChange}
|
||||
onSubmit={handleSubmit}
|
||||
as={Flex}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
alignItems: 'center',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<EditablePreview
|
||||
sx={{
|
||||
p: 0,
|
||||
fontWeight: isMissingInput ? 600 : 400,
|
||||
textAlign: 'left',
|
||||
_hover: {
|
||||
fontWeight: '600 !important',
|
||||
},
|
||||
}}
|
||||
noOfLines={1}
|
||||
/>
|
||||
<EditableInput
|
||||
className="nodrag"
|
||||
sx={{
|
||||
p: 0,
|
||||
w: 'full',
|
||||
fontWeight: 600,
|
||||
color: 'base.900',
|
||||
_dark: {
|
||||
color: 'base.100',
|
||||
},
|
||||
_focusVisible: {
|
||||
p: 0,
|
||||
textAlign: 'left',
|
||||
boxShadow: 'none',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
<EditableControls />
|
||||
</Editable>
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(EditableFieldTitle);
|
||||
|
||||
const EditableControls = memo(() => {
|
||||
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
const { onClick } = getEditButtonProps();
|
||||
if (!onClick) {
|
||||
return;
|
||||
}
|
||||
onClick(e);
|
||||
e.preventDefault();
|
||||
},
|
||||
[getEditButtonProps]
|
||||
);
|
||||
|
||||
if (isEditing) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
onClick={handleClick}
|
||||
position="absolute"
|
||||
w="full"
|
||||
h="full"
|
||||
top={0}
|
||||
insetInlineStart={0}
|
||||
cursor="text"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
EditableControls.displayName = 'EditableControls';
|
||||
@@ -1,8 +1,11 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import {
|
||||
COLLECTION_TYPES,
|
||||
FIELDS,
|
||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||
MODEL_TYPES,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
@@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
|
||||
borderWidth: 0,
|
||||
zIndex: 1,
|
||||
};
|
||||
``;
|
||||
|
||||
export const inputHandleStyles: CSSProperties = {
|
||||
left: '-1rem',
|
||||
@@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
connectionError,
|
||||
} = props;
|
||||
const { name, type } = fieldTemplate;
|
||||
const { color, title } = FIELDS[type];
|
||||
const { color: typeColor, title } = FIELDS[type];
|
||||
|
||||
const styles: CSSProperties = useMemo(() => {
|
||||
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||
const isModelType = MODEL_TYPES.includes(type);
|
||||
const color = colorTokenToCssVar(typeColor);
|
||||
const s: CSSProperties = {
|
||||
backgroundColor: colorTokenToCssVar(color),
|
||||
backgroundColor:
|
||||
isCollectionType || isPolymorphicType
|
||||
? 'var(--invokeai-colors-base-900)'
|
||||
: color,
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: 0,
|
||||
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
zIndex: 1,
|
||||
};
|
||||
|
||||
@@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
|
||||
return s;
|
||||
}, [
|
||||
color,
|
||||
connectionError,
|
||||
handleType,
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
type,
|
||||
typeColor,
|
||||
]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
import {
|
||||
Editable,
|
||||
EditableInput,
|
||||
EditablePreview,
|
||||
Flex,
|
||||
forwardRef,
|
||||
useEditableControls,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { Flex, Text, forwardRef } from '@chakra-ui/react';
|
||||
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
||||
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
@@ -24,31 +15,6 @@ const FieldTitle = forwardRef((props: Props, ref) => {
|
||||
const label = useFieldLabel(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const [localTitle, setLocalTitle] = useState(
|
||||
label || fieldTemplateTitle || 'Unknown Field'
|
||||
);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (newTitle: string) => {
|
||||
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
|
||||
return;
|
||||
}
|
||||
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
|
||||
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
|
||||
},
|
||||
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
|
||||
);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
setLocalTitle(newTitle);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// Another component may change the title; sync local title with global state
|
||||
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
|
||||
}, [label, fieldTemplateTitle]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
@@ -62,82 +28,11 @@ const FieldTitle = forwardRef((props: Props, ref) => {
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<Editable
|
||||
value={localTitle}
|
||||
onChange={handleChange}
|
||||
onSubmit={handleSubmit}
|
||||
as={Flex}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
alignItems: 'center',
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<EditablePreview
|
||||
sx={{
|
||||
p: 0,
|
||||
fontWeight: isMissingInput ? 600 : 400,
|
||||
textAlign: 'left',
|
||||
_hover: {
|
||||
fontWeight: '600 !important',
|
||||
},
|
||||
}}
|
||||
noOfLines={1}
|
||||
/>
|
||||
<EditableInput
|
||||
className="nodrag"
|
||||
sx={{
|
||||
p: 0,
|
||||
fontWeight: 600,
|
||||
color: 'base.900',
|
||||
_dark: {
|
||||
color: 'base.100',
|
||||
},
|
||||
_focusVisible: {
|
||||
p: 0,
|
||||
textAlign: 'left',
|
||||
boxShadow: 'none',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
<EditableControls />
|
||||
</Editable>
|
||||
<Text sx={{ fontWeight: isMissingInput ? 600 : 400 }}>
|
||||
{label || fieldTemplateTitle}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(FieldTitle);
|
||||
|
||||
const EditableControls = memo(() => {
|
||||
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
const { onClick } = getEditButtonProps();
|
||||
if (!onClick) {
|
||||
return;
|
||||
}
|
||||
onClick(e);
|
||||
e.preventDefault();
|
||||
},
|
||||
[getEditButtonProps]
|
||||
);
|
||||
|
||||
if (isEditing) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
onClick={handleClick}
|
||||
position="absolute"
|
||||
w="full"
|
||||
h="full"
|
||||
top={0}
|
||||
insetInlineStart={0}
|
||||
cursor="text"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
EditableControls.displayName = 'EditableControls';
|
||||
|
||||
@@ -34,6 +34,8 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
}
|
||||
|
||||
return 'Unknown Field';
|
||||
} else {
|
||||
return fieldTemplate?.title || 'Unknown Field';
|
||||
}
|
||||
}, [field, fieldTemplate]);
|
||||
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import { Box, Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import { Box, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue';
|
||||
import { useFieldInputKind } from 'features/nodes/hooks/useFieldInputKind';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import { PropsWithChildren, memo, useMemo } from 'react';
|
||||
import EditableFieldTitle from './EditableFieldTitle';
|
||||
import FieldContextMenu from './FieldContextMenu';
|
||||
import FieldHandle from './FieldHandle';
|
||||
import FieldTitle from './FieldTitle';
|
||||
import FieldTooltipContent from './FieldTooltipContent';
|
||||
import InputFieldRenderer from './InputFieldRenderer';
|
||||
|
||||
interface Props {
|
||||
@@ -21,7 +16,6 @@ interface Props {
|
||||
const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
||||
const input = useFieldInputKind(nodeId, fieldName);
|
||||
|
||||
const {
|
||||
isConnected,
|
||||
@@ -51,11 +45,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
|
||||
if (fieldTemplate?.fieldKind !== 'input') {
|
||||
return (
|
||||
<InputFieldWrapper
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
shouldDim={shouldDim}
|
||||
>
|
||||
<InputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl
|
||||
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
|
||||
>
|
||||
@@ -66,19 +56,14 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<InputFieldWrapper
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
shouldDim={shouldDim}
|
||||
>
|
||||
<InputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl
|
||||
as={Flex}
|
||||
isInvalid={isMissingInput}
|
||||
isDisabled={isConnected}
|
||||
sx={{
|
||||
alignItems: 'stretch',
|
||||
justifyContent: 'space-between',
|
||||
ps: 2,
|
||||
ps: fieldTemplate.input === 'direct' ? 0 : 2,
|
||||
gap: 2,
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
@@ -86,42 +71,28 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
>
|
||||
<FieldContextMenu nodeId={nodeId} fieldName={fieldName} kind="input">
|
||||
{(ref) => (
|
||||
<Tooltip
|
||||
label={
|
||||
<FieldTooltipContent
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
/>
|
||||
}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
placement="top"
|
||||
hasArrow
|
||||
<FormLabel
|
||||
sx={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
h: 'full',
|
||||
mb: 0,
|
||||
px: 1,
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<FormLabel
|
||||
sx={{
|
||||
mb: 0,
|
||||
width: input === 'connection' ? 'auto' : '25%',
|
||||
flexShrink: 0,
|
||||
flexGrow: 0,
|
||||
}}
|
||||
>
|
||||
<FieldTitle
|
||||
ref={ref}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
isMissingInput={isMissingInput}
|
||||
/>
|
||||
</FormLabel>
|
||||
</Tooltip>
|
||||
<EditableFieldTitle
|
||||
ref={ref}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
isMissingInput={isMissingInput}
|
||||
withTooltip
|
||||
/>
|
||||
</FormLabel>
|
||||
)}
|
||||
</FieldContextMenu>
|
||||
<Box
|
||||
sx={{
|
||||
width: input === 'connection' ? 'auto' : '75%',
|
||||
}}
|
||||
>
|
||||
<Box>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
</Box>
|
||||
</FormControl>
|
||||
@@ -143,19 +114,12 @@ export default memo(InputField);
|
||||
|
||||
type InputFieldWrapperProps = PropsWithChildren<{
|
||||
shouldDim: boolean;
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
}>;
|
||||
|
||||
const InputFieldWrapper = memo(
|
||||
({ shouldDim, nodeId, fieldName, children }: InputFieldWrapperProps) => {
|
||||
const { isMouseOverField, handleMouseOver, handleMouseOut } =
|
||||
useIsMouseOverField(nodeId, fieldName);
|
||||
|
||||
({ shouldDim, children }: InputFieldWrapperProps) => {
|
||||
return (
|
||||
<Flex
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
minH: 8,
|
||||
@@ -169,7 +133,6 @@ const InputFieldWrapper = memo(
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3,17 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { memo } from 'react';
|
||||
import BooleanInputField from './inputs/BooleanInputField';
|
||||
import ClipInputField from './inputs/ClipInputField';
|
||||
import CollectionInputField from './inputs/CollectionInputField';
|
||||
import CollectionItemInputField from './inputs/CollectionItemInputField';
|
||||
import ColorInputField from './inputs/ColorInputField';
|
||||
import ConditioningInputField from './inputs/ConditioningInputField';
|
||||
import ControlInputField from './inputs/ControlInputField';
|
||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||
import EnumInputField from './inputs/EnumInputField';
|
||||
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
|
||||
import ImageInputField from './inputs/ImageInputField';
|
||||
import LatentsInputField from './inputs/LatentsInputField';
|
||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||
import MainModelInputField from './inputs/MainModelInputField';
|
||||
import NumberInputField from './inputs/NumberInputField';
|
||||
@@ -21,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
|
||||
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import UnetInputField from './inputs/UnetInputField';
|
||||
import VaeInputField from './inputs/VaeInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
|
||||
type InputFieldProps = {
|
||||
@@ -30,7 +21,6 @@ type InputFieldProps = {
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
// build an individual input element based on the schema
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
@@ -92,75 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'LatentsField' &&
|
||||
fieldTemplate?.type === 'LatentsField'
|
||||
) {
|
||||
return (
|
||||
<LatentsInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ConditioningField' &&
|
||||
fieldTemplate?.type === 'ConditioningField'
|
||||
) {
|
||||
return (
|
||||
<ConditioningInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
|
||||
return (
|
||||
<UnetInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
|
||||
return (
|
||||
<ClipInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
|
||||
return (
|
||||
<VaeInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ControlField' &&
|
||||
fieldTemplate?.type === 'ControlField'
|
||||
) {
|
||||
return (
|
||||
<ControlInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'MainModelField' &&
|
||||
fieldTemplate?.type === 'MainModelField'
|
||||
@@ -226,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
|
||||
return (
|
||||
<CollectionInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'CollectionItem' &&
|
||||
fieldTemplate?.type === 'CollectionItem'
|
||||
) {
|
||||
return (
|
||||
<CollectionItemInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
@@ -259,19 +157,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ImageCollection' &&
|
||||
fieldTemplate?.type === 'ImageCollection'
|
||||
) {
|
||||
return (
|
||||
<ImageCollectionInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLMainModelField' &&
|
||||
fieldTemplate?.type === 'SDXLMainModelField'
|
||||
@@ -295,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field && fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={1}>
|
||||
<Text
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
import { Flex, FormControl, FormLabel, Icon, Tooltip } from '@chakra-ui/react';
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Icon,
|
||||
Spacer,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
|
||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaInfoCircle, FaTrash } from 'react-icons/fa';
|
||||
import FieldTitle from './FieldTitle';
|
||||
import EditableFieldTitle from './EditableFieldTitle';
|
||||
import FieldTooltipContent from './FieldTooltipContent';
|
||||
import InputFieldRenderer from './InputFieldRenderer';
|
||||
|
||||
@@ -18,8 +25,8 @@ type Props = {
|
||||
|
||||
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { isMouseOverField, handleMouseOut, handleMouseOver } =
|
||||
useIsMouseOverField(nodeId, fieldName);
|
||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
|
||||
useMouseOverNode(nodeId);
|
||||
|
||||
const handleRemoveField = useCallback(() => {
|
||||
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
|
||||
@@ -27,8 +34,8 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
|
||||
return (
|
||||
<Flex
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
onMouseEnter={handleMouseOver}
|
||||
onMouseLeave={handleMouseOut}
|
||||
layerStyle="second"
|
||||
sx={{
|
||||
position: 'relative',
|
||||
@@ -42,11 +49,15 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
sx={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
mb: 0,
|
||||
}}
|
||||
>
|
||||
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
|
||||
<EditableFieldTitle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
/>
|
||||
<Spacer />
|
||||
<Tooltip
|
||||
label={
|
||||
<FieldTooltipContent
|
||||
@@ -74,7 +85,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
</FormLabel>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
</FormControl>
|
||||
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
|
||||
<NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import {
|
||||
ControlInputFieldTemplate,
|
||||
ControlInputFieldValue,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ControlPolymorphicInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const ControlInputFieldComponent = (
|
||||
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
|
||||
_props: FieldComponentProps<
|
||||
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
|
||||
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -92,6 +92,7 @@ const ControlNetModelInputFieldComponent = (
|
||||
error={!selectedModel}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
sx={{ width: '100%' }}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import {
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
DenoiseMaskInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const DenoiseMaskInputFieldComponent = (
|
||||
_props: FieldComponentProps<
|
||||
DenoiseMaskInputFieldValue,
|
||||
DenoiseMaskInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(DenoiseMaskInputFieldComponent);
|
||||
@@ -9,9 +9,9 @@ import {
|
||||
} from 'features/dnd/types';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
|
||||
@@ -2,11 +2,16 @@ import {
|
||||
LatentsInputFieldTemplate,
|
||||
LatentsInputFieldValue,
|
||||
FieldComponentProps,
|
||||
LatentsPolymorphicInputFieldValue,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const LatentsInputFieldComponent = (
|
||||
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
|
||||
_props: FieldComponentProps<
|
||||
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
|
||||
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -101,8 +101,10 @@ const LoRAModelInputFieldComponent = (
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
error={!selectedLoRAModel}
|
||||
onChange={handleChange}
|
||||
sx={{
|
||||
width: '100%',
|
||||
'.mantine-Select-dropdown': {
|
||||
width: '16rem !important',
|
||||
},
|
||||
|
||||
@@ -134,6 +134,7 @@ const MainModelInputFieldComponent = (
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
sx={{
|
||||
width: '100%',
|
||||
'.mantine-Select-dropdown': {
|
||||
width: '16rem !important',
|
||||
},
|
||||
|
||||
@@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
FloatInputFieldTemplate,
|
||||
FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
@@ -101,20 +101,17 @@ const RefinerModelInputFieldComponent = (
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
error={!selectedModel}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
sx={{
|
||||
width: '100%',
|
||||
'.mantine-Select-dropdown': {
|
||||
width: '16rem !important',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton className="nodrag" iconMode />
|
||||
</Box>
|
||||
)}
|
||||
{isSyncModelEnabled && <SyncModelsButton className="nodrag" iconMode />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -128,10 +128,11 @@ const ModelInputFieldComponent = (
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
error={!selectedModel}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
sx={{
|
||||
width: '100%',
|
||||
'.mantine-Select-dropdown': {
|
||||
width: '16rem !important',
|
||||
},
|
||||
|
||||
@@ -4,9 +4,9 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
@@ -88,17 +88,15 @@ const VaeModelInputFieldComponent = (
|
||||
className="nowheel nodrag"
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
tooltip={selectedVaeModel?.description}
|
||||
label={
|
||||
selectedVaeModel?.base_model &&
|
||||
MODEL_TYPE_MAP[selectedVaeModel?.base_model]
|
||||
}
|
||||
value={selectedVaeModel?.id ?? 'default'}
|
||||
placeholder="Default"
|
||||
data={data}
|
||||
onChange={handleChangeModel}
|
||||
disabled={data.length === 0}
|
||||
error={!selectedVaeModel}
|
||||
clearable
|
||||
sx={{
|
||||
width: '100%',
|
||||
'.mantine-Select-dropdown': {
|
||||
width: '16rem !important',
|
||||
},
|
||||
|
||||
@@ -27,9 +27,11 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const handleSubmit = useCallback(
|
||||
async (newTitle: string) => {
|
||||
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
|
||||
setLocalTitle(newTitle || title || 'Problem Setting Title');
|
||||
setLocalTitle(
|
||||
newTitle || title || templateTitle || 'Problem Setting Title'
|
||||
);
|
||||
},
|
||||
[nodeId, dispatch, title]
|
||||
[dispatch, nodeId, title, templateTitle]
|
||||
);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
|
||||
@@ -7,13 +7,22 @@ import {
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
DRAG_HANDLE_CLASSNAME,
|
||||
NODE_WIDTH,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { NodeStatus } from 'features/nodes/types/types';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import { PropsWithChildren, memo, useCallback, useMemo } from 'react';
|
||||
import {
|
||||
MouseEvent,
|
||||
PropsWithChildren,
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react';
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
nodeId: string;
|
||||
@@ -23,6 +32,8 @@ type NodeWrapperProps = PropsWithChildren & {
|
||||
|
||||
const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
|
||||
useMouseOverNode(nodeId);
|
||||
|
||||
const selectIsInProgress = useMemo(
|
||||
() =>
|
||||
@@ -36,25 +47,16 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
|
||||
const isInProgress = useAppSelector(selectIsInProgress);
|
||||
|
||||
const [
|
||||
nodeSelectedLight,
|
||||
nodeSelectedDark,
|
||||
nodeInProgressLight,
|
||||
nodeInProgressDark,
|
||||
shadowsXl,
|
||||
shadowsBase,
|
||||
] = useToken('shadows', [
|
||||
'nodeSelected.light',
|
||||
'nodeSelected.dark',
|
||||
'nodeInProgress.light',
|
||||
'nodeInProgress.dark',
|
||||
'shadows.xl',
|
||||
'shadows.base',
|
||||
]);
|
||||
const [nodeInProgressLight, nodeInProgressDark, shadowsXl, shadowsBase] =
|
||||
useToken('shadows', [
|
||||
'nodeInProgress.light',
|
||||
'nodeInProgress.dark',
|
||||
'shadows.xl',
|
||||
'shadows.base',
|
||||
]);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selectedShadow = useColorModeValue(nodeSelectedLight, nodeSelectedDark);
|
||||
const inProgressShadow = useColorModeValue(
|
||||
nodeInProgressLight,
|
||||
nodeInProgressDark
|
||||
@@ -62,13 +64,21 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
|
||||
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(contextMenusClosed());
|
||||
}, [dispatch]);
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
|
||||
dispatch(nodeExclusivelySelected(nodeId));
|
||||
}
|
||||
dispatch(contextMenusClosed());
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<Box
|
||||
onClick={handleClick}
|
||||
onMouseEnter={handleMouseOver}
|
||||
onMouseLeave={handleMouseOut}
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={{
|
||||
h: 'full',
|
||||
@@ -77,11 +87,6 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
w: width ?? NODE_WIDTH,
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
shadow: selected
|
||||
? isInProgress
|
||||
? undefined
|
||||
: selectedShadow
|
||||
: undefined,
|
||||
cursor: 'grab',
|
||||
opacity,
|
||||
}}
|
||||
@@ -116,6 +121,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
}}
|
||||
/>
|
||||
{children}
|
||||
<NodeSelectionOverlay isSelected={selected} isHovered={isMouseOverNode} />
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,12 +2,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useWorkflow } from 'features/nodes/hooks/useWorkflow';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSave } from 'react-icons/fa';
|
||||
import { FaDownload } from 'react-icons/fa';
|
||||
|
||||
const SaveWorkflowButton = () => {
|
||||
const DownloadWorkflowButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const workflow = useWorkflow();
|
||||
const handleSave = useCallback(() => {
|
||||
const handleDownload = useCallback(() => {
|
||||
const blob = new Blob([JSON.stringify(workflow, null, 2)]);
|
||||
const a = document.createElement('a');
|
||||
a.href = URL.createObjectURL(blob);
|
||||
@@ -18,12 +18,12 @@ const SaveWorkflowButton = () => {
|
||||
}, [workflow]);
|
||||
return (
|
||||
<IAIIconButton
|
||||
icon={<FaSave />}
|
||||
tooltip={t('nodes.saveWorkflow')}
|
||||
aria-label={t('nodes.saveWorkflow')}
|
||||
onClick={handleSave}
|
||||
icon={<FaDownload />}
|
||||
tooltip={t('nodes.downloadWorkflow')}
|
||||
aria-label={t('nodes.downloadWorkflow')}
|
||||
onClick={handleDownload}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SaveWorkflowButton);
|
||||
export default memo(DownloadWorkflowButton);
|
||||
@@ -2,7 +2,7 @@ import { Flex } from '@chakra-ui/layout';
|
||||
import { memo } from 'react';
|
||||
import LoadWorkflowButton from './LoadWorkflowButton';
|
||||
import ResetWorkflowButton from './ResetWorkflowButton';
|
||||
import SaveWorkflowButton from './SaveWorkflowButton';
|
||||
import DownloadWorkflowButton from './DownloadWorkflowButton';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
return (
|
||||
@@ -15,7 +15,7 @@ const TopCenterPanel = () => {
|
||||
transform: 'translate(-50%)',
|
||||
}}
|
||||
>
|
||||
<SaveWorkflowButton />
|
||||
<DownloadWorkflowButton />
|
||||
<LoadWorkflowButton />
|
||||
<ResetWorkflowButton />
|
||||
</Flex>
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { InvocationTemplate, NodeData } from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
|
||||
import NodeTitle from '../../flow/nodes/common/NodeTitle';
|
||||
import ScrollableContent from '../ScrollableContent';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const lastSelectedNodeId =
|
||||
nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find(
|
||||
(node) => node.id === lastSelectedNodeId
|
||||
);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode
|
||||
? nodes.nodeTemplates[lastSelectedNode.data.type]
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
data: lastSelectedNode?.data,
|
||||
template: lastSelectedNodeTemplate,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const InspectorDetailsTab = () => {
|
||||
const { data, template } = useAppSelector(selector);
|
||||
|
||||
if (!template || !data) {
|
||||
return <IAINoContentFallback label="No node selected" icon={null} />;
|
||||
}
|
||||
|
||||
return <Content data={data} template={template} />;
|
||||
};
|
||||
|
||||
export default memo(InspectorDetailsTab);
|
||||
|
||||
const Content = (props: { data: NodeData; template: InvocationTemplate }) => {
|
||||
const { data } = props;
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<ScrollableContent>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
position: 'relative',
|
||||
p: 1,
|
||||
gap: 2,
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
<NodeTitle nodeId={data.id} />
|
||||
<NotesTextarea nodeId={data.id} />
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -4,12 +4,13 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||
import ScrollableContent from '../ScrollableContent';
|
||||
import { ImageOutput } from 'services/api/types';
|
||||
import { AnyResult } from 'services/events/types';
|
||||
import StringOutputPreview from './outputs/StringOutputPreview';
|
||||
import NumberOutputPreview from './outputs/NumberOutputPreview';
|
||||
import ScrollableContent from '../ScrollableContent';
|
||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@@ -21,11 +22,16 @@ const selector = createSelector(
|
||||
(node) => node.id === lastSelectedNodeId
|
||||
);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode
|
||||
? nodes.nodeTemplates[lastSelectedNode.data.type]
|
||||
: undefined;
|
||||
|
||||
const nes =
|
||||
nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
||||
|
||||
return {
|
||||
node: lastSelectedNode,
|
||||
template: lastSelectedNodeTemplate,
|
||||
nes,
|
||||
};
|
||||
},
|
||||
@@ -33,9 +39,9 @@ const selector = createSelector(
|
||||
);
|
||||
|
||||
const InspectorOutputsTab = () => {
|
||||
const { node, nes } = useAppSelector(selector);
|
||||
const { node, template, nes } = useAppSelector(selector);
|
||||
|
||||
if (!node || !nes) {
|
||||
if (!node || !nes || !isInvocationNode(node)) {
|
||||
return <IAINoContentFallback label="No node selected" icon={null} />;
|
||||
}
|
||||
|
||||
@@ -63,33 +69,16 @@ const InspectorOutputsTab = () => {
|
||||
w: 'full',
|
||||
}}
|
||||
>
|
||||
{nes.outputs.map((result, i) => {
|
||||
if (result.type === 'string_output') {
|
||||
return (
|
||||
<StringOutputPreview key={getKey(result, i)} output={result} />
|
||||
);
|
||||
}
|
||||
if (result.type === 'float_output') {
|
||||
return (
|
||||
<NumberOutputPreview key={getKey(result, i)} output={result} />
|
||||
);
|
||||
}
|
||||
if (result.type === 'integer_output') {
|
||||
return (
|
||||
<NumberOutputPreview key={getKey(result, i)} output={result} />
|
||||
);
|
||||
}
|
||||
if (result.type === 'image_output') {
|
||||
return (
|
||||
<ImageOutputPreview key={getKey(result, i)} output={result} />
|
||||
);
|
||||
}
|
||||
return (
|
||||
<pre key={getKey(result, i)}>
|
||||
{JSON.stringify(result, null, 2)}
|
||||
</pre>
|
||||
);
|
||||
})}
|
||||
{template?.outputType === 'image_output' ? (
|
||||
nes.outputs.map((result, i) => (
|
||||
<ImageOutputPreview
|
||||
key={getKey(result, i)}
|
||||
output={result as ImageOutput}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<DataViewer data={nes.outputs} label="Node Outputs" />
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
||||
|
||||
@@ -10,6 +10,7 @@ import { memo } from 'react';
|
||||
import InspectorDataTab from './InspectorDataTab';
|
||||
import InspectorOutputsTab from './InspectorOutputsTab';
|
||||
import InspectorTemplateTab from './InspectorTemplateTab';
|
||||
// import InspectorDetailsTab from './InspectorDetailsTab';
|
||||
|
||||
const InspectorPanel = () => {
|
||||
return (
|
||||
@@ -29,12 +30,16 @@ const InspectorPanel = () => {
|
||||
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
||||
>
|
||||
<TabList>
|
||||
{/* <Tab>Details</Tab> */}
|
||||
<Tab>Outputs</Tab>
|
||||
<Tab>Data</Tab>
|
||||
<Tab>Template</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels>
|
||||
{/* <TabPanel>
|
||||
<InspectorDetailsTab />
|
||||
</TabPanel> */}
|
||||
<TabPanel>
|
||||
<InspectorOutputsTab />
|
||||
</TabPanel>
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import { Text } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FloatOutput, IntegerOutput } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
output: IntegerOutput | FloatOutput;
|
||||
};
|
||||
|
||||
const NumberOutputPreview = ({ output }: Props) => {
|
||||
return <Text>{output.value}</Text>;
|
||||
};
|
||||
|
||||
export default memo(NumberOutputPreview);
|
||||
@@ -1,13 +0,0 @@
|
||||
import { Text } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { StringOutput } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
output: StringOutput;
|
||||
};
|
||||
|
||||
const StringOutputPreview = ({ output }: Props) => {
|
||||
return <Text>{output.value}</Text>;
|
||||
};
|
||||
|
||||
export default memo(StringOutputPreview);
|
||||
@@ -22,6 +22,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
}
|
||||
return map(nodeTemplate.inputs)
|
||||
.filter((field) => ['any', 'direct'].includes(field.input))
|
||||
.filter((field) => !field.ui_hidden)
|
||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||
.map((field) => field.name)
|
||||
.filter((fieldName) => fieldName !== 'is_intermediate');
|
||||
|
||||
@@ -138,11 +138,14 @@ export const useBuildNodeData = () => {
|
||||
data: {
|
||||
id: nodeId,
|
||||
type,
|
||||
inputs,
|
||||
outputs,
|
||||
isOpen: true,
|
||||
version: template.version,
|
||||
label: '',
|
||||
notes: '',
|
||||
isOpen: true,
|
||||
embedWorkflow: false,
|
||||
isIntermediate: true,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
}
|
||||
return map(nodeTemplate.inputs)
|
||||
.filter((field) => field.input === 'connection')
|
||||
.filter((field) => !field.ui_hidden)
|
||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||
.map((field) => field.name)
|
||||
.filter((fieldName) => fieldName !== 'is_intermediate');
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
if (!nodeTemplate?.version || !node.data?.version) {
|
||||
return false;
|
||||
}
|
||||
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeTemplate = useAppSelector(selector);
|
||||
|
||||
return nodeTemplate;
|
||||
};
|
||||
@@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return Boolean(node?.data.inputs[fieldName]?.value);
|
||||
return node?.data.inputs[fieldName]?.value !== undefined;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
export const useEmbedWorkflow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return node.data.embedWorkflow;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const embedWorkflow = useAppSelector(selector);
|
||||
return embedWorkflow;
|
||||
};
|
||||
@@ -15,7 +15,7 @@ export const useIsIntermediate = (nodeId: string) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return Boolean(node.data.inputs.is_intermediate?.value);
|
||||
return node.data.isIntermediate;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
|
||||
@@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
||||
import { COLLECTION_TYPES } from '../types/constants';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from '../types/constants';
|
||||
import { InvocationNodeData } from '../types/types';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const flow = useReactFlow();
|
||||
const shouldValidateGraph = useAppSelector(
|
||||
@@ -42,6 +52,19 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
edges
|
||||
.filter((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
})
|
||||
.find((edge) => {
|
||||
edge.source === source && edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
@@ -53,21 +76,62 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection types must be the same for a connection
|
||||
if (
|
||||
sourceType !== targetType &&
|
||||
sourceType !== 'CollectionItem' &&
|
||||
targetType !== 'CollectionItem'
|
||||
) {
|
||||
if (
|
||||
!(
|
||||
COLLECTION_TYPES.includes(targetType) &&
|
||||
COLLECTION_TYPES.includes(sourceType)
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat
|
||||
);
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||
},
|
||||
|
||||
@@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { zWorkflow } from 'features/nodes/types/types';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
||||
import { workflowLoadRequested } from '../store/actions';
|
||||
|
||||
export const useLoadWorkflowFromFile = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -27,49 +27,38 @@ export const useLoadWorkflowFromFile = () => {
|
||||
const result = zWorkflow.safeParse(parsedJSON);
|
||||
|
||||
if (!result.success) {
|
||||
const message = fromZodError(result.error, {
|
||||
const { message } = fromZodError(result.error, {
|
||||
prefix: 'Workflow Validation Error',
|
||||
}).toString();
|
||||
});
|
||||
|
||||
logger.error({ error: parseify(result.error) }, message);
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Unable to Validate Workflow',
|
||||
description: (
|
||||
<WorkflowValidationErrorContent error={result.error} />
|
||||
),
|
||||
status: 'error',
|
||||
duration: 5000,
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(workflowLoaded(result.data));
|
||||
dispatch(workflowLoadRequested(result.data));
|
||||
|
||||
reader.abort();
|
||||
} catch {
|
||||
// file reader error
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
title: 'Unable to Load Workflow',
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
} catch (error) {
|
||||
// file reader error
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Unable to Load Workflow',
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { mouseOverNodeChanged } from '../store/nodesSlice';
|
||||
|
||||
export const useMouseOverNode = (nodeId: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => nodes.mouseOverNode === nodeId,
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const isMouseOverNode = useAppSelector(selector);
|
||||
|
||||
const handleMouseOver = useCallback(() => {
|
||||
!isMouseOverNode && dispatch(mouseOverNodeChanged(nodeId));
|
||||
}, [dispatch, nodeId, isMouseOverNode]);
|
||||
|
||||
const handleMouseOut = useCallback(() => {
|
||||
isMouseOverNode && dispatch(mouseOverNodeChanged(null));
|
||||
}, [dispatch, isMouseOverNode]);
|
||||
|
||||
return { isMouseOverNode, handleMouseOver, handleMouseOut };
|
||||
};
|
||||
@@ -21,6 +21,7 @@ export const useOutputFieldNames = (nodeId: string) => {
|
||||
return [];
|
||||
}
|
||||
return map(nodeTemplate.outputs)
|
||||
.filter((field) => !field.ui_hidden)
|
||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||
.map((field) => field.name)
|
||||
.filter((fieldName) => fieldName !== 'is_intermediate');
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { Workflow } from '../types/types';
|
||||
|
||||
export const textToImageGraphBuilt = createAction<Graph>(
|
||||
'nodes/textToImageGraphBuilt'
|
||||
@@ -16,3 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
|
||||
canvasGraphBuilt,
|
||||
nodesGraphBuilt
|
||||
);
|
||||
|
||||
export const workflowLoadRequested = createAction<Workflow>(
|
||||
'nodes/workflowLoadRequested'
|
||||
);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es';
|
||||
import { cloneDeep, forEach, isEqual, map, uniqBy } from 'lodash-es';
|
||||
import {
|
||||
addEdge,
|
||||
applyEdgeChanges,
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
Viewport,
|
||||
} from 'reactflow';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { sessionInvoked } from 'services/api/thunks/session';
|
||||
import { sessionCanceled, sessionInvoked } from 'services/api/thunks/session';
|
||||
import { ImageField } from 'services/api/types';
|
||||
import {
|
||||
appSocketGeneratorProgress,
|
||||
@@ -102,6 +102,7 @@ export const initialNodesState: NodesState = {
|
||||
nodeExecutionStates: {},
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
mouseOverField: null,
|
||||
mouseOverNode: null,
|
||||
nodesToCopy: [],
|
||||
edgesToCopy: [],
|
||||
selectionMode: SelectionMode.Partial,
|
||||
@@ -245,6 +246,34 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
field.label = label;
|
||||
},
|
||||
nodeEmbedWorkflowChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; embedWorkflow: boolean }>
|
||||
) => {
|
||||
const { nodeId, embedWorkflow } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
node.data.embedWorkflow = embedWorkflow;
|
||||
},
|
||||
nodeIsIntermediateChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; isIntermediate: boolean }>
|
||||
) => {
|
||||
const { nodeId, isIntermediate } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
node.data.isIntermediate = isIntermediate;
|
||||
},
|
||||
nodeIsOpenChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; isOpen: boolean }>
|
||||
@@ -414,6 +443,17 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = notes;
|
||||
},
|
||||
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
|
||||
const nodeId = action.payload;
|
||||
state.nodes = applyNodeChanges(
|
||||
state.nodes.map((n) => ({
|
||||
id: n.id,
|
||||
type: 'select',
|
||||
selected: n.id === nodeId ? true : false,
|
||||
})),
|
||||
state.nodes
|
||||
);
|
||||
},
|
||||
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
|
||||
state.selectedNodes = action.payload;
|
||||
},
|
||||
@@ -561,7 +601,7 @@ const nodesSlice = createSlice({
|
||||
nodeEditorReset: (state) => {
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
state.workflow.exposedFields = [];
|
||||
state.workflow = cloneDeep(initialWorkflow);
|
||||
},
|
||||
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldValidateGraph = action.payload;
|
||||
@@ -637,6 +677,9 @@ const nodesSlice = createSlice({
|
||||
) => {
|
||||
state.mouseOverField = action.payload;
|
||||
},
|
||||
mouseOverNodeChanged: (state, action: PayloadAction<string | null>) => {
|
||||
state.mouseOverNode = action.payload;
|
||||
},
|
||||
selectedAll: (state) => {
|
||||
state.nodes = applyNodeChanges(
|
||||
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
|
||||
@@ -790,6 +833,13 @@ const nodesSlice = createSlice({
|
||||
nes.outputs = [];
|
||||
});
|
||||
});
|
||||
builder.addCase(sessionCanceled.fulfilled, (state) => {
|
||||
map(state.nodeExecutionStates, (nes) => {
|
||||
if (nes.status === NodeStatus.IN_PROGRESS) {
|
||||
nes.status = NodeStatus.PENDING;
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -850,6 +900,10 @@ export const {
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverToggled,
|
||||
selectionModeChanged,
|
||||
nodeEmbedWorkflowChanged,
|
||||
nodeIsIntermediateChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeExclusivelySelected,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
import { atom } from 'nanostores';
|
||||
import { ReactFlowInstance } from 'reactflow';
|
||||
|
||||
export const $flow = atom<ReactFlowInstance | null>(null);
|
||||
@@ -35,6 +35,7 @@ export type NodesState = {
|
||||
viewport: Viewport;
|
||||
isReady: boolean;
|
||||
mouseOverField: FieldIdentifier | null;
|
||||
mouseOverNode: string | null;
|
||||
nodesToCopy: Node<NodeData>[];
|
||||
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
||||
isAddNodePopoverOpen: boolean;
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { COLLECTION_TYPES } from 'features/nodes/types/constants';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import { HandleType } from 'reactflow';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const makeConnectionErrorSelector = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
@@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = (
|
||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||
state.nodes;
|
||||
|
||||
if (!state.nodes.shouldValidateGraph) {
|
||||
// manual override!
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!connectionStartParams || !currentConnectionFieldType) {
|
||||
return 'No connection in progress';
|
||||
}
|
||||
@@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = (
|
||||
return 'No connection data';
|
||||
}
|
||||
|
||||
const targetFieldType =
|
||||
const targetType =
|
||||
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
||||
const sourceFieldType =
|
||||
const sourceType =
|
||||
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
||||
|
||||
if (nodeId === connectionNodeId) {
|
||||
@@ -55,30 +60,73 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
|
||||
if (
|
||||
fieldType !== currentConnectionFieldType &&
|
||||
fieldType !== 'CollectionItem' &&
|
||||
currentConnectionFieldType !== 'CollectionItem'
|
||||
) {
|
||||
if (
|
||||
!(
|
||||
COLLECTION_TYPES.includes(targetFieldType) &&
|
||||
COLLECTION_TYPES.includes(sourceFieldType)
|
||||
)
|
||||
) {
|
||||
// except for collection items, field types must match
|
||||
return 'Field types must match';
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
handleType === 'target' &&
|
||||
edges.find((edge) => {
|
||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetFieldType !== 'CollectionItem'
|
||||
targetType !== 'CollectionItem'
|
||||
) {
|
||||
return 'Inputs may only have one connection';
|
||||
return 'Input may only have one connection';
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
if (
|
||||
!(
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat
|
||||
)
|
||||
) {
|
||||
return 'Field types must match';
|
||||
}
|
||||
}
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(
|
||||
|
||||
@@ -17,176 +17,297 @@ export const KIND_MAP = {
|
||||
export const COLLECTION_TYPES: FieldType[] = [
|
||||
'Collection',
|
||||
'IntegerCollection',
|
||||
'BooleanCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'BooleanCollection',
|
||||
'ImageCollection',
|
||||
'LatentsCollection',
|
||||
'ConditioningCollection',
|
||||
'ControlCollection',
|
||||
'ColorCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES = [
|
||||
'IntegerPolymorphic',
|
||||
'BooleanPolymorphic',
|
||||
'FloatPolymorphic',
|
||||
'StringPolymorphic',
|
||||
'ImagePolymorphic',
|
||||
'LatentsPolymorphic',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'ONNXModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'ClipField',
|
||||
];
|
||||
|
||||
export const COLLECTION_MAP = {
|
||||
integer: 'IntegerCollection',
|
||||
boolean: 'BooleanCollection',
|
||||
number: 'FloatCollection',
|
||||
float: 'FloatCollection',
|
||||
string: 'StringCollection',
|
||||
ImageField: 'ImageCollection',
|
||||
LatentsField: 'LatentsCollection',
|
||||
ConditioningField: 'ConditioningCollection',
|
||||
ControlField: 'ControlCollection',
|
||||
ColorField: 'ColorCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof COLLECTION_MAP =>
|
||||
Boolean(itemType && itemType in COLLECTION_MAP);
|
||||
|
||||
export const SINGLE_TO_POLYMORPHIC_MAP = {
|
||||
integer: 'IntegerPolymorphic',
|
||||
boolean: 'BooleanPolymorphic',
|
||||
number: 'FloatPolymorphic',
|
||||
float: 'FloatPolymorphic',
|
||||
string: 'StringPolymorphic',
|
||||
ImageField: 'ImagePolymorphic',
|
||||
LatentsField: 'LatentsPolymorphic',
|
||||
ConditioningField: 'ConditioningPolymorphic',
|
||||
ControlField: 'ControlPolymorphic',
|
||||
ColorField: 'ColorPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP = {
|
||||
IntegerPolymorphic: 'integer',
|
||||
BooleanPolymorphic: 'boolean',
|
||||
FloatPolymorphic: 'float',
|
||||
StringPolymorphic: 'string',
|
||||
ImagePolymorphic: 'ImageField',
|
||||
LatentsPolymorphic: 'LatentsField',
|
||||
ConditioningPolymorphic: 'ConditioningField',
|
||||
ControlPolymorphic: 'ControlField',
|
||||
ColorPolymorphic: 'ColorField',
|
||||
};
|
||||
|
||||
export const isPolymorphicItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
title: 'Integer',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
color: 'red.500',
|
||||
},
|
||||
float: {
|
||||
title: 'Float',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
color: 'orange.500',
|
||||
},
|
||||
string: {
|
||||
title: 'String',
|
||||
description: 'Strings are text.',
|
||||
color: 'yellow.500',
|
||||
},
|
||||
boolean: {
|
||||
title: 'Boolean',
|
||||
color: 'green.500',
|
||||
description: 'Booleans are true or false.',
|
||||
title: 'Boolean',
|
||||
},
|
||||
enum: {
|
||||
title: 'Enum',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
color: 'blue.500',
|
||||
BooleanCollection: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Collection',
|
||||
},
|
||||
array: {
|
||||
title: 'Array',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
color: 'base.500',
|
||||
},
|
||||
ImageField: {
|
||||
title: 'Image',
|
||||
description: 'Images may be passed between nodes.',
|
||||
color: 'purple.500',
|
||||
},
|
||||
LatentsField: {
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
color: 'pink.500',
|
||||
},
|
||||
LatentsCollection: {
|
||||
title: 'Latents Collection',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
color: 'pink.500',
|
||||
},
|
||||
ConditioningField: {
|
||||
color: 'cyan.500',
|
||||
title: 'Conditioning',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
},
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
title: 'Conditioning Collection',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
},
|
||||
ImageCollection: {
|
||||
title: 'Image Collection',
|
||||
description: 'A collection of images.',
|
||||
color: 'base.300',
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
title: 'UNet',
|
||||
description: 'UNet submodel.',
|
||||
BooleanPolymorphic: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Polymorphic',
|
||||
},
|
||||
ClipField: {
|
||||
color: 'green.500',
|
||||
title: 'Clip',
|
||||
description: 'Tokenizer and text_encoder submodels.',
|
||||
},
|
||||
VaeField: {
|
||||
color: 'blue.500',
|
||||
title: 'Vae',
|
||||
description: 'Vae submodel.',
|
||||
},
|
||||
ControlField: {
|
||||
color: 'cyan.500',
|
||||
title: 'Control',
|
||||
description: 'Control info passed between nodes.',
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'Model',
|
||||
description: 'TODO',
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'Refiner Model',
|
||||
description: 'TODO',
|
||||
},
|
||||
VaeModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'VAE',
|
||||
description: 'TODO',
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'LoRA',
|
||||
description: 'TODO',
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'ControlNet',
|
||||
description: 'TODO',
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
title: 'Scheduler',
|
||||
description: 'TODO',
|
||||
title: 'Clip',
|
||||
},
|
||||
Collection: {
|
||||
color: 'base.500',
|
||||
title: 'Collection',
|
||||
description: 'TODO',
|
||||
title: 'Collection',
|
||||
},
|
||||
CollectionItem: {
|
||||
color: 'base.500',
|
||||
title: 'Collection Item',
|
||||
description: 'TODO',
|
||||
title: 'Collection Item',
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Collection',
|
||||
},
|
||||
ColorField: {
|
||||
title: 'Color',
|
||||
color: 'pink.300',
|
||||
description: 'A RGBA color.',
|
||||
color: 'base.500',
|
||||
title: 'Color',
|
||||
},
|
||||
BooleanCollection: {
|
||||
title: 'Boolean Collection',
|
||||
description: 'A collection of booleans.',
|
||||
color: 'green.500',
|
||||
ColorPolymorphic: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Polymorphic',
|
||||
},
|
||||
IntegerCollection: {
|
||||
title: 'Integer Collection',
|
||||
description: 'A collection of integers.',
|
||||
color: 'red.500',
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Collection',
|
||||
},
|
||||
ConditioningField: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning',
|
||||
},
|
||||
ConditioningPolymorphic: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Polymorphic',
|
||||
},
|
||||
ControlCollection: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Collection',
|
||||
},
|
||||
ControlField: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control',
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'ControlNet',
|
||||
},
|
||||
ControlPolymorphic: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Polymorphic',
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
color: 'blue.300',
|
||||
description: 'Denoise Mask may be passed between nodes',
|
||||
title: 'Denoise Mask',
|
||||
},
|
||||
enum: {
|
||||
color: 'blue.500',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
title: 'Enum',
|
||||
},
|
||||
float: {
|
||||
color: 'orange.500',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
title: 'Float',
|
||||
},
|
||||
FloatCollection: {
|
||||
color: 'orange.500',
|
||||
title: 'Float Collection',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Collection',
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'base.500',
|
||||
title: 'Color Collection',
|
||||
description: 'A collection of colors.',
|
||||
FloatPolymorphic: {
|
||||
color: 'orange.500',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Polymorphic',
|
||||
},
|
||||
FilePath: {
|
||||
color: 'base.500',
|
||||
title: 'File Path',
|
||||
description: 'A path to a file.',
|
||||
ImageCollection: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Collection',
|
||||
},
|
||||
ImageField: {
|
||||
color: 'purple.500',
|
||||
description: 'Images may be passed between nodes.',
|
||||
title: 'Image',
|
||||
},
|
||||
ImagePolymorphic: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Polymorphic',
|
||||
},
|
||||
integer: {
|
||||
color: 'red.500',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
title: 'Integer',
|
||||
},
|
||||
IntegerCollection: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Collection',
|
||||
},
|
||||
IntegerPolymorphic: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Polymorphic',
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Collection',
|
||||
},
|
||||
LatentsField: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents',
|
||||
},
|
||||
LatentsPolymorphic: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Polymorphic',
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'LoRA',
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Model',
|
||||
},
|
||||
ONNXModelField: {
|
||||
color: 'base.500',
|
||||
title: 'ONNX Model',
|
||||
color: 'teal.500',
|
||||
description: 'ONNX model field.',
|
||||
title: 'ONNX Model',
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Scheduler',
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
color: 'base.500',
|
||||
title: 'SDXL Model',
|
||||
color: 'teal.500',
|
||||
description: 'SDXL model field.',
|
||||
title: 'SDXL Model',
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Refiner Model',
|
||||
},
|
||||
string: {
|
||||
color: 'yellow.500',
|
||||
description: 'Strings are text.',
|
||||
title: 'String',
|
||||
},
|
||||
StringCollection: {
|
||||
color: 'yellow.500',
|
||||
title: 'String Collection',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Collection',
|
||||
},
|
||||
StringPolymorphic: {
|
||||
color: 'yellow.500',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Polymorphic',
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
description: 'UNet submodel.',
|
||||
title: 'UNet',
|
||||
},
|
||||
VaeField: {
|
||||
color: 'blue.500',
|
||||
description: 'Vae submodel.',
|
||||
title: 'Vae',
|
||||
},
|
||||
VaeModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'VAE',
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
import {
|
||||
SchedulerParam,
|
||||
zBaseModel,
|
||||
zMainModel,
|
||||
zMainOrOnnxModel,
|
||||
zOnnxModel,
|
||||
zSDXLRefinerModel,
|
||||
zScheduler,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { Node } from 'reactflow';
|
||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
||||
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
||||
import {
|
||||
AnyInvocationType,
|
||||
AnyResult,
|
||||
ProgressImage,
|
||||
} from 'services/events/types';
|
||||
import { O } from 'ts-toolbelt';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { z } from 'zod';
|
||||
|
||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||
@@ -47,6 +52,10 @@ export type InvocationTemplate = {
|
||||
* The type of this node's output
|
||||
*/
|
||||
outputType: string; // TODO: generate a union of output types
|
||||
/**
|
||||
* The invocation's version.
|
||||
*/
|
||||
version?: string;
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
@@ -57,87 +66,63 @@ export type FieldUIConfig = {
|
||||
|
||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||
export const zFieldType = z.enum([
|
||||
// region Primitives
|
||||
'integer',
|
||||
'float',
|
||||
'boolean',
|
||||
'string',
|
||||
'array',
|
||||
'ImageField',
|
||||
'LatentsField',
|
||||
'ConditioningField',
|
||||
'ControlField',
|
||||
'ColorField',
|
||||
'ImageCollection',
|
||||
'ConditioningCollection',
|
||||
'ColorCollection',
|
||||
'LatentsCollection',
|
||||
'IntegerCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'BooleanCollection',
|
||||
// endregion
|
||||
|
||||
// region Models
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'ONNXModelField',
|
||||
'VaeModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'BooleanPolymorphic',
|
||||
'ClipField',
|
||||
// endregion
|
||||
|
||||
// region Iterate/Collect
|
||||
'Collection',
|
||||
'CollectionItem',
|
||||
// endregion
|
||||
|
||||
// region Misc
|
||||
'FilePath',
|
||||
'ColorCollection',
|
||||
'ColorField',
|
||||
'ColorPolymorphic',
|
||||
'ConditioningCollection',
|
||||
'ConditioningField',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlCollection',
|
||||
'ControlField',
|
||||
'ControlNetModelField',
|
||||
'ControlPolymorphic',
|
||||
'DenoiseMaskField',
|
||||
'enum',
|
||||
'float',
|
||||
'FloatCollection',
|
||||
'FloatPolymorphic',
|
||||
'ImageCollection',
|
||||
'ImageField',
|
||||
'ImagePolymorphic',
|
||||
'integer',
|
||||
'IntegerCollection',
|
||||
'IntegerPolymorphic',
|
||||
'LatentsCollection',
|
||||
'LatentsField',
|
||||
'LatentsPolymorphic',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'ONNXModelField',
|
||||
'Scheduler',
|
||||
// endregion
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'string',
|
||||
'StringCollection',
|
||||
'StringPolymorphic',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'VaeModelField',
|
||||
]);
|
||||
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
|
||||
export const isFieldType = (value: unknown): value is FieldType =>
|
||||
zFieldType.safeParse(value).success;
|
||||
export const zReservedFieldType = z.enum([
|
||||
'WorkflowField',
|
||||
'IsIntermediate',
|
||||
'MetadataField',
|
||||
]);
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the field type and other field metadata (e.g. title, description,
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| IntegerInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| ClipInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| MainModelInputFieldTemplate
|
||||
| SDXLMainModelInputFieldTemplate
|
||||
| SDXLRefinerModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| CollectionInputFieldTemplate
|
||||
| CollectionItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
| ImageCollectionInputFieldTemplate
|
||||
| SchedulerInputFieldTemplate;
|
||||
export type ReservedFieldType = z.infer<typeof zReservedFieldType>;
|
||||
|
||||
export const isFieldType = (value: unknown): value is FieldType =>
|
||||
zFieldType.safeParse(value).success ||
|
||||
zReservedFieldType.safeParse(value).success;
|
||||
|
||||
/**
|
||||
* Indicates the kind of input(s) this field may have.
|
||||
@@ -205,30 +190,100 @@ export const zConditioningField = z.object({
|
||||
});
|
||||
export type ConditioningField = z.infer<typeof zConditioningField>;
|
||||
|
||||
export const zDenoiseMaskField = z.object({
|
||||
mask_name: z.string().trim().min(1),
|
||||
masked_latents_name: z.string().trim().min(1).optional(),
|
||||
});
|
||||
export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
|
||||
|
||||
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('integer'),
|
||||
value: z.number().optional(),
|
||||
value: z.number().int().optional(),
|
||||
});
|
||||
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
||||
|
||||
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IntegerCollection'),
|
||||
value: z.array(z.number().int()).optional(),
|
||||
});
|
||||
export type IntegerCollectionInputFieldValue = z.infer<
|
||||
typeof zIntegerCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IntegerPolymorphic'),
|
||||
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
|
||||
});
|
||||
export type IntegerPolymorphicInputFieldValue = z.infer<
|
||||
typeof zIntegerPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('float'),
|
||||
value: z.number().optional(),
|
||||
});
|
||||
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
|
||||
|
||||
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FloatCollection'),
|
||||
value: z.array(z.number()).optional(),
|
||||
});
|
||||
export type FloatCollectionInputFieldValue = z.infer<
|
||||
typeof zFloatCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FloatPolymorphic'),
|
||||
value: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
});
|
||||
export type FloatPolymorphicInputFieldValue = z.infer<
|
||||
typeof zFloatPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zStringInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('string'),
|
||||
value: z.string().optional(),
|
||||
});
|
||||
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
|
||||
|
||||
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('StringCollection'),
|
||||
value: z.array(z.string()).optional(),
|
||||
});
|
||||
export type StringCollectionInputFieldValue = z.infer<
|
||||
typeof zStringCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('StringPolymorphic'),
|
||||
value: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
});
|
||||
export type StringPolymorphicInputFieldValue = z.infer<
|
||||
typeof zStringPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('boolean'),
|
||||
value: z.boolean().optional(),
|
||||
});
|
||||
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
|
||||
|
||||
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('BooleanCollection'),
|
||||
value: z.array(z.boolean()).optional(),
|
||||
});
|
||||
export type BooleanCollectionInputFieldValue = z.infer<
|
||||
typeof zBooleanCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('BooleanPolymorphic'),
|
||||
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
|
||||
});
|
||||
export type BooleanPolymorphicInputFieldValue = z.infer<
|
||||
typeof zBooleanPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('enum'),
|
||||
value: z.union([z.string(), z.number()]).optional(),
|
||||
@@ -241,6 +296,30 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
||||
|
||||
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('LatentsCollection'),
|
||||
value: z.array(zLatentsField).optional(),
|
||||
});
|
||||
export type LatentsCollectionInputFieldValue = z.infer<
|
||||
typeof zLatentsCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('LatentsPolymorphic'),
|
||||
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
|
||||
});
|
||||
export type LatentsPolymorphicInputFieldValue = z.infer<
|
||||
typeof zLatentsPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('DenoiseMaskField'),
|
||||
value: zDenoiseMaskField.optional(),
|
||||
});
|
||||
export type DenoiseMaskInputFieldValue = z.infer<
|
||||
typeof zDenoiseMaskInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ConditioningField'),
|
||||
value: zConditioningField.optional(),
|
||||
@@ -249,10 +328,30 @@ export type ConditioningInputFieldValue = z.infer<
|
||||
typeof zConditioningInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningCollectionInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('ConditioningCollection'),
|
||||
value: z.array(zConditioningField).optional(),
|
||||
});
|
||||
export type ConditioningCollectionInputFieldValue = z.infer<
|
||||
typeof zConditioningCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningPolymorphicInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('ConditioningPolymorphic'),
|
||||
value: z
|
||||
.union([zConditioningField, z.array(zConditioningField)])
|
||||
.optional(),
|
||||
});
|
||||
export type ConditioningPolymorphicInputFieldValue = z.infer<
|
||||
typeof zConditioningPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zControlNetModel = zModelIdentifier;
|
||||
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
||||
|
||||
export const zControlField = zInputFieldValueBase.extend({
|
||||
export const zControlField = z.object({
|
||||
image: zImageField,
|
||||
control_model: zControlNetModel,
|
||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
@@ -267,11 +366,27 @@ export const zControlField = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ControlField = z.infer<typeof zControlField>;
|
||||
|
||||
export const zControlInputFieldTemplate = zInputFieldValueBase.extend({
|
||||
export const zControlInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ControlField'),
|
||||
value: zControlField.optional(),
|
||||
});
|
||||
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldTemplate>;
|
||||
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
|
||||
|
||||
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ControlPolymorphic'),
|
||||
value: z.union([zControlField, z.array(zControlField)]).optional(),
|
||||
});
|
||||
export type ControlPolymorphicInputFieldValue = z.infer<
|
||||
typeof zControlPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ControlCollection'),
|
||||
value: z.array(zControlField).optional(),
|
||||
});
|
||||
export type ControlCollectionInputFieldValue = z.infer<
|
||||
typeof zControlCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zModelType = z.enum([
|
||||
'onnx',
|
||||
@@ -352,6 +467,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
||||
|
||||
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ImagePolymorphic'),
|
||||
value: z.union([zImageField, z.array(zImageField)]).optional(),
|
||||
});
|
||||
export type ImagePolymorphicInputFieldValue = z.infer<
|
||||
typeof zImagePolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ImageCollection'),
|
||||
value: z.array(zImageField).optional(),
|
||||
@@ -444,6 +567,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
|
||||
|
||||
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ColorCollection'),
|
||||
value: z.array(zColorField).optional(),
|
||||
});
|
||||
export type ColorCollectionInputFieldValue = z.infer<
|
||||
typeof zColorCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ColorPolymorphic'),
|
||||
value: z.union([zColorField, z.array(zColorField)]).optional(),
|
||||
});
|
||||
export type ColorPolymorphicInputFieldValue = z.infer<
|
||||
typeof zColorPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Scheduler'),
|
||||
value: zScheduler.optional(),
|
||||
@@ -453,29 +592,47 @@ export type SchedulerInputFieldValue = z.infer<
|
||||
>;
|
||||
|
||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zIntegerInputFieldValue,
|
||||
zFloatInputFieldValue,
|
||||
zStringInputFieldValue,
|
||||
zBooleanCollectionInputFieldValue,
|
||||
zBooleanInputFieldValue,
|
||||
zImageInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zConditioningInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zBooleanPolymorphicInputFieldValue,
|
||||
zClipInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zControlInputFieldTemplate,
|
||||
zEnumInputFieldValue,
|
||||
zMainModelInputFieldValue,
|
||||
zSDXLMainModelInputFieldValue,
|
||||
zSDXLRefinerModelInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
zLoRAModelInputFieldValue,
|
||||
zControlNetModelInputFieldValue,
|
||||
zCollectionInputFieldValue,
|
||||
zCollectionItemInputFieldValue,
|
||||
zColorInputFieldValue,
|
||||
zColorCollectionInputFieldValue,
|
||||
zColorPolymorphicInputFieldValue,
|
||||
zConditioningInputFieldValue,
|
||||
zConditioningCollectionInputFieldValue,
|
||||
zConditioningPolymorphicInputFieldValue,
|
||||
zControlInputFieldValue,
|
||||
zControlNetModelInputFieldValue,
|
||||
zControlCollectionInputFieldValue,
|
||||
zControlPolymorphicInputFieldValue,
|
||||
zDenoiseMaskInputFieldValue,
|
||||
zEnumInputFieldValue,
|
||||
zFloatCollectionInputFieldValue,
|
||||
zFloatInputFieldValue,
|
||||
zFloatPolymorphicInputFieldValue,
|
||||
zImageCollectionInputFieldValue,
|
||||
zImagePolymorphicInputFieldValue,
|
||||
zImageInputFieldValue,
|
||||
zIntegerCollectionInputFieldValue,
|
||||
zIntegerPolymorphicInputFieldValue,
|
||||
zIntegerInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zLatentsCollectionInputFieldValue,
|
||||
zLatentsPolymorphicInputFieldValue,
|
||||
zLoRAModelInputFieldValue,
|
||||
zMainModelInputFieldValue,
|
||||
zSchedulerInputFieldValue,
|
||||
zSDXLMainModelInputFieldValue,
|
||||
zSDXLRefinerModelInputFieldValue,
|
||||
zStringCollectionInputFieldValue,
|
||||
zStringPolymorphicInputFieldValue,
|
||||
zStringInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
]);
|
||||
|
||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||
@@ -484,7 +641,6 @@ export type InputFieldTemplateBase = {
|
||||
name: string;
|
||||
title: string;
|
||||
description: string;
|
||||
type: FieldType;
|
||||
required: boolean;
|
||||
fieldKind: 'input';
|
||||
} & _InputField;
|
||||
@@ -499,6 +655,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'IntegerCollection';
|
||||
default: number[];
|
||||
item_default?: number;
|
||||
};
|
||||
|
||||
export type IntegerPolymorphicInputFieldTemplate = Omit<
|
||||
IntegerInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'IntegerPolymorphic';
|
||||
};
|
||||
|
||||
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'float';
|
||||
default: number;
|
||||
@@ -509,6 +678,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'FloatCollection';
|
||||
default: number[];
|
||||
item_default?: number;
|
||||
};
|
||||
|
||||
export type FloatPolymorphicInputFieldTemplate = Omit<
|
||||
FloatInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'FloatPolymorphic';
|
||||
};
|
||||
|
||||
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'string';
|
||||
default: string;
|
||||
@@ -517,31 +699,95 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||
pattern?: string;
|
||||
};
|
||||
|
||||
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'StringCollection';
|
||||
default: string[];
|
||||
item_default?: string;
|
||||
};
|
||||
|
||||
export type StringPolymorphicInputFieldTemplate = Omit<
|
||||
StringInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'StringPolymorphic';
|
||||
};
|
||||
|
||||
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: boolean;
|
||||
type: 'boolean';
|
||||
};
|
||||
|
||||
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'BooleanCollection';
|
||||
default: boolean[];
|
||||
item_default?: boolean;
|
||||
};
|
||||
|
||||
export type BooleanPolymorphicInputFieldTemplate = Omit<
|
||||
BooleanInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'BooleanPolymorphic';
|
||||
};
|
||||
|
||||
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: ImageDTO;
|
||||
default: ImageField;
|
||||
type: 'ImageField';
|
||||
};
|
||||
|
||||
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: ImageField[];
|
||||
type: 'ImageCollection';
|
||||
item_default?: ImageField;
|
||||
};
|
||||
|
||||
export type ImagePolymorphicInputFieldTemplate = Omit<
|
||||
ImageInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ImagePolymorphic';
|
||||
};
|
||||
|
||||
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'DenoiseMaskField';
|
||||
};
|
||||
|
||||
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
default: LatentsField;
|
||||
type: 'LatentsField';
|
||||
};
|
||||
|
||||
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: LatentsField[];
|
||||
type: 'LatentsCollection';
|
||||
item_default?: LatentsField;
|
||||
};
|
||||
|
||||
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: LatentsField;
|
||||
type: 'LatentsPolymorphic';
|
||||
};
|
||||
|
||||
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ConditioningField';
|
||||
};
|
||||
|
||||
export type ConditioningCollectionInputFieldTemplate =
|
||||
InputFieldTemplateBase & {
|
||||
default: ConditioningField[];
|
||||
type: 'ConditioningCollection';
|
||||
item_default?: ConditioningField;
|
||||
};
|
||||
|
||||
export type ConditioningPolymorphicInputFieldTemplate = Omit<
|
||||
ConditioningInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ConditioningPolymorphic';
|
||||
};
|
||||
|
||||
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'UNetField';
|
||||
@@ -562,6 +808,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ControlField';
|
||||
};
|
||||
|
||||
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ControlCollection';
|
||||
item_default?: ControlField;
|
||||
};
|
||||
|
||||
export type ControlPolymorphicInputFieldTemplate = Omit<
|
||||
ControlInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ControlPolymorphic';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string | number;
|
||||
type: 'enum';
|
||||
@@ -614,11 +873,77 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ColorField';
|
||||
};
|
||||
|
||||
export type ColorPolymorphicInputFieldTemplate = Omit<
|
||||
ColorInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ColorPolymorphic';
|
||||
};
|
||||
|
||||
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'ColorCollection';
|
||||
};
|
||||
|
||||
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: SchedulerParam;
|
||||
type: 'Scheduler';
|
||||
};
|
||||
|
||||
export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'WorkflowField';
|
||||
};
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the field type and other field metadata (e.g. title, description,
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| BooleanCollectionInputFieldTemplate
|
||||
| BooleanPolymorphicInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ClipInputFieldTemplate
|
||||
| CollectionInputFieldTemplate
|
||||
| CollectionItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
| ColorCollectionInputFieldTemplate
|
||||
| ColorPolymorphicInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| ConditioningCollectionInputFieldTemplate
|
||||
| ConditioningPolymorphicInputFieldTemplate
|
||||
| ControlInputFieldTemplate
|
||||
| ControlCollectionInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| ControlPolymorphicInputFieldTemplate
|
||||
| DenoiseMaskInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| FloatCollectionInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| FloatPolymorphicInputFieldTemplate
|
||||
| ImageCollectionInputFieldTemplate
|
||||
| ImagePolymorphicInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| IntegerCollectionInputFieldTemplate
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| IntegerInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| LatentsCollectionInputFieldTemplate
|
||||
| LatentsPolymorphicInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| MainModelInputFieldTemplate
|
||||
| SchedulerInputFieldTemplate
|
||||
| SDXLMainModelInputFieldTemplate
|
||||
| SDXLRefinerModelInputFieldTemplate
|
||||
| StringCollectionInputFieldTemplate
|
||||
| StringPolymorphicInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate;
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||
@@ -639,7 +964,9 @@ export type TypeHints = {
|
||||
export type InvocationSchemaExtra = {
|
||||
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
||||
title: string;
|
||||
category?: string;
|
||||
tags?: string[];
|
||||
version?: string;
|
||||
properties: Omit<
|
||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
||||
(_InputField | _OutputField),
|
||||
@@ -690,8 +1017,22 @@ export type InvocationSchemaObject = (
|
||||
) & { class: 'invocation' };
|
||||
|
||||
export const isSchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
||||
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||
|
||||
export const isArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||
|
||||
export const isNonArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.NonArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||
|
||||
export const isRefObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||
|
||||
export const isInvocationSchemaObject = (
|
||||
obj:
|
||||
@@ -715,6 +1056,73 @@ export const isInvocationFieldSchema = (
|
||||
|
||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish(),
|
||||
generation_mode: z.string().nullish(),
|
||||
created_by: z.string().nullish(),
|
||||
positive_prompt: z.string().nullish(),
|
||||
negative_prompt: z.string().nullish(),
|
||||
width: z.number().int().nullish(),
|
||||
height: z.number().int().nullish(),
|
||||
seed: z.number().int().nullish(),
|
||||
rand_device: z.string().nullish(),
|
||||
cfg_scale: z.number().nullish(),
|
||||
steps: z.number().int().nullish(),
|
||||
scheduler: z.string().nullish(),
|
||||
clip_skip: z.number().int().nullish(),
|
||||
model: z
|
||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||
.nullish(),
|
||||
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||
loras: z
|
||||
.array(
|
||||
z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
})
|
||||
)
|
||||
.nullish(),
|
||||
vae: zVaeModelField.nullish(),
|
||||
strength: z.number().nullish(),
|
||||
init_image: z.string().nullish(),
|
||||
positive_style_prompt: z.string().nullish(),
|
||||
negative_style_prompt: z.string().nullish(),
|
||||
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
|
||||
refiner_cfg_scale: z.number().nullish(),
|
||||
refiner_steps: z.number().int().nullish(),
|
||||
refiner_scheduler: z.string().nullish(),
|
||||
refiner_positive_aesthetic_score: z.number().nullish(),
|
||||
refiner_negative_aesthetic_score: z.number().nullish(),
|
||||
refiner_start: z.number().nullish(),
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
major !== undefined &&
|
||||
Number.isInteger(Number(major)) &&
|
||||
minor !== undefined &&
|
||||
Number.isInteger(Number(minor)) &&
|
||||
patch !== undefined &&
|
||||
Number.isInteger(Number(patch))
|
||||
);
|
||||
});
|
||||
|
||||
export const zParsedSemver = zSemVer.transform((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return {
|
||||
major: Number(major),
|
||||
minor: Number(minor),
|
||||
patch: Number(patch),
|
||||
};
|
||||
});
|
||||
|
||||
export type SemVer = z.infer<typeof zSemVer>;
|
||||
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
||||
@@ -725,6 +1133,9 @@ export const zInvocationNodeData = z.object({
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
embedWorkflow: z.boolean(),
|
||||
isIntermediate: z.boolean(),
|
||||
version: zSemVer.optional(),
|
||||
});
|
||||
|
||||
// Massage this to get better type safety while developing
|
||||
@@ -745,28 +1156,38 @@ export const zNotesNodeData = z.object({
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
|
||||
const zPosition = z
|
||||
.object({
|
||||
x: z.number(),
|
||||
y: z.number(),
|
||||
})
|
||||
.default({ x: 0, y: 0 });
|
||||
|
||||
const zDimension = z.number().gt(0).nullish();
|
||||
|
||||
export const zWorkflowInvocationNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('invocation'),
|
||||
data: zInvocationNodeData,
|
||||
width: z.number().gt(0),
|
||||
height: z.number().gt(0),
|
||||
position: z.object({
|
||||
x: z.number(),
|
||||
y: z.number(),
|
||||
}),
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zPosition,
|
||||
});
|
||||
|
||||
export type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
|
||||
|
||||
export const isWorkflowInvocationNode = (
|
||||
val: unknown
|
||||
): val is WorkflowInvocationNode =>
|
||||
zWorkflowInvocationNode.safeParse(val).success;
|
||||
|
||||
export const zWorkflowNotesNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
data: zNotesNodeData,
|
||||
width: z.number().gt(0),
|
||||
height: z.number().gt(0),
|
||||
position: z.object({
|
||||
x: z.number(),
|
||||
y: z.number(),
|
||||
}),
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zPosition,
|
||||
});
|
||||
|
||||
export const zWorkflowNode = z.discriminatedUnion('type', [
|
||||
@@ -776,14 +1197,25 @@ export const zWorkflowNode = z.discriminatedUnion('type', [
|
||||
|
||||
export type WorkflowNode = z.infer<typeof zWorkflowNode>;
|
||||
|
||||
export const zWorkflowEdge = z.object({
|
||||
export const zDefaultWorkflowEdge = z.object({
|
||||
source: z.string().trim().min(1),
|
||||
sourceHandle: z.string().trim().min(1),
|
||||
target: z.string().trim().min(1),
|
||||
targetHandle: z.string().trim().min(1),
|
||||
id: z.string().trim().min(1),
|
||||
type: z.enum(['default', 'collapsed']),
|
||||
type: z.literal('default'),
|
||||
});
|
||||
export const zCollapsedWorkflowEdge = z.object({
|
||||
source: z.string().trim().min(1),
|
||||
target: z.string().trim().min(1),
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('collapsed'),
|
||||
});
|
||||
|
||||
export const zWorkflowEdge = z.union([
|
||||
zDefaultWorkflowEdge,
|
||||
zCollapsedWorkflowEdge,
|
||||
]);
|
||||
|
||||
export const zFieldIdentifier = z.object({
|
||||
nodeId: z.string().trim().min(1),
|
||||
@@ -792,35 +1224,80 @@ export const zFieldIdentifier = z.object({
|
||||
|
||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
major !== undefined &&
|
||||
minor !== undefined &&
|
||||
patch !== undefined &&
|
||||
Number.isInteger(Number(major)) &&
|
||||
Number.isInteger(Number(minor)) &&
|
||||
Number.isInteger(Number(patch))
|
||||
);
|
||||
});
|
||||
|
||||
export type SemVer = z.infer<typeof zSemVer>;
|
||||
export type WorkflowWarning = {
|
||||
message: string;
|
||||
issues: string[];
|
||||
data: JsonObject;
|
||||
};
|
||||
|
||||
export const zWorkflow = z.object({
|
||||
name: z.string(),
|
||||
author: z.string(),
|
||||
description: z.string(),
|
||||
version: z.string(),
|
||||
contact: z.string(),
|
||||
tags: z.string(),
|
||||
notes: z.string(),
|
||||
nodes: z.array(zWorkflowNode),
|
||||
edges: z.array(zWorkflowEdge),
|
||||
exposedFields: z.array(zFieldIdentifier),
|
||||
name: z.string().default(''),
|
||||
author: z.string().default(''),
|
||||
description: z.string().default(''),
|
||||
version: z.string().default(''),
|
||||
contact: z.string().default(''),
|
||||
tags: z.string().default(''),
|
||||
notes: z.string().default(''),
|
||||
nodes: z.array(zWorkflowNode).default([]),
|
||||
edges: z.array(zWorkflowEdge).default([]),
|
||||
exposedFields: z.array(zFieldIdentifier).default([]),
|
||||
meta: z
|
||||
.object({
|
||||
version: zSemVer,
|
||||
})
|
||||
.default({ version: '1.0.0' }),
|
||||
});
|
||||
|
||||
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||
const { nodes, edges } = workflow;
|
||||
const warnings: WorkflowWarning[] = [];
|
||||
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||
edges.forEach((edge, i) => {
|
||||
const sourceNode = keyedNodes[edge.source];
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
const issues: string[] = [];
|
||||
if (!sourceNode) {
|
||||
issues.push(`Output node ${edge.source} does not exist`);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||
);
|
||||
}
|
||||
if (!targetNode) {
|
||||
issues.push(`Input node ${edge.target} does not exist`);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.targetHandle in targetNode.data.inputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||
);
|
||||
}
|
||||
if (issues.length) {
|
||||
delete edges[i];
|
||||
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
|
||||
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
|
||||
warnings.push({
|
||||
message: `Edge "${src} -> ${tgt}" skipped`,
|
||||
issues,
|
||||
data: edge,
|
||||
});
|
||||
}
|
||||
});
|
||||
return { workflow, warnings };
|
||||
});
|
||||
|
||||
export type Workflow = z.infer<typeof zWorkflow>;
|
||||
|
||||
export type ImageMetadataAndWorkflow = {
|
||||
metadata?: CoreMetadata;
|
||||
workflow?: Workflow;
|
||||
};
|
||||
|
||||
export type CurrentImageNodeData = {
|
||||
id: string;
|
||||
type: 'current_image';
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { NodesState } from '../store/types';
|
||||
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
|
||||
export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||
const { workflow: workflowMeta, nodes, edges } = nodesState;
|
||||
@@ -11,17 +12,29 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||
edges: [],
|
||||
};
|
||||
|
||||
nodes.forEach((node) => {
|
||||
const result = zWorkflowNode.safeParse(node);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
workflow.nodes.push(result.data);
|
||||
});
|
||||
nodes
|
||||
.filter((n) =>
|
||||
['invocation', 'notes'].includes(n.type ?? '__UNKNOWN_NODE_TYPE__')
|
||||
)
|
||||
.forEach((node) => {
|
||||
const result = zWorkflowNode.safeParse(node);
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
prefix: 'Unable to parse node',
|
||||
});
|
||||
logger('nodes').warn({ node: parseify(node) }, message);
|
||||
return;
|
||||
}
|
||||
workflow.nodes.push(result.data);
|
||||
});
|
||||
|
||||
edges.forEach((edge) => {
|
||||
const result = zWorkflowEdge.safeParse(edge);
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
prefix: 'Unable to parse edge',
|
||||
});
|
||||
logger('nodes').warn({ edge: parseify(edge) }, message);
|
||||
return;
|
||||
}
|
||||
workflow.edges.push(result.data);
|
||||
@@ -29,7 +42,3 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||
|
||||
return workflow;
|
||||
};
|
||||
|
||||
export const workflowSelector = createSelector(stateSelector, ({ nodes }) =>
|
||||
buildWorkflow(nodes)
|
||||
);
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
SINGLE_TO_POLYMORPHIC_MAP,
|
||||
isCollectionItemType,
|
||||
isPolymorphicItemType,
|
||||
} from '../types/constants';
|
||||
import {
|
||||
BooleanCollectionInputFieldTemplate,
|
||||
BooleanInputFieldTemplate,
|
||||
ClipInputFieldTemplate,
|
||||
CollectionInputFieldTemplate,
|
||||
@@ -8,12 +17,16 @@ import {
|
||||
ConditioningInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatCollectionInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
ImageCollectionInputFieldTemplate,
|
||||
ImageInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
IntegerCollectionInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
InvocationFieldSchema,
|
||||
InvocationSchemaObject,
|
||||
@@ -23,12 +36,32 @@ import {
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SchedulerInputFieldTemplate,
|
||||
StringCollectionInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
isFieldType,
|
||||
isArraySchemaObject,
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ColorPolymorphicInputFieldTemplate,
|
||||
ColorCollectionInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
LatentsCollectionInputFieldTemplate,
|
||||
ConditioningPolymorphicInputFieldTemplate,
|
||||
ConditioningCollectionInputFieldTemplate,
|
||||
ControlCollectionInputFieldTemplate,
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
|
||||
@@ -45,15 +78,8 @@ export type BuildInputFieldArg = {
|
||||
* @example
|
||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||
*/
|
||||
export const refObjectToFieldType = (
|
||||
refObject: OpenAPIV3.ReferenceObject
|
||||
): FieldType => {
|
||||
const name = refObject.$ref.split('/').slice(-1)[0];
|
||||
if (!name) {
|
||||
throw `Unknown field type: ${name}`;
|
||||
}
|
||||
return name as FieldType;
|
||||
};
|
||||
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
|
||||
refObject.$ref.split('/').slice(-1)[0];
|
||||
|
||||
const buildIntegerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
@@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
|
||||
const template: IntegerPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IntegerPolymorphic',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
|
||||
const item_default =
|
||||
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: 0;
|
||||
const template: IntegerCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IntegerCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
|
||||
const template: FloatPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'FloatPolymorphic',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
|
||||
const item_default = isNumber(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: 0;
|
||||
const template: FloatCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'FloatCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
|
||||
const template: StringPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'StringPolymorphic',
|
||||
default: schemaObject.default ?? '',
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (schemaObject.pattern !== undefined) {
|
||||
template.pattern = schemaObject.pattern;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
|
||||
const item_default = isString(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: '';
|
||||
const template: StringCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'StringCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
|
||||
const template: BooleanPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'BooleanPolymorphic',
|
||||
default: schemaObject.default ?? false,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
|
||||
const item_default =
|
||||
schemaObject.item_default && isBoolean(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: false;
|
||||
const template: BooleanCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'BooleanCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMainModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImagePolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
|
||||
const template: ImagePolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ImagePolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -257,6 +468,20 @@ const buildImageCollectionInputFieldTemplate = ({
|
||||
const template: ImageCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ImageCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ImageField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildDenoiseMaskInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): DenoiseMaskInputFieldTemplate => {
|
||||
const template: DenoiseMaskInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'DenoiseMaskField',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@@ -276,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLatentsPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
|
||||
const template: LatentsPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'LatentsPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLatentsCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
|
||||
const template: LatentsCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'LatentsCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -289,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
|
||||
const template: ConditioningPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ConditioningPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
|
||||
const template: ConditioningCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ConditioningCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildUNetInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -342,6 +621,33 @@ const buildControlInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
|
||||
const template: ControlPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ControlPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
|
||||
const template: ControlCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ControlCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ControlField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -395,6 +701,32 @@ const buildColorInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
|
||||
const template: ColorPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ColorPolymorphic',
|
||||
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
|
||||
const template: ColorCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ColorCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSchedulerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -410,49 +742,136 @@ const buildSchedulerInputFieldTemplate = ({
|
||||
|
||||
export const getFieldType = (
|
||||
schemaObject: InvocationFieldSchema
|
||||
): FieldType => {
|
||||
let fieldType = '';
|
||||
|
||||
const { ui_type } = schemaObject;
|
||||
if (ui_type) {
|
||||
fieldType = ui_type;
|
||||
): string | undefined => {
|
||||
if (schemaObject?.ui_type) {
|
||||
return schemaObject.ui_type;
|
||||
} else if (!schemaObject.type) {
|
||||
// console.log('refObject', schemaObject);
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
|
||||
if (schemaObject.allOf) {
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
const allOf = schemaObject.allOf;
|
||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||
return refObjectToSchemaName(allOf[0]);
|
||||
}
|
||||
} else if (schemaObject.anyOf) {
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.oneOf) {
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
const anyOf = schemaObject.anyOf;
|
||||
/**
|
||||
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||
* - an `anyOf` with two items
|
||||
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||
*
|
||||
* Any other cases we ignore.
|
||||
*/
|
||||
|
||||
let firstType: string | undefined;
|
||||
let secondType: string | undefined;
|
||||
|
||||
if (isArraySchemaObject(anyOf[0])) {
|
||||
// first is array, second is not
|
||||
const first = anyOf[0].items;
|
||||
const second = anyOf[1];
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
} else if (isArraySchemaObject(anyOf[1])) {
|
||||
// first is not array, second is
|
||||
const first = anyOf[0];
|
||||
const second = anyOf[1].items;
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
}
|
||||
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||
}
|
||||
}
|
||||
} else if (schemaObject.enum) {
|
||||
fieldType = 'enum';
|
||||
return 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'number') {
|
||||
// floats are "number" in OpenAPI, while ints are "integer"
|
||||
fieldType = 'float';
|
||||
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||
return 'float';
|
||||
} else if (schemaObject.type === 'array') {
|
||||
const itemType = isSchemaObject(schemaObject.items)
|
||||
? schemaObject.items.type
|
||||
: refObjectToSchemaName(schemaObject.items);
|
||||
|
||||
if (isCollectionItemType(itemType)) {
|
||||
return COLLECTION_MAP[itemType];
|
||||
}
|
||||
|
||||
return;
|
||||
} else {
|
||||
fieldType = schemaObject.type;
|
||||
return schemaObject.type;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
throw `Field type "${fieldType}" is unknown!`;
|
||||
}
|
||||
|
||||
return fieldType;
|
||||
return;
|
||||
};
|
||||
|
||||
const TEMPLATE_BUILDER_MAP = {
|
||||
boolean: buildBooleanInputFieldTemplate,
|
||||
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||
ClipField: buildClipInputFieldTemplate,
|
||||
Collection: buildCollectionInputFieldTemplate,
|
||||
CollectionItem: buildCollectionItemInputFieldTemplate,
|
||||
ColorCollection: buildColorCollectionInputFieldTemplate,
|
||||
ColorField: buildColorInputFieldTemplate,
|
||||
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
|
||||
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
|
||||
ConditioningField: buildConditioningInputFieldTemplate,
|
||||
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
|
||||
ControlCollection: buildControlCollectionInputFieldTemplate,
|
||||
ControlField: buildControlInputFieldTemplate,
|
||||
ControlNetModelField: buildControlNetModelInputFieldTemplate,
|
||||
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
|
||||
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
|
||||
enum: buildEnumInputFieldTemplate,
|
||||
float: buildFloatInputFieldTemplate,
|
||||
FloatCollection: buildFloatCollectionInputFieldTemplate,
|
||||
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
|
||||
ImageCollection: buildImageCollectionInputFieldTemplate,
|
||||
ImageField: buildImageInputFieldTemplate,
|
||||
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
|
||||
integer: buildIntegerInputFieldTemplate,
|
||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||
MainModelField: buildMainModelInputFieldTemplate,
|
||||
Scheduler: buildSchedulerInputFieldTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
|
||||
string: buildStringInputFieldTemplate,
|
||||
StringCollection: buildStringCollectionInputFieldTemplate,
|
||||
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
||||
UNetField: buildUNetInputFieldTemplate,
|
||||
VaeField: buildVaeInputFieldTemplate,
|
||||
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||
};
|
||||
|
||||
const isTemplatedFieldType = (
|
||||
fieldType: string | undefined
|
||||
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
|
||||
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
|
||||
|
||||
/**
|
||||
* Builds an input field from an invocation schema property.
|
||||
* @param fieldSchema The schema object
|
||||
@@ -461,16 +880,14 @@ export const getFieldType = (
|
||||
export const buildInputFieldTemplate = (
|
||||
nodeSchema: InvocationSchemaObject,
|
||||
fieldSchema: InvocationFieldSchema,
|
||||
name: string
|
||||
name: string,
|
||||
fieldType: FieldType
|
||||
) => {
|
||||
// console.log('input', schemaObject);
|
||||
const fieldType = getFieldType(fieldSchema);
|
||||
// console.log('input fieldType', fieldType);
|
||||
|
||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
||||
|
||||
const extra = {
|
||||
input,
|
||||
// TODO: Can we support polymorphic inputs in the UI?
|
||||
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
@@ -486,140 +903,12 @@ export const buildInputFieldTemplate = (
|
||||
...extra,
|
||||
};
|
||||
|
||||
if (fieldType === 'ImageField') {
|
||||
return buildImageInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
if (!isTemplatedFieldType(fieldType)) {
|
||||
return;
|
||||
}
|
||||
if (fieldType === 'ImageCollection') {
|
||||
return buildImageCollectionInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'LatentsField') {
|
||||
return buildLatentsInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ConditioningField') {
|
||||
return buildConditioningInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'UNetField') {
|
||||
return buildUNetInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ClipField') {
|
||||
return buildClipInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'VaeField') {
|
||||
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
|
||||
}
|
||||
if (fieldType === 'ControlField') {
|
||||
return buildControlInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'MainModelField') {
|
||||
return buildMainModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'SDXLRefinerModelField') {
|
||||
return buildRefinerModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'SDXLMainModelField') {
|
||||
return buildSDXLMainModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'VaeModelField') {
|
||||
return buildVaeModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'LoRAModelField') {
|
||||
return buildLoRAModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ControlNetModelField') {
|
||||
return buildControlNetModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'enum') {
|
||||
return buildEnumInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'integer') {
|
||||
return buildIntegerInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'float') {
|
||||
return buildFloatInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'string') {
|
||||
return buildStringInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'boolean') {
|
||||
return buildBooleanInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'Collection') {
|
||||
return buildCollectionInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'CollectionItem') {
|
||||
return buildCollectionItemInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ColorField') {
|
||||
return buildColorInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'Scheduler') {
|
||||
return buildSchedulerInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
return;
|
||||
|
||||
return TEMPLATE_BUILDER_MAP[fieldType]({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,100 +1,79 @@
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
|
||||
const FIELD_VALUE_FALLBACK_MAP = {
|
||||
'enum.number': 0,
|
||||
'enum.string': '',
|
||||
boolean: false,
|
||||
BooleanCollection: [],
|
||||
BooleanPolymorphic: false,
|
||||
ClipField: undefined,
|
||||
Collection: [],
|
||||
CollectionItem: undefined,
|
||||
ColorCollection: [],
|
||||
ColorField: undefined,
|
||||
ColorPolymorphic: undefined,
|
||||
ConditioningCollection: [],
|
||||
ConditioningField: undefined,
|
||||
ConditioningPolymorphic: undefined,
|
||||
ControlCollection: [],
|
||||
ControlField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
ControlPolymorphic: undefined,
|
||||
DenoiseMaskField: undefined,
|
||||
float: 0,
|
||||
FloatCollection: [],
|
||||
FloatPolymorphic: 0,
|
||||
ImageCollection: [],
|
||||
ImageField: undefined,
|
||||
ImagePolymorphic: undefined,
|
||||
integer: 0,
|
||||
IntegerCollection: [],
|
||||
IntegerPolymorphic: 0,
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
ONNXModelField: undefined,
|
||||
Scheduler: 'euler',
|
||||
SDXLMainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
string: '',
|
||||
StringCollection: [],
|
||||
StringPolymorphic: '',
|
||||
UNetField: undefined,
|
||||
VaeField: undefined,
|
||||
VaeModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildInputFieldValue = (
|
||||
id: string,
|
||||
template: InputFieldTemplate
|
||||
): InputFieldValue => {
|
||||
const fieldValue: InputFieldValue = {
|
||||
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
|
||||
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
|
||||
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
|
||||
// `InputFieldValue` union, but TS doesn't seem to like it...
|
||||
const fieldValue = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input',
|
||||
};
|
||||
|
||||
if (template.type === 'string') {
|
||||
fieldValue.value = template.default ?? '';
|
||||
}
|
||||
|
||||
if (template.type === 'integer') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'float') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'boolean') {
|
||||
fieldValue.value = template.default ?? false;
|
||||
}
|
||||
} as InputFieldValue;
|
||||
|
||||
if (template.type === 'enum') {
|
||||
if (template.enumType === 'number') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
|
||||
}
|
||||
if (template.enumType === 'string') {
|
||||
fieldValue.value = template.default ?? '';
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
|
||||
}
|
||||
}
|
||||
|
||||
if (template.type === 'Collection') {
|
||||
fieldValue.value = template.default ?? 1;
|
||||
}
|
||||
|
||||
if (template.type === 'ImageField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ImageCollection') {
|
||||
fieldValue.value = [];
|
||||
}
|
||||
|
||||
if (template.type === 'LatentsField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ConditioningField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'UNetField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ClipField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'VaeField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ControlField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'MainModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'SDXLRefinerModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'VaeModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'LoRAModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ControlNetModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'Scheduler') {
|
||||
fieldValue.value = 'euler';
|
||||
} else {
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import * as png from '@stevebel/png';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import {
|
||||
ImageMetadataAndWorkflow,
|
||||
zCoreMetadata,
|
||||
zWorkflow,
|
||||
} from 'features/nodes/types/types';
|
||||
import { get } from 'lodash-es';
|
||||
|
||||
export const getMetadataAndWorkflowFromImageBlob = async (
|
||||
image: Blob
|
||||
): Promise<ImageMetadataAndWorkflow> => {
|
||||
const data: ImageMetadataAndWorkflow = {};
|
||||
const buffer = await image.arrayBuffer();
|
||||
const text = png.decode(buffer).text;
|
||||
|
||||
const rawMetadata = get(text, 'invokeai_metadata');
|
||||
if (rawMetadata) {
|
||||
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||
if (metadataResult.success) {
|
||||
data.metadata = metadataResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(metadataResult.error) },
|
||||
'Problem reading metadata from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const rawWorkflow = get(text, 'invokeai_workflow');
|
||||
if (rawWorkflow) {
|
||||
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||
if (workflowResult.success) {
|
||||
data.workflow = workflowResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(workflowResult.error) },
|
||||
'Problem reading workflow from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user