mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 06:18:02 -05:00
Compare commits
4 Commits
feat/invoc
...
feat/ui/wo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f505ec64ba | ||
|
|
f22eb368a3 | ||
|
|
96ae22c7e0 | ||
|
|
f5447cdc23 |
@@ -1,12 +0,0 @@
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import type { PluginOption, UserConfig } from 'vite';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
|
||||
export const commonPlugins: UserConfig['plugins'] = [
|
||||
react(),
|
||||
eslint(),
|
||||
tsconfigPaths(),
|
||||
visualizer() as unknown as PluginOption,
|
||||
];
|
||||
@@ -1,33 +0,0 @@
|
||||
import type { UserConfig } from 'vite';
|
||||
|
||||
import { commonPlugins } from './common.mjs';
|
||||
|
||||
export const appConfig: UserConfig = {
|
||||
base: './',
|
||||
plugins: [...commonPlugins],
|
||||
build: {
|
||||
chunkSizeWarningLimit: 1500,
|
||||
},
|
||||
server: {
|
||||
// Proxy HTTP requests to the flask server
|
||||
proxy: {
|
||||
// Proxy socket.io to the nodes socketio server
|
||||
'/ws/socket.io': {
|
||||
target: 'ws://127.0.0.1:9090',
|
||||
ws: true,
|
||||
},
|
||||
// Proxy openapi schema definiton
|
||||
'/openapi.json': {
|
||||
target: 'http://127.0.0.1:9090/openapi.json',
|
||||
rewrite: (path) => path.replace(/^\/openapi.json/, ''),
|
||||
changeOrigin: true,
|
||||
},
|
||||
// proxy nodes api
|
||||
'/api/v1': {
|
||||
target: 'http://127.0.0.1:9090/api/v1',
|
||||
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -1,46 +0,0 @@
|
||||
import path from 'path';
|
||||
import type { UserConfig } from 'vite';
|
||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||
import dts from 'vite-plugin-dts';
|
||||
|
||||
import { commonPlugins } from './common.mjs';
|
||||
|
||||
export const packageConfig: UserConfig = {
|
||||
base: './',
|
||||
plugins: [
|
||||
...commonPlugins,
|
||||
dts({
|
||||
insertTypesEntry: true,
|
||||
}),
|
||||
cssInjectedByJsPlugin(),
|
||||
],
|
||||
build: {
|
||||
cssCodeSplit: true,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
||||
name: 'InvokeAIUI',
|
||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
||||
},
|
||||
rollupOptions: {
|
||||
external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'],
|
||||
output: {
|
||||
globals: {
|
||||
react: 'React',
|
||||
'react-dom': 'ReactDOM',
|
||||
'@emotion/react': 'EmotionReact',
|
||||
'@invoke-ai/ui-library': 'UiLibrary',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
app: path.resolve(__dirname, '../src/app'),
|
||||
assets: path.resolve(__dirname, '../src/assets'),
|
||||
common: path.resolve(__dirname, '../src/common'),
|
||||
features: path.resolve(__dirname, '../src/features'),
|
||||
services: path.resolve(__dirname, '../src/services'),
|
||||
theme: path.resolve(__dirname, '../src/theme'),
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -33,7 +33,9 @@
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
"unimported": "npx unimported"
|
||||
"unimported": "npx unimported",
|
||||
"test": "vitest",
|
||||
"test:no-watch": "vitest --no-watch"
|
||||
},
|
||||
"madge": {
|
||||
"excludeRegExp": [
|
||||
@@ -157,7 +159,8 @@
|
||||
"vite-plugin-css-injected-by-js": "^3.3.1",
|
||||
"vite-plugin-dts": "^3.7.1",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.3.1"
|
||||
"vite-tsconfig-paths": "^4.3.1",
|
||||
"vitest": "^1.2.2"
|
||||
},
|
||||
"pnpm": {
|
||||
"patchedDependencies": {
|
||||
|
||||
222
invokeai/frontend/web/pnpm-lock.yaml
generated
222
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -215,7 +215,7 @@ devDependencies:
|
||||
version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12)
|
||||
'@storybook/test':
|
||||
specifier: ^7.6.10
|
||||
version: 7.6.10
|
||||
version: 7.6.10(vitest@1.2.2)
|
||||
'@storybook/theming':
|
||||
specifier: ^7.6.10
|
||||
version: 7.6.10(react-dom@18.2.0)(react@18.2.0)
|
||||
@@ -318,6 +318,9 @@ devDependencies:
|
||||
vite-tsconfig-paths:
|
||||
specifier: ^4.3.1
|
||||
version: 4.3.1(typescript@5.3.3)(vite@5.0.12)
|
||||
vitest:
|
||||
specifier: ^1.2.2
|
||||
version: 1.2.2(@types/node@20.11.5)
|
||||
|
||||
packages:
|
||||
|
||||
@@ -5464,7 +5467,7 @@ packages:
|
||||
- supports-color
|
||||
dev: true
|
||||
|
||||
/@storybook/test@7.6.10:
|
||||
/@storybook/test@7.6.10(vitest@1.2.2):
|
||||
resolution: {integrity: sha512-dn/T+HcWOBlVh3c74BHurp++BaqBoQgNbSIaXlYDpJoZ+DzNIoEQVsWFYm5gCbtKK27iFd4n52RiQI3f6Vblqw==}
|
||||
dependencies:
|
||||
'@storybook/client-logger': 7.6.10
|
||||
@@ -5472,7 +5475,7 @@ packages:
|
||||
'@storybook/instrumenter': 7.6.10
|
||||
'@storybook/preview-api': 7.6.10
|
||||
'@testing-library/dom': 9.3.4
|
||||
'@testing-library/jest-dom': 6.2.0
|
||||
'@testing-library/jest-dom': 6.2.0(vitest@1.2.2)
|
||||
'@testing-library/user-event': 14.3.0(@testing-library/dom@9.3.4)
|
||||
'@types/chai': 4.3.11
|
||||
'@vitest/expect': 0.34.7
|
||||
@@ -5652,7 +5655,7 @@ packages:
|
||||
pretty-format: 27.5.1
|
||||
dev: true
|
||||
|
||||
/@testing-library/jest-dom@6.2.0:
|
||||
/@testing-library/jest-dom@6.2.0(vitest@1.2.2):
|
||||
resolution: {integrity: sha512-+BVQlJ9cmEn5RDMUS8c2+TU6giLvzaHZ8sU/x0Jj7fk+6/46wPdwlgOPcpxS17CjcanBi/3VmGMqVr2rmbUmNw==}
|
||||
engines: {node: '>=14', npm: '>=6', yarn: '>=1'}
|
||||
peerDependencies:
|
||||
@@ -5678,6 +5681,7 @@ packages:
|
||||
dom-accessibility-api: 0.6.3
|
||||
lodash: 4.17.21
|
||||
redent: 3.0.0
|
||||
vitest: 1.2.2(@types/node@20.11.5)
|
||||
dev: true
|
||||
|
||||
/@testing-library/user-event@14.3.0(@testing-library/dom@9.3.4):
|
||||
@@ -6490,12 +6494,42 @@ packages:
|
||||
chai: 4.4.1
|
||||
dev: true
|
||||
|
||||
/@vitest/expect@1.2.2:
|
||||
resolution: {integrity: sha512-3jpcdPAD7LwHUUiT2pZTj2U82I2Tcgg2oVPvKxhn6mDI2On6tfvPQTjAI4628GUGDZrCm4Zna9iQHm5cEexOAg==}
|
||||
dependencies:
|
||||
'@vitest/spy': 1.2.2
|
||||
'@vitest/utils': 1.2.2
|
||||
chai: 4.4.1
|
||||
dev: true
|
||||
|
||||
/@vitest/runner@1.2.2:
|
||||
resolution: {integrity: sha512-JctG7QZ4LSDXr5CsUweFgcpEvrcxOV1Gft7uHrvkQ+fsAVylmWQvnaAr/HDp3LAH1fztGMQZugIheTWjaGzYIg==}
|
||||
dependencies:
|
||||
'@vitest/utils': 1.2.2
|
||||
p-limit: 5.0.0
|
||||
pathe: 1.1.2
|
||||
dev: true
|
||||
|
||||
/@vitest/snapshot@1.2.2:
|
||||
resolution: {integrity: sha512-SmGY4saEw1+bwE1th6S/cZmPxz/Q4JWsl7LvbQIky2tKE35US4gd0Mjzqfr84/4OD0tikGWaWdMja/nWL5NIPA==}
|
||||
dependencies:
|
||||
magic-string: 0.30.5
|
||||
pathe: 1.1.2
|
||||
pretty-format: 29.7.0
|
||||
dev: true
|
||||
|
||||
/@vitest/spy@0.34.7:
|
||||
resolution: {integrity: sha512-NMMSzOY2d8L0mcOt4XcliDOS1ISyGlAXuQtERWVOoVHnKwmG+kKhinAiGw3dTtMQWybfa89FG8Ucg9tiC/FhTQ==}
|
||||
dependencies:
|
||||
tinyspy: 2.2.0
|
||||
dev: true
|
||||
|
||||
/@vitest/spy@1.2.2:
|
||||
resolution: {integrity: sha512-k9Gcahssw8d7X3pSLq3e3XEu/0L78mUkCjivUqCQeXJm9clfXR/Td8+AP+VC1O6fKPIDLcHDTAmBOINVuv6+7g==}
|
||||
dependencies:
|
||||
tinyspy: 2.2.0
|
||||
dev: true
|
||||
|
||||
/@vitest/utils@0.34.7:
|
||||
resolution: {integrity: sha512-ziAavQLpCYS9sLOorGrFFKmy2gnfiNU0ZJ15TsMz/K92NAPS/rp9K4z6AJQQk5Y8adCy4Iwpxy7pQumQ/psnRg==}
|
||||
dependencies:
|
||||
@@ -6504,6 +6538,15 @@ packages:
|
||||
pretty-format: 29.7.0
|
||||
dev: true
|
||||
|
||||
/@vitest/utils@1.2.2:
|
||||
resolution: {integrity: sha512-WKITBHLsBHlpjnDQahr+XK6RE7MiAsgrIkr0pGhQ9ygoxBfUeG0lUG5iLlzqjmKSlBv3+j5EGsriBzh+C3Tq9g==}
|
||||
dependencies:
|
||||
diff-sequences: 29.6.3
|
||||
estree-walker: 3.0.3
|
||||
loupe: 2.3.7
|
||||
pretty-format: 29.7.0
|
||||
dev: true
|
||||
|
||||
/@volar/language-core@1.11.1:
|
||||
resolution: {integrity: sha512-dOcNn3i9GgZAcJt43wuaEykSluAuOkQgzni1cuxLxTV0nJKanQztp7FxyswdRILaKH+P2XZMPRp2S4MV/pElCw==}
|
||||
dependencies:
|
||||
@@ -7184,6 +7227,11 @@ packages:
|
||||
engines: {node: '>=0.4.0'}
|
||||
dev: true
|
||||
|
||||
/acorn-walk@8.3.2:
|
||||
resolution: {integrity: sha512-cjkyv4OtNCIeqhHrfS81QWXoCBPExR/J62oyEqepVw8WaQeSqpW2uhuLPh1m9eWhDuOo/jUXVTlifvesOWp/4A==}
|
||||
engines: {node: '>=0.4.0'}
|
||||
dev: true
|
||||
|
||||
/acorn@7.4.1:
|
||||
resolution: {integrity: sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==}
|
||||
engines: {node: '>=0.4.0'}
|
||||
@@ -7661,6 +7709,11 @@ packages:
|
||||
engines: {node: '>= 0.8'}
|
||||
dev: true
|
||||
|
||||
/cac@6.7.14:
|
||||
resolution: {integrity: sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==}
|
||||
engines: {node: '>=8'}
|
||||
dev: true
|
||||
|
||||
/call-bind@1.0.5:
|
||||
resolution: {integrity: sha512-C3nQxfFZxFRVoJoGKKI8y3MOEo129NQ+FgQ08iye+Mk4zNZZGdjfs06bVTr+DBSlA66Q2VEcMki/cUCP4SercQ==}
|
||||
dependencies:
|
||||
@@ -9173,6 +9226,12 @@ packages:
|
||||
resolution: {integrity: sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==}
|
||||
dev: true
|
||||
|
||||
/estree-walker@3.0.3:
|
||||
resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==}
|
||||
dependencies:
|
||||
'@types/estree': 1.0.5
|
||||
dev: true
|
||||
|
||||
/esutils@2.0.3:
|
||||
resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
@@ -10547,6 +10606,10 @@ packages:
|
||||
hasBin: true
|
||||
dev: true
|
||||
|
||||
/jsonc-parser@3.2.1:
|
||||
resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==}
|
||||
dev: true
|
||||
|
||||
/jsondiffpatch@0.6.0:
|
||||
resolution: {integrity: sha512-3QItJOXp2AP1uv7waBkao5nCvhEv+QmJAd38Ybq7wNI74Q+BBmnLn4EDKz6yI9xGAIQoUF87qHt+kc1IVxB4zQ==}
|
||||
engines: {node: ^18.0.0 || >=20.0.0}
|
||||
@@ -10648,6 +10711,14 @@ packages:
|
||||
engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0}
|
||||
dev: true
|
||||
|
||||
/local-pkg@0.5.0:
|
||||
resolution: {integrity: sha512-ok6z3qlYyCDS4ZEU27HaU6x/xZa9Whf8jD4ptH5UZTQYZVYeb9bnZ3ojVhiJNLiXK1Hfc0GNbLXcmZ5plLDDBg==}
|
||||
engines: {node: '>=14'}
|
||||
dependencies:
|
||||
mlly: 1.5.0
|
||||
pkg-types: 1.0.3
|
||||
dev: true
|
||||
|
||||
/locate-path@3.0.0:
|
||||
resolution: {integrity: sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A==}
|
||||
engines: {node: '>=6'}
|
||||
@@ -10986,6 +11057,15 @@ packages:
|
||||
hasBin: true
|
||||
dev: true
|
||||
|
||||
/mlly@1.5.0:
|
||||
resolution: {integrity: sha512-NPVQvAY1xr1QoVeG0cy8yUYC7FQcOx6evl/RjT1wL5FvzPnzOysoqB/jmx/DhssT2dYa8nxECLAaFI/+gVLhDQ==}
|
||||
dependencies:
|
||||
acorn: 8.11.3
|
||||
pathe: 1.1.2
|
||||
pkg-types: 1.0.3
|
||||
ufo: 1.3.2
|
||||
dev: true
|
||||
|
||||
/module-definition@3.4.0:
|
||||
resolution: {integrity: sha512-XxJ88R1v458pifaSkPNLUTdSPNVGMP2SXVncVmApGO+gAfrLANiYe6JofymCzVceGOMwQE2xogxBSc8uB7XegA==}
|
||||
engines: {node: '>=6.0'}
|
||||
@@ -11380,6 +11460,13 @@ packages:
|
||||
yocto-queue: 0.1.0
|
||||
dev: true
|
||||
|
||||
/p-limit@5.0.0:
|
||||
resolution: {integrity: sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==}
|
||||
engines: {node: '>=18'}
|
||||
dependencies:
|
||||
yocto-queue: 1.0.0
|
||||
dev: true
|
||||
|
||||
/p-locate@3.0.0:
|
||||
resolution: {integrity: sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ==}
|
||||
engines: {node: '>=6'}
|
||||
@@ -11550,6 +11637,14 @@ packages:
|
||||
find-up: 5.0.0
|
||||
dev: true
|
||||
|
||||
/pkg-types@1.0.3:
|
||||
resolution: {integrity: sha512-nN7pYi0AQqJnoLPC9eHFQ8AcyaixBUOwvqc5TDnIKCMEE6I0y8P7OKA7fPexsXGCGxQDl/cmrLAp26LhcwxZ4A==}
|
||||
dependencies:
|
||||
jsonc-parser: 3.2.1
|
||||
mlly: 1.5.0
|
||||
pathe: 1.1.2
|
||||
dev: true
|
||||
|
||||
/pluralize@8.0.0:
|
||||
resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==}
|
||||
engines: {node: '>=4'}
|
||||
@@ -12850,6 +12945,10 @@ packages:
|
||||
object-inspect: 1.13.1
|
||||
dev: true
|
||||
|
||||
/siginfo@2.0.0:
|
||||
resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==}
|
||||
dev: true
|
||||
|
||||
/signal-exit@3.0.7:
|
||||
resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==}
|
||||
dev: true
|
||||
@@ -12968,6 +13067,10 @@ packages:
|
||||
stackframe: 1.3.4
|
||||
dev: false
|
||||
|
||||
/stackback@0.0.2:
|
||||
resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==}
|
||||
dev: true
|
||||
|
||||
/stackframe@1.3.4:
|
||||
resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==}
|
||||
dev: false
|
||||
@@ -12992,6 +13095,10 @@ packages:
|
||||
engines: {node: '>= 0.8'}
|
||||
dev: true
|
||||
|
||||
/std-env@3.7.0:
|
||||
resolution: {integrity: sha512-JPbdCEQLj1w5GilpiHAx3qJvFndqybBysA3qUOnznweH4QbNYUsW/ea8QzSrnh0vNsezMMw5bcVool8lM0gwzg==}
|
||||
dev: true
|
||||
|
||||
/stop-iteration-iterator@1.0.0:
|
||||
resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==}
|
||||
engines: {node: '>= 0.4'}
|
||||
@@ -13161,6 +13268,12 @@ packages:
|
||||
engines: {node: '>=8'}
|
||||
dev: true
|
||||
|
||||
/strip-literal@1.3.0:
|
||||
resolution: {integrity: sha512-PugKzOsyXpArk0yWmUwqOZecSO0GH0bPoctLcqNDH9J04pVW3lflYE0ujElBGTloevcxF5MofAOZ7C5l2b+wLg==}
|
||||
dependencies:
|
||||
acorn: 8.11.3
|
||||
dev: true
|
||||
|
||||
/stylis@4.2.0:
|
||||
resolution: {integrity: sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==}
|
||||
dev: false
|
||||
@@ -13311,6 +13424,15 @@ packages:
|
||||
/tiny-invariant@1.3.1:
|
||||
resolution: {integrity: sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==}
|
||||
|
||||
/tinybench@2.6.0:
|
||||
resolution: {integrity: sha512-N8hW3PG/3aOoZAN5V/NSAEDz0ZixDSSt5b/a05iqtpgfLWMSVuCo7w0k2vVvEjdrIoeGqZzweX2WlyioNIHchA==}
|
||||
dev: true
|
||||
|
||||
/tinypool@0.8.2:
|
||||
resolution: {integrity: sha512-SUszKYe5wgsxnNOVlBYO6IC+8VGWdVGZWAqUxp3UErNBtptZvWbwyUOyzNL59zigz2rCA92QiL3wvG+JDSdJdQ==}
|
||||
engines: {node: '>=14.0.0'}
|
||||
dev: true
|
||||
|
||||
/tinyspy@2.2.0:
|
||||
resolution: {integrity: sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==}
|
||||
engines: {node: '>=14.0.0'}
|
||||
@@ -13828,6 +13950,27 @@ packages:
|
||||
engines: {node: '>= 0.8'}
|
||||
dev: true
|
||||
|
||||
/vite-node@1.2.2(@types/node@20.11.5):
|
||||
resolution: {integrity: sha512-1as4rDTgVWJO3n1uHmUYqq7nsFgINQ9u+mRcXpjeOMJUmviqNKjcZB7UfRZrlM7MjYXMKpuWp5oGkjaFLnjawg==}
|
||||
engines: {node: ^18.0.0 || >=20.0.0}
|
||||
hasBin: true
|
||||
dependencies:
|
||||
cac: 6.7.14
|
||||
debug: 4.3.4
|
||||
pathe: 1.1.2
|
||||
picocolors: 1.0.0
|
||||
vite: 5.0.12(@types/node@20.11.5)
|
||||
transitivePeerDependencies:
|
||||
- '@types/node'
|
||||
- less
|
||||
- lightningcss
|
||||
- sass
|
||||
- stylus
|
||||
- sugarss
|
||||
- supports-color
|
||||
- terser
|
||||
dev: true
|
||||
|
||||
/vite-plugin-css-injected-by-js@3.3.1(vite@5.0.12):
|
||||
resolution: {integrity: sha512-PjM/X45DR3/V1K1fTRs8HtZHEQ55kIfdrn+dzaqNBFrOYO073SeSNCxp4j7gSYhV9NffVHaEnOL4myoko0ePAg==}
|
||||
peerDependencies:
|
||||
@@ -13926,6 +14069,63 @@ packages:
|
||||
fsevents: 2.3.3
|
||||
dev: true
|
||||
|
||||
/vitest@1.2.2(@types/node@20.11.5):
|
||||
resolution: {integrity: sha512-d5Ouvrnms3GD9USIK36KG8OZ5bEvKEkITFtnGv56HFaSlbItJuYr7hv2Lkn903+AvRAgSixiamozUVfORUekjw==}
|
||||
engines: {node: ^18.0.0 || >=20.0.0}
|
||||
hasBin: true
|
||||
peerDependencies:
|
||||
'@edge-runtime/vm': '*'
|
||||
'@types/node': ^18.0.0 || >=20.0.0
|
||||
'@vitest/browser': ^1.0.0
|
||||
'@vitest/ui': ^1.0.0
|
||||
happy-dom: '*'
|
||||
jsdom: '*'
|
||||
peerDependenciesMeta:
|
||||
'@edge-runtime/vm':
|
||||
optional: true
|
||||
'@types/node':
|
||||
optional: true
|
||||
'@vitest/browser':
|
||||
optional: true
|
||||
'@vitest/ui':
|
||||
optional: true
|
||||
happy-dom:
|
||||
optional: true
|
||||
jsdom:
|
||||
optional: true
|
||||
dependencies:
|
||||
'@types/node': 20.11.5
|
||||
'@vitest/expect': 1.2.2
|
||||
'@vitest/runner': 1.2.2
|
||||
'@vitest/snapshot': 1.2.2
|
||||
'@vitest/spy': 1.2.2
|
||||
'@vitest/utils': 1.2.2
|
||||
acorn-walk: 8.3.2
|
||||
cac: 6.7.14
|
||||
chai: 4.4.1
|
||||
debug: 4.3.4
|
||||
execa: 8.0.1
|
||||
local-pkg: 0.5.0
|
||||
magic-string: 0.30.5
|
||||
pathe: 1.1.2
|
||||
picocolors: 1.0.0
|
||||
std-env: 3.7.0
|
||||
strip-literal: 1.3.0
|
||||
tinybench: 2.6.0
|
||||
tinypool: 0.8.2
|
||||
vite: 5.0.12(@types/node@20.11.5)
|
||||
vite-node: 1.2.2(@types/node@20.11.5)
|
||||
why-is-node-running: 2.2.2
|
||||
transitivePeerDependencies:
|
||||
- less
|
||||
- lightningcss
|
||||
- sass
|
||||
- stylus
|
||||
- sugarss
|
||||
- supports-color
|
||||
- terser
|
||||
dev: true
|
||||
|
||||
/void-elements@3.1.0:
|
||||
resolution: {integrity: sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
@@ -14049,6 +14249,15 @@ packages:
|
||||
isexe: 2.0.0
|
||||
dev: true
|
||||
|
||||
/why-is-node-running@2.2.2:
|
||||
resolution: {integrity: sha512-6tSwToZxTOcotxHeA+qGCq1mVzKR3CwcJGmVcY+QE8SHy6TnpFnh8PAvPNHYr7EcuVeG0QSMxtYCuO1ta/G/oA==}
|
||||
engines: {node: '>=8'}
|
||||
hasBin: true
|
||||
dependencies:
|
||||
siginfo: 2.0.0
|
||||
stackback: 0.0.2
|
||||
dev: true
|
||||
|
||||
/wordwrap@1.0.0:
|
||||
resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==}
|
||||
dev: true
|
||||
@@ -14189,6 +14398,11 @@ packages:
|
||||
engines: {node: '>=10'}
|
||||
dev: true
|
||||
|
||||
/yocto-queue@1.0.0:
|
||||
resolution: {integrity: sha512-9bnSc/HEW2uRy67wc+T8UwauLuPJVn28jb+GtJY16iiKWyvmYJRXVT4UamsAEGQfPohgr2q4Tq0sQbQlxTfi1g==}
|
||||
engines: {node: '>=12.20'}
|
||||
dev: true
|
||||
|
||||
/z-schema@5.0.5:
|
||||
resolution: {integrity: sha512-D7eujBWkLa3p2sIpJA0d1pr7es+a7m0vFAnZLlCEKq/Ij2k0MLi9Br2UPxoxdYystm5K1yeBGzub0FlYUEWj2Q==}
|
||||
engines: {node: '>=8.0.0'}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { size } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
|
||||
@@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => {
|
||||
actionCreator: updateAllNodesRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const nodes = getState().nodes.nodes;
|
||||
const templates = getState().nodeTemplates.templates;
|
||||
const { nodes, templates } = getState().nodes;
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const { workflow, asCopy } = action.payload;
|
||||
const nodeTemplates = getState().nodeTemplates.templates;
|
||||
const nodeTemplates = getState().nodes.templates;
|
||||
|
||||
try {
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||
|
||||
@@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
@@ -46,7 +45,6 @@ const allReducers = {
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
[generationSlice.name]: generationSlice.reducer,
|
||||
[nodesSlice.name]: nodesSlice.reducer,
|
||||
[nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer,
|
||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import type { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import type { TypedUseSelectorHook } from 'react-redux';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||
export const useAppStore = () => useStore<RootState>();
|
||||
|
||||
2
invokeai/frontend/web/src/app/store/util.ts
Normal file
2
invokeai/frontend/web/src/app/store/util.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export const EMPTY_ARRAY = [];
|
||||
export const EMPTY_OBJECT = {};
|
||||
@@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
@@ -23,11 +22,10 @@ const selector = createMemoizedSelector(
|
||||
selectGenerationSlice,
|
||||
selectSystemSlice,
|
||||
selectNodesSlice,
|
||||
selectNodeTemplatesSlice,
|
||||
selectDynamicPromptsSlice,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => {
|
||||
(controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => {
|
||||
const { initialImage, model, positivePrompt } = generation;
|
||||
|
||||
const { isConnected } = system;
|
||||
@@ -54,7 +52,7 @@ const selector = createMemoizedSelector(
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
const nodeTemplate = nodes.templates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
|
||||
@@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { SelectInstance } from 'chakra-react-select';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import {
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverOpened,
|
||||
nodeAdded,
|
||||
selectNodesSlice,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import { filter, map, memoize, some } from 'lodash-es';
|
||||
import type { KeyboardEventHandler } from 'react';
|
||||
@@ -54,10 +58,10 @@ const AddNodePopover = () => {
|
||||
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
|
||||
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
|
||||
|
||||
const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const filteredNodeTemplates = fieldFilter
|
||||
? filter(nodeTemplates.templates, (template) => {
|
||||
? filter(nodes.templates, (template) => {
|
||||
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
return some(handles, (handle) => {
|
||||
@@ -67,7 +71,7 @@ const AddNodePopover = () => {
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
})
|
||||
: map(nodeTemplates.templates);
|
||||
: map(nodes.templates);
|
||||
|
||||
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
|
||||
return {
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
|
||||
const defaultReturnValue = {
|
||||
isSelected: false,
|
||||
shouldAnimate: false,
|
||||
stroke: colorTokenToCssVar('base.500'),
|
||||
};
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
source: string,
|
||||
sourceHandleId: string | null | undefined,
|
||||
@@ -12,14 +19,19 @@ export const makeEdgeSelector = (
|
||||
targetHandleId: string | null | undefined,
|
||||
selected?: boolean
|
||||
) =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
|
||||
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
|
||||
const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined;
|
||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||
if (!sourceNode || !sourceHandleId) {
|
||||
return defaultReturnValue;
|
||||
}
|
||||
|
||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { map } from 'lodash-es';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
@@ -13,7 +12,7 @@ interface Props {
|
||||
const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' };
|
||||
|
||||
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const { base600 } = useChakraThemeTokens();
|
||||
|
||||
const dummyHandleStyles: CSSProperties = useMemo(
|
||||
@@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
[dummyHandleStyles]
|
||||
);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
<>
|
||||
<Handle
|
||||
type="target"
|
||||
id={`${data.id}-collapsed-target`}
|
||||
id={`${nodeId}-collapsed-target`}
|
||||
isConnectable={false}
|
||||
position={Position.Left}
|
||||
style={collapsedTargetStyles}
|
||||
/>
|
||||
{map(data.inputs, (input) => (
|
||||
{map(template.inputs, (input) => (
|
||||
<Handle
|
||||
key={`${data.id}-${input.name}-collapsed-input-handle`}
|
||||
key={`${nodeId}-${input.name}-collapsed-input-handle`}
|
||||
type="target"
|
||||
id={input.name}
|
||||
isConnectable={false}
|
||||
@@ -62,14 +61,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||
))}
|
||||
<Handle
|
||||
type="source"
|
||||
id={`${data.id}-collapsed-source`}
|
||||
id={`${nodeId}-collapsed-source`}
|
||||
isConnectable={false}
|
||||
position={Position.Right}
|
||||
style={collapsedSourceStyles}
|
||||
/>
|
||||
{map(data.outputs, (output) => (
|
||||
{map(template.outputs, (output) => (
|
||||
<Handle
|
||||
key={`${data.id}-${output.name}-collapsed-output-handle`}
|
||||
key={`${nodeId}-${output.name}-collapsed-output-handle`}
|
||||
type="source"
|
||||
id={output.name}
|
||||
isConnectable={false}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { NodeProps } from 'reactflow';
|
||||
@@ -13,7 +13,7 @@ const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
|
||||
const { id: nodeId, type, isOpen, label } = data;
|
||||
|
||||
const hasTemplateSelector = useMemo(
|
||||
() => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])),
|
||||
() => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])),
|
||||
[type]
|
||||
);
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
isMissingInput?: boolean;
|
||||
withTooltip?: boolean;
|
||||
}
|
||||
@@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="input" /> : undefined}
|
||||
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Editable
|
||||
|
||||
@@ -6,7 +6,7 @@ import { memo } from 'react';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
isMissingInput?: boolean;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { isFieldInputInstance, isFieldInputTemplate } from 'features/nodes/types/field';
|
||||
@@ -9,11 +9,11 @@ import { useTranslation } from 'react-i18next';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
}
|
||||
|
||||
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||
const field = useFieldInstance(nodeId, fieldName);
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
|
||||
|
||||
@@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'input' });
|
||||
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||
|
||||
const isMissingInput = useMemo(() => {
|
||||
if (!fieldTemplate) {
|
||||
@@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
<EditableFieldTitle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
kind="inputs"
|
||||
isMissingInput={isMissingInput}
|
||||
withTooltip
|
||||
/>
|
||||
@@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
<EditableFieldTitle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
kind="input"
|
||||
kind="inputs"
|
||||
isMissingInput={isMissingInput}
|
||||
withTooltip
|
||||
/>
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Box, Text } from '@invoke-ai/ui-library';
|
||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
@@ -38,7 +37,6 @@ import {
|
||||
isVAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
@@ -63,17 +61,8 @@ type InputFieldProps = {
|
||||
};
|
||||
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldInstance = useFieldInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
|
||||
if (fieldTemplate?.fieldKind === 'output') {
|
||||
return (
|
||||
<Box p={2}>
|
||||
{t('nodes.outputFieldInInput')}: {fieldInstance?.type.name}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
@@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (fieldInstance && fieldTemplate) {
|
||||
if (fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={1}>
|
||||
<Text fontSize="sm" fontWeight="semibold" color="error.300">
|
||||
{t('nodes.unknownFieldType', { type: fieldInstance?.type.name })}
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InputFieldRenderer);
|
||||
|
||||
@@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
/>
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex alignItems="center">
|
||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
|
||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" />
|
||||
<Spacer />
|
||||
{isValueChanged && (
|
||||
<IconButton
|
||||
@@ -75,7 +75,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||
/>
|
||||
)}
|
||||
<Tooltip
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="input" />}
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" />}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
placement="top"
|
||||
>
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance';
|
||||
import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
@@ -18,18 +17,17 @@ interface Props {
|
||||
const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
||||
const fieldInstance = useFieldOutputInstance(nodeId, fieldName);
|
||||
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'output' });
|
||||
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
|
||||
|
||||
if (!fieldTemplate || !fieldInstance) {
|
||||
if (!fieldTemplate) {
|
||||
return (
|
||||
<OutputFieldWrapper shouldDim={shouldDim}>
|
||||
<FormControl alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
|
||||
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
|
||||
{t('nodes.unknownOutput', {
|
||||
name: fieldTemplate?.title ?? fieldName,
|
||||
name: fieldName,
|
||||
})}
|
||||
</FormLabel>
|
||||
</FormControl>
|
||||
@@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||
return (
|
||||
<OutputFieldWrapper shouldDim={shouldDim}>
|
||||
<Tooltip
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="output" />}
|
||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="outputs" />}
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
placement="top"
|
||||
shouldWrapChildren
|
||||
|
||||
@@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
||||
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import EditableNodeTitle from './details/EditableNodeTitle';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||
return;
|
||||
|
||||
@@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types';
|
||||
|
||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
||||
|
||||
|
||||
@@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
return {
|
||||
template: lastSelectedNodeTemplate,
|
||||
|
||||
@@ -1,26 +1,22 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
||||
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return [];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
return [];
|
||||
}
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
const fields = map(template.inputs).filter(
|
||||
(field) =>
|
||||
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
|
||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
|
||||
@@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||
};
|
||||
|
||||
export const useBuildNode = () => {
|
||||
const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates);
|
||||
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
|
||||
|
||||
const flow = useReactFlow();
|
||||
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||
import { keys, map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useConnectionInputFieldNames = (nodeId: string) => {
|
||||
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return [];
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
return [];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
// get the visible fields
|
||||
const fields = map(nodeTemplate.inputs).filter(
|
||||
const fields = map(template.inputs).filter(
|
||||
(field) =>
|
||||
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
|
||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||
|
||||
@@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector(
|
||||
export type UseConnectionStateProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
kind: 'input' | 'output';
|
||||
kind: 'inputs' | 'outputs';
|
||||
};
|
||||
|
||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||
@@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
Boolean(
|
||||
nodes.edges.filter((edge) => {
|
||||
return (
|
||||
(kind === 'input' ? edge.target : edge.source) === nodeId &&
|
||||
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName
|
||||
(kind === 'inputs' ? edge.target : edge.source) === nodeId &&
|
||||
(kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName
|
||||
);
|
||||
}).length
|
||||
)
|
||||
@@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
);
|
||||
|
||||
const selectConnectionError = useMemo(
|
||||
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType),
|
||||
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
||||
[nodeId, fieldName, kind, fieldType]
|
||||
);
|
||||
|
||||
@@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
||||
Boolean(
|
||||
nodes.connectionStartParams?.nodeId === nodeId &&
|
||||
nodes.connectionStartParams?.handleId === fieldName &&
|
||||
nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind]
|
||||
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
|
||||
)
|
||||
),
|
||||
[fieldName, kind, nodeId]
|
||||
|
||||
@@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||
export const useDoNodeVersionsMatch = (nodeId: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const data = selectNodeData(nodes, nodeId);
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template?.version || !data?.version) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
if (!nodeTemplate?.version || !node.data?.version) {
|
||||
return false;
|
||||
}
|
||||
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
||||
return compareVersions(template.version, data.version) === 0;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
const data = selectNodeData(nodes, nodeId);
|
||||
if (!data) {
|
||||
return false;
|
||||
}
|
||||
return node?.data.inputs[fieldName]?.value !== undefined;
|
||||
return data.inputs[fieldName]?.value !== undefined;
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInstance = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data.inputs[fieldName];
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldData = useAppSelector(selector);
|
||||
|
||||
return fieldData;
|
||||
};
|
||||
@@ -1,23 +1,20 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputInstance } from 'features/nodes/store/selectors';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInputInstance = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data.inputs[fieldName];
|
||||
return selectFieldInputInstance(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
const fieldData = useAppSelector(selector);
|
||||
|
||||
return fieldTemplate;
|
||||
return fieldData;
|
||||
};
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldInput } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
const fieldTemplate = nodeTemplate?.inputs[fieldName];
|
||||
return fieldTemplate?.input;
|
||||
createSelector(selectNodesSlice, (nodes): FieldInput | null => {
|
||||
const template = selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
return template?.input ?? null;
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldInputTemplate = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.inputs[fieldName];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputInstance } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldLabel = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldLabel = (nodeId: string, fieldName: string): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data.inputs[fieldName]?.label;
|
||||
return selectFieldInputInstance(nodes, nodeId, fieldName)?.label ?? null;
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldOutputInstance = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data.outputs[fieldName];
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const fieldTemplate = useAppSelector(selector);
|
||||
|
||||
return fieldTemplate;
|
||||
};
|
||||
@@ -1,20 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => {
|
||||
export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.outputs[fieldName];
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { KIND_MAP } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldTemplate = (nodeId: string, fieldName: string, kind: 'input' | 'output') => {
|
||||
export const useFieldTemplate = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
kind: 'inputs' | 'outputs'
|
||||
): FieldInputTemplate | FieldOutputTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.[KIND_MAP[kind]][fieldName];
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { KIND_MAP } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'input' | 'output') => {
|
||||
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title;
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null;
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { KIND_MAP } from 'features/nodes/types/constants';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldType = (nodeId: string, fieldName: string, kind: 'input' | 'output') => {
|
||||
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
if (kind === 'inputs') {
|
||||
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null;
|
||||
}
|
||||
const field = node.data[KIND_MAP[kind]][fieldName];
|
||||
return field?.type;
|
||||
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null;
|
||||
}),
|
||||
[fieldName, kind, nodeId]
|
||||
);
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
|
||||
const selector = createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) =>
|
||||
const selector = createSelector(selectNodesSlice, (nodes) =>
|
||||
nodes.nodes.filter(isInvocationNode).some((node) => {
|
||||
const template = nodeTemplates.templates[node.data.type];
|
||||
const template = nodes.templates[node.data.type];
|
||||
if (!template) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { some } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useHasImageOutput = (nodeId: string) => {
|
||||
export const useHasImageOutput = (nodeId: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
return some(
|
||||
node.data.outputs,
|
||||
template?.outputs,
|
||||
(output) =>
|
||||
output.type.name === 'ImageField' &&
|
||||
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
|
||||
node.data.type !== 'image'
|
||||
template?.type !== 'image'
|
||||
);
|
||||
}),
|
||||
[nodeId]
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useIsIntermediate = (nodeId: string) => {
|
||||
export const useIsIntermediate = (nodeId: string): boolean => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return node.data.isIntermediate;
|
||||
return selectNodeData(nodes, nodeId)?.isIntermediate ?? false;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
// TODO: enable this at some point
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useCallback } from 'react';
|
||||
import type { Connection, Node } from 'reactflow';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
@@ -13,39 +12,34 @@ import { useReactFlow } from 'reactflow';
|
||||
*/
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const flow = useReactFlow();
|
||||
const store = useAppStore();
|
||||
const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph);
|
||||
const isValidConnection = useCallback(
|
||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||
const edges = flow.getEdges();
|
||||
const nodes = flow.getNodes();
|
||||
// Connection must have valid targets
|
||||
if (!(source && sourceHandle && target && targetHandle)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = flow.getNode(source) as Node<InvocationNodeData>;
|
||||
const targetNode = flow.getNode(target) as Node<InvocationNodeData>;
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const sourceField = sourceNode.data.outputs[sourceHandle];
|
||||
const targetField = targetNode.data.inputs[targetHandle];
|
||||
|
||||
if (!sourceField || !targetField) {
|
||||
// something has gone terribly awry
|
||||
return false;
|
||||
}
|
||||
|
||||
if (source === target) {
|
||||
// Don't allow nodes to connect to themselves, even if validation is disabled
|
||||
return false;
|
||||
}
|
||||
|
||||
const state = store.getState();
|
||||
const { nodes, edges, templates } = state.nodes;
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
|
||||
const targetNode = nodes.find((node) => node.id === target) as Node<InvocationNodeData>;
|
||||
const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle];
|
||||
const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle];
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceFieldTemplate && targetFieldTemplate)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!shouldValidateGraph) {
|
||||
// manual override!
|
||||
return true;
|
||||
@@ -69,20 +63,20 @@ export const useIsValidConnection = () => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetField.type.name !== 'CollectionItemField'
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must use the originalType here if it exists
|
||||
if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) {
|
||||
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||
},
|
||||
[flow, shouldValidateGraph]
|
||||
[shouldValidateGraph, store]
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import type { Classification } from 'features/nodes/types/common';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeClassification = (nodeId: string) => {
|
||||
export const useNodeClassification = (nodeId: string): Classification | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate?.classification;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId)?.classification ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeData = (nodeId: string) => {
|
||||
export const useNodeData = (nodeId: string): InvocationNodeData | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
return node?.data;
|
||||
return selectNodeData(nodes, nodeId);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeLabel = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return node.data.label;
|
||||
return selectNodeData(nodes, nodeId)?.label ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeNeedsUpdate = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const template = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
if (isInvocationNode(node) && template) {
|
||||
return getNeedsUpdate(node, template);
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const node = selectInvocationNode(nodes, nodeId);
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!node || !template) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
return getNeedsUpdate(node, template);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodePack = (nodeId: string) => {
|
||||
export const useNodePack = (nodeId: string): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return node.data.nodePack;
|
||||
return selectNodeData(nodes, nodeId)?.nodePack ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplate = (nodeId: string) => {
|
||||
export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
||||
return nodeTemplate;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplateByType = (type: string) => {
|
||||
export const useNodeTemplateByType = (type: string): InvocationTemplate | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates): InvocationTemplate | undefined => {
|
||||
return nodeTemplates.templates[type];
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return nodes.templates[type] ?? null;
|
||||
}),
|
||||
[type]
|
||||
);
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useNodeTemplateTitle = (nodeId: string) => {
|
||||
export const useNodeTemplateTitle = (nodeId: string): string | null => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = node ? nodeTemplates.templates[node.data.type] : undefined;
|
||||
|
||||
return nodeTemplate?.title;
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNodeTemplate(nodes, nodeId)?.title ?? null;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||
import { map } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
@@ -10,17 +10,13 @@ import { useMemo } from 'react';
|
||||
export const useOutputFieldNames = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return [];
|
||||
}
|
||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
return [];
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const template = selectNodeTemplate(nodes, nodeId);
|
||||
if (!template) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return getSortedFilteredFieldNames(map(nodeTemplate.outputs));
|
||||
return getSortedFilteredFieldNames(map(template.outputs));
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useUseCache = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
return node.data.useCache;
|
||||
return selectNodeData(nodes, nodeId)?.useCache ?? false;
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
@@ -2,14 +2,14 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import { useEffect } from 'react';
|
||||
|
||||
export const $builtWorkflow = atom<WorkflowV2 | null>(null);
|
||||
export const $builtWorkflow = atom<WorkflowV3 | null>(null);
|
||||
|
||||
const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => {
|
||||
$builtWorkflow.set(buildWorkflowFast(arg));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { Graph } from 'services/api/types';
|
||||
|
||||
export const textToImageGraphBuilt = createAction<Graph>('nodes/textToImageGraphBuilt');
|
||||
@@ -21,4 +21,4 @@ export const workflowLoadRequested = createAction<{
|
||||
|
||||
export const updateAllNodesRequested = createAction('nodes/updateAllNodesRequested');
|
||||
|
||||
export const workflowLoaded = createAction<WorkflowV2>('workflow/workflowLoaded');
|
||||
export const workflowLoaded = createAction<WorkflowV3>('workflow/workflowLoaded');
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
|
||||
import type { NodeTemplatesState } from './types';
|
||||
|
||||
export const initialNodeTemplatesState: NodeTemplatesState = {
|
||||
templates: {},
|
||||
};
|
||||
|
||||
export const nodesTemplatesSlice = createSlice({
|
||||
name: 'nodeTemplates',
|
||||
initialState: initialNodeTemplatesState,
|
||||
reducers: {
|
||||
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
||||
state.templates = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions;
|
||||
|
||||
export const selectNodeTemplatesSlice = (state: RootState) => state.nodeTemplates;
|
||||
@@ -42,7 +42,7 @@ import {
|
||||
zT2IAdapterModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import type {
|
||||
@@ -92,6 +92,7 @@ export const initialNodesState: NodesState = {
|
||||
_version: 1,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
templates: {},
|
||||
connectionStartParams: null,
|
||||
connectionStartFieldType: null,
|
||||
connectionMade: false,
|
||||
@@ -190,6 +191,7 @@ export const nodesSlice = createSlice({
|
||||
node,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
@@ -224,12 +226,12 @@ export const nodesSlice = createSlice({
|
||||
if (!nodeId || !handleId) {
|
||||
return;
|
||||
}
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const field = handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId];
|
||||
const template = state.templates[node.data.type];
|
||||
const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId];
|
||||
state.connectionStartFieldType = field?.type ?? null;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
@@ -260,6 +262,7 @@ export const nodesSlice = createSlice({
|
||||
mouseOverNode,
|
||||
state.nodes,
|
||||
state.edges,
|
||||
state.templates,
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
@@ -677,6 +680,9 @@ export const nodesSlice = createSlice({
|
||||
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial;
|
||||
},
|
||||
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
|
||||
state.templates = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
@@ -808,6 +814,7 @@ export const {
|
||||
shouldValidateGraphChanged,
|
||||
viewportChanged,
|
||||
edgeAdded,
|
||||
nodeTemplatesBuilt,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
// This is used for tracking `state.workflow.isTouched`
|
||||
|
||||
51
invokeai/frontend/web/src/features/nodes/store/selectors.ts
Normal file
51
invokeai/frontend/web/src/features/nodes/store/selectors.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
return node;
|
||||
};
|
||||
|
||||
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => {
|
||||
return selectInvocationNode(nodesSlice, nodeId)?.data ?? null;
|
||||
};
|
||||
|
||||
export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => {
|
||||
const node = selectInvocationNode(nodesSlice, nodeId);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
return nodesSlice.templates[node.data.type] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldInputInstance = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldInputInstance | null => {
|
||||
const data = selectNodeData(nodesSlice, nodeId);
|
||||
return data?.inputs[fieldName] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldInputTemplate = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldInputTemplate | null => {
|
||||
const template = selectNodeTemplate(nodesSlice, nodeId);
|
||||
return template?.inputs[fieldName] ?? null;
|
||||
};
|
||||
|
||||
export const selectFieldOutputTemplate = (
|
||||
nodesSlice: NodesState,
|
||||
nodeId: string,
|
||||
fieldName: string
|
||||
): FieldOutputTemplate | null => {
|
||||
const template = selectNodeTemplate(nodesSlice, nodeId);
|
||||
return template?.outputs[fieldName] ?? null;
|
||||
};
|
||||
@@ -5,13 +5,14 @@ import type {
|
||||
InvocationTemplate,
|
||||
NodeExecutionState,
|
||||
} from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow';
|
||||
|
||||
export type NodesState = {
|
||||
_version: 1;
|
||||
nodes: AnyNode[];
|
||||
edges: InvocationNodeEdge[];
|
||||
templates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
connectionStartFieldType: FieldType | null;
|
||||
connectionMade: boolean;
|
||||
@@ -38,7 +39,7 @@ export type FieldIdentifierWithValue = FieldIdentifier & {
|
||||
value: StatefulFieldValue;
|
||||
};
|
||||
|
||||
export type WorkflowsState = Omit<WorkflowV2, 'nodes' | 'edges'> & {
|
||||
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
|
||||
_version: 1;
|
||||
isTouched: boolean;
|
||||
mode: WorkflowMode;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import type { FieldInputInstance, FieldOutputInstance, FieldType } from 'features/nodes/types/field';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection, Edge, HandleType, Node } from 'reactflow';
|
||||
|
||||
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
|
||||
@@ -9,7 +11,7 @@ const isValidConnection = (
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType,
|
||||
node: Node,
|
||||
handle: FieldInputInstance | FieldOutputInstance
|
||||
handle: FieldInputTemplate | FieldOutputTemplate
|
||||
) => {
|
||||
let isValidConnection = true;
|
||||
if (handleCurrentType === 'source') {
|
||||
@@ -38,24 +40,31 @@ const isValidConnection = (
|
||||
};
|
||||
|
||||
export const findConnectionToValidHandle = (
|
||||
node: Node,
|
||||
nodes: Node[],
|
||||
edges: Edge[],
|
||||
node: AnyNode,
|
||||
nodes: AnyNode[],
|
||||
edges: InvocationNodeEdge[],
|
||||
templates: Record<string, InvocationTemplate>,
|
||||
handleCurrentNodeId: string,
|
||||
handleCurrentName: string,
|
||||
handleCurrentType: HandleType,
|
||||
handleCurrentFieldType: FieldType
|
||||
): Connection | null => {
|
||||
if (node.id === handleCurrentNodeId) {
|
||||
if (node.id === handleCurrentNodeId || !isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handles = handleCurrentType === 'source' ? node.data.inputs : node.data.outputs;
|
||||
const template = templates[node.data.type];
|
||||
|
||||
if (!template) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handles = handleCurrentType === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
//Prioritize handles whos name matches the node we're coming from
|
||||
if (handles[handleCurrentName]) {
|
||||
const handle = handles[handleCurrentName];
|
||||
const handle = handles[handleCurrentName];
|
||||
|
||||
if (handle) {
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name;
|
||||
@@ -77,6 +86,9 @@ export const findConnectionToValidHandle = (
|
||||
|
||||
for (const handleName in handles) {
|
||||
const handle = handles[handleName];
|
||||
if (!handle) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id;
|
||||
const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId;
|
||||
|
||||
@@ -16,7 +16,7 @@ export const makeConnectionErrorSelector = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
handleType: HandleType,
|
||||
fieldType?: FieldType
|
||||
fieldType?: FieldType | null
|
||||
) => {
|
||||
return createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
if (!fieldType) {
|
||||
|
||||
@@ -10,10 +10,10 @@ import type {
|
||||
} from 'features/nodes/store/types';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es';
|
||||
|
||||
export const blankWorkflow: Omit<WorkflowV2, 'nodes' | 'edges'> = {
|
||||
export const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
name: '',
|
||||
author: '',
|
||||
description: '',
|
||||
@@ -22,7 +22,7 @@ export const blankWorkflow: Omit<WorkflowV2, 'nodes' | 'edges'> = {
|
||||
tags: '',
|
||||
notes: '',
|
||||
exposedFields: [],
|
||||
meta: { version: '2.0.0', category: 'user' },
|
||||
meta: { version: '3.0.0', category: 'user' },
|
||||
id: undefined,
|
||||
};
|
||||
|
||||
|
||||
@@ -56,3 +56,8 @@ export class FieldParseError extends Error {
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
export class UnableToExtractSchemaNameFromRefError extends FieldParseError {}
|
||||
export class UnsupportedArrayItemType extends FieldParseError {}
|
||||
export class UnsupportedUnionError extends FieldParseError {}
|
||||
export class UnsupportedPrimitiveTypeError extends FieldParseError {}
|
||||
@@ -46,20 +46,11 @@ export type FieldInput = z.infer<typeof zFieldInput>;
|
||||
export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']);
|
||||
export type FieldUIComponent = z.infer<typeof zFieldUIComponent>;
|
||||
|
||||
export const zFieldInstanceBase = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
export const zFieldInputInstanceBase = z.object({
|
||||
name: z.string().trim().min(1),
|
||||
});
|
||||
export const zFieldInputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('input'),
|
||||
label: z.string().nullish(),
|
||||
});
|
||||
export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
});
|
||||
export type FieldInstanceBase = z.infer<typeof zFieldInstanceBase>;
|
||||
export type FieldInputInstanceBase = z.infer<typeof zFieldInputInstanceBase>;
|
||||
export type FieldOutputInstanceBase = z.infer<typeof zFieldOutputInstanceBase>;
|
||||
|
||||
export const zFieldTemplateBase = z.object({
|
||||
name: z.string().min(1),
|
||||
@@ -102,12 +93,8 @@ export const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zIntegerFieldValue = z.number().int();
|
||||
export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
value: zIntegerFieldValue,
|
||||
});
|
||||
export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
});
|
||||
export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
default: zIntegerFieldValue,
|
||||
@@ -136,12 +123,8 @@ export const zFloatFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zFloatFieldValue = z.number();
|
||||
export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
value: zFloatFieldValue,
|
||||
});
|
||||
export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
});
|
||||
export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
default: zFloatFieldValue,
|
||||
@@ -157,7 +140,6 @@ export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type FloatFieldType = z.infer<typeof zFloatFieldType>;
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
export type FloatFieldOutputInstance = z.infer<typeof zFloatFieldOutputInstance>;
|
||||
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
|
||||
export type FloatFieldOutputTemplate = z.infer<typeof zFloatFieldOutputTemplate>;
|
||||
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
|
||||
@@ -172,12 +154,8 @@ export const zStringFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zStringFieldValue = z.string();
|
||||
export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
value: zStringFieldValue,
|
||||
});
|
||||
export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
default: zStringFieldValue,
|
||||
@@ -191,7 +169,6 @@ export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type StringFieldType = z.infer<typeof zStringFieldType>;
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldOutputInstance = z.infer<typeof zStringFieldOutputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
export type StringFieldOutputTemplate = z.infer<typeof zStringFieldOutputTemplate>;
|
||||
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
|
||||
@@ -206,12 +183,8 @@ export const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zBooleanFieldValue = z.boolean();
|
||||
export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
value: zBooleanFieldValue,
|
||||
});
|
||||
export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
});
|
||||
export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
default: zBooleanFieldValue,
|
||||
@@ -222,7 +195,6 @@ export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BooleanFieldType = z.infer<typeof zBooleanFieldType>;
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
export type BooleanFieldOutputInstance = z.infer<typeof zBooleanFieldOutputInstance>;
|
||||
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
|
||||
export type BooleanFieldOutputTemplate = z.infer<typeof zBooleanFieldOutputTemplate>;
|
||||
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
|
||||
@@ -237,12 +209,8 @@ export const zEnumFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zEnumFieldValue = z.string();
|
||||
export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
value: zEnumFieldValue,
|
||||
});
|
||||
export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
});
|
||||
export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
default: zEnumFieldValue,
|
||||
@@ -255,7 +223,6 @@ export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type EnumFieldType = z.infer<typeof zEnumFieldType>;
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
export type EnumFieldOutputInstance = z.infer<typeof zEnumFieldOutputInstance>;
|
||||
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
|
||||
export type EnumFieldOutputTemplate = z.infer<typeof zEnumFieldOutputTemplate>;
|
||||
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
|
||||
@@ -270,12 +237,8 @@ export const zImageFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zImageFieldValue = zImageField.optional();
|
||||
export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
value: zImageFieldValue,
|
||||
});
|
||||
export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
});
|
||||
export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
default: zImageFieldValue,
|
||||
@@ -286,7 +249,6 @@ export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ImageFieldType = z.infer<typeof zImageFieldType>;
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
export type ImageFieldOutputInstance = z.infer<typeof zImageFieldOutputInstance>;
|
||||
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
|
||||
export type ImageFieldOutputTemplate = z.infer<typeof zImageFieldOutputTemplate>;
|
||||
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
|
||||
@@ -301,12 +263,8 @@ export const zBoardFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zBoardFieldValue = zBoardField.optional();
|
||||
export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
value: zBoardFieldValue,
|
||||
});
|
||||
export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
});
|
||||
export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
default: zBoardFieldValue,
|
||||
@@ -317,7 +275,6 @@ export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BoardFieldType = z.infer<typeof zBoardFieldType>;
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
export type BoardFieldOutputInstance = z.infer<typeof zBoardFieldOutputInstance>;
|
||||
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
|
||||
export type BoardFieldOutputTemplate = z.infer<typeof zBoardFieldOutputTemplate>;
|
||||
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
|
||||
@@ -332,12 +289,8 @@ export const zColorFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zColorFieldValue = zColorField.optional();
|
||||
export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
value: zColorFieldValue,
|
||||
});
|
||||
export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
});
|
||||
export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
default: zColorFieldValue,
|
||||
@@ -348,7 +301,6 @@ export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ColorFieldType = z.infer<typeof zColorFieldType>;
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
export type ColorFieldOutputInstance = z.infer<typeof zColorFieldOutputInstance>;
|
||||
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
|
||||
export type ColorFieldOutputTemplate = z.infer<typeof zColorFieldOutputTemplate>;
|
||||
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
|
||||
@@ -363,12 +315,8 @@ export const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zMainModelFieldValue = zMainModelField.optional();
|
||||
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
});
|
||||
export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
default: zMainModelFieldValue,
|
||||
@@ -379,7 +327,6 @@ export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type MainModelFieldType = z.infer<typeof zMainModelFieldType>;
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
export type MainModelFieldOutputInstance = z.infer<typeof zMainModelFieldOutputInstance>;
|
||||
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
|
||||
export type MainModelFieldOutputTemplate = z.infer<typeof zMainModelFieldOutputTemplate>;
|
||||
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
|
||||
@@ -394,12 +341,8 @@ export const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
value: zSDXLMainModelFieldValue,
|
||||
});
|
||||
export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
});
|
||||
export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
default: zSDXLMainModelFieldValue,
|
||||
@@ -410,7 +353,6 @@ export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend
|
||||
export type SDXLMainModelFieldType = z.infer<typeof zSDXLMainModelFieldType>;
|
||||
export type SDXLMainModelFieldValue = z.infer<typeof zSDXLMainModelFieldValue>;
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldOutputInstance = z.infer<typeof zSDXLMainModelFieldOutputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
export type SDXLMainModelFieldOutputTemplate = z.infer<typeof zSDXLMainModelFieldOutputTemplate>;
|
||||
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
|
||||
@@ -425,12 +367,8 @@ export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
value: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
default: zSDXLRefinerModelFieldValue,
|
||||
@@ -441,7 +379,6 @@ export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.ext
|
||||
export type SDXLRefinerModelFieldType = z.infer<typeof zSDXLRefinerModelFieldType>;
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
export type SDXLRefinerModelFieldOutputInstance = z.infer<typeof zSDXLRefinerModelFieldOutputInstance>;
|
||||
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
|
||||
export type SDXLRefinerModelFieldOutputTemplate = z.infer<typeof zSDXLRefinerModelFieldOutputTemplate>;
|
||||
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
|
||||
@@ -456,12 +393,8 @@ export const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zVAEModelFieldValue = zVAEModelField.optional();
|
||||
export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
});
|
||||
export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
default: zVAEModelFieldValue,
|
||||
@@ -472,7 +405,6 @@ export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type VAEModelFieldType = z.infer<typeof zVAEModelFieldType>;
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
export type VAEModelFieldOutputInstance = z.infer<typeof zVAEModelFieldOutputInstance>;
|
||||
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
|
||||
export type VAEModelFieldOutputTemplate = z.infer<typeof zVAEModelFieldOutputTemplate>;
|
||||
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
|
||||
@@ -487,12 +419,8 @@ export const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zLoRAModelFieldValue = zLoRAModelField.optional();
|
||||
export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
});
|
||||
export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
default: zLoRAModelFieldValue,
|
||||
@@ -503,7 +431,6 @@ export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type LoRAModelFieldType = z.infer<typeof zLoRAModelFieldType>;
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
export type LoRAModelFieldOutputInstance = z.infer<typeof zLoRAModelFieldOutputInstance>;
|
||||
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
|
||||
export type LoRAModelFieldOutputTemplate = z.infer<typeof zLoRAModelFieldOutputTemplate>;
|
||||
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
|
||||
@@ -518,12 +445,8 @@ export const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zControlNetModelFieldValue = zControlNetModelField.optional();
|
||||
export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
});
|
||||
export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
default: zControlNetModelFieldValue,
|
||||
@@ -534,7 +457,6 @@ export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.exte
|
||||
export type ControlNetModelFieldType = z.infer<typeof zControlNetModelFieldType>;
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
export type ControlNetModelFieldOutputInstance = z.infer<typeof zControlNetModelFieldOutputInstance>;
|
||||
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
|
||||
export type ControlNetModelFieldOutputTemplate = z.infer<typeof zControlNetModelFieldOutputTemplate>;
|
||||
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
|
||||
@@ -551,12 +473,8 @@ export const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional();
|
||||
export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
});
|
||||
export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
default: zIPAdapterModelFieldValue,
|
||||
@@ -567,7 +485,6 @@ export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exten
|
||||
export type IPAdapterModelFieldType = z.infer<typeof zIPAdapterModelFieldType>;
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
export type IPAdapterModelFieldOutputInstance = z.infer<typeof zIPAdapterModelFieldOutputInstance>;
|
||||
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
|
||||
export type IPAdapterModelFieldOutputTemplate = z.infer<typeof zIPAdapterModelFieldOutputTemplate>;
|
||||
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
|
||||
@@ -584,12 +501,8 @@ export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional();
|
||||
export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
});
|
||||
export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
default: zT2IAdapterModelFieldValue,
|
||||
@@ -600,7 +513,6 @@ export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exte
|
||||
export type T2IAdapterModelFieldType = z.infer<typeof zT2IAdapterModelFieldType>;
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
export type T2IAdapterModelFieldOutputInstance = z.infer<typeof zT2IAdapterModelFieldOutputInstance>;
|
||||
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
|
||||
export type T2IAdapterModelFieldOutputTemplate = z.infer<typeof zT2IAdapterModelFieldOutputTemplate>;
|
||||
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
|
||||
@@ -615,12 +527,8 @@ export const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
value: zSchedulerFieldValue,
|
||||
});
|
||||
export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
});
|
||||
export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
default: zSchedulerFieldValue,
|
||||
@@ -631,7 +539,6 @@ export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type SchedulerFieldType = z.infer<typeof zSchedulerFieldType>;
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
export type SchedulerFieldOutputInstance = z.infer<typeof zSchedulerFieldOutputInstance>;
|
||||
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
|
||||
export type SchedulerFieldOutputTemplate = z.infer<typeof zSchedulerFieldOutputTemplate>;
|
||||
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
|
||||
@@ -657,12 +564,8 @@ export const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
});
|
||||
export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
|
||||
export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
value: zStatelessFieldValue,
|
||||
});
|
||||
export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
});
|
||||
export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
default: zStatelessFieldValue,
|
||||
@@ -675,7 +578,6 @@ export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type StatelessFieldType = z.infer<typeof zStatelessFieldType>;
|
||||
export type StatelessFieldValue = z.infer<typeof zStatelessFieldValue>;
|
||||
export type StatelessFieldInputInstance = z.infer<typeof zStatelessFieldInputInstance>;
|
||||
export type StatelessFieldOutputInstance = z.infer<typeof zStatelessFieldOutputInstance>;
|
||||
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
||||
export type StatelessFieldOutputTemplate = z.infer<typeof zStatelessFieldOutputTemplate>;
|
||||
// #endregion
|
||||
@@ -783,36 +685,6 @@ export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
|
||||
zFieldInputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputInstance & FieldOutputInstance
|
||||
export const zStatefulFieldOutputInstance = z.union([
|
||||
zIntegerFieldOutputInstance,
|
||||
zFloatFieldOutputInstance,
|
||||
zStringFieldOutputInstance,
|
||||
zBooleanFieldOutputInstance,
|
||||
zEnumFieldOutputInstance,
|
||||
zImageFieldOutputInstance,
|
||||
zBoardFieldOutputInstance,
|
||||
zMainModelFieldOutputInstance,
|
||||
zSDXLMainModelFieldOutputInstance,
|
||||
zSDXLRefinerModelFieldOutputInstance,
|
||||
zVAEModelFieldOutputInstance,
|
||||
zLoRAModelFieldOutputInstance,
|
||||
zControlNetModelFieldOutputInstance,
|
||||
zIPAdapterModelFieldOutputInstance,
|
||||
zT2IAdapterModelFieldOutputInstance,
|
||||
zColorFieldOutputInstance,
|
||||
zSchedulerFieldOutputInstance,
|
||||
]);
|
||||
export type StatefulFieldOutputInstance = z.infer<typeof zStatefulFieldOutputInstance>;
|
||||
export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance =>
|
||||
zStatefulFieldOutputInstance.safeParse(val).success;
|
||||
|
||||
export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]);
|
||||
export type FieldOutputInstance = z.infer<typeof zFieldOutputInstance>;
|
||||
export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance =>
|
||||
zFieldOutputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputTemplate & FieldInputTemplate
|
||||
export const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { Edge, Node } from 'reactflow';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zClassification, zProgressImage } from './common';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field';
|
||||
import { zSemVer } from './semver';
|
||||
|
||||
// #region InvocationTemplate
|
||||
@@ -25,16 +25,15 @@ export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
||||
// #region NodeData
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.string().trim().min(1),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean(),
|
||||
version: zSemVer,
|
||||
nodePack: z.string().min(1).nullish(),
|
||||
label: z.string(),
|
||||
notes: z.string(),
|
||||
type: z.string().trim().min(1),
|
||||
inputs: z.record(zFieldInputInstance),
|
||||
outputs: z.record(zFieldOutputInstance),
|
||||
isOpen: z.boolean(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean(),
|
||||
});
|
||||
|
||||
export const zNotesNodeData = z.object({
|
||||
@@ -62,11 +61,12 @@ export type NotesNode = Node<NotesNodeData, 'notes'>;
|
||||
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
|
||||
export type AnyNode = Node<AnyNodeData>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode =>
|
||||
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isCurrentImageNode = (node?: AnyNode | null): node is CurrentImageNode =>
|
||||
Boolean(node && node.type === 'current_image');
|
||||
export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData =>
|
||||
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
// #endregion
|
||||
|
||||
|
||||
188
invokeai/frontend/web/src/features/nodes/types/v2/common.ts
Normal file
188
invokeai/frontend/web/src/features/nodes/types/v2/common.ts
Normal file
@@ -0,0 +1,188 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
// #region Field data schemas
|
||||
export const zImageField = z.object({
|
||||
image_name: z.string().trim().min(1),
|
||||
});
|
||||
export type ImageField = z.infer<typeof zImageField>;
|
||||
|
||||
export const zBoardField = z.object({
|
||||
board_id: z.string().trim().min(1),
|
||||
});
|
||||
export type BoardField = z.infer<typeof zBoardField>;
|
||||
|
||||
export const zColorField = z.object({
|
||||
r: z.number().int().min(0).max(255),
|
||||
g: z.number().int().min(0).max(255),
|
||||
b: z.number().int().min(0).max(255),
|
||||
a: z.number().int().min(0).max(255),
|
||||
});
|
||||
export type ColorField = z.infer<typeof zColorField>;
|
||||
|
||||
export const zClassification = z.enum(['stable', 'beta', 'prototype']);
|
||||
export type Classification = z.infer<typeof zClassification>;
|
||||
|
||||
export const zSchedulerField = z.enum([
|
||||
'euler',
|
||||
'deis',
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_sde',
|
||||
'heun',
|
||||
'kdpm_2',
|
||||
'lms',
|
||||
'pndm',
|
||||
'unipc',
|
||||
'euler_k',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde_k',
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
'kdpm_2_a',
|
||||
'lcm',
|
||||
]);
|
||||
export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']);
|
||||
export const zModelName = z.string().min(3);
|
||||
export const zModelIdentifier = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||
export type ModelType = z.infer<typeof zModelType>;
|
||||
export type ModelIdentifier = z.infer<typeof zModelIdentifier>;
|
||||
|
||||
export const zMainModelField = z.object({
|
||||
model_name: zModelName,
|
||||
base_model: zBaseModel,
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export const zSDXLRefinerModelField = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: z.literal('sdxl-refiner'),
|
||||
model_type: z.literal('main'),
|
||||
});
|
||||
export type MainModelField = z.infer<typeof zMainModelField>;
|
||||
export type SDXLRefinerModelField = z.infer<typeof zSDXLRefinerModelField>;
|
||||
|
||||
export const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
export type SubModelType = z.infer<typeof zSubModelType>;
|
||||
|
||||
export const zVAEModelField = zModelIdentifier;
|
||||
|
||||
export const zModelInfo = zModelIdentifier.extend({
|
||||
model_type: zModelType,
|
||||
submodel: zSubModelType.optional(),
|
||||
});
|
||||
export type ModelInfo = z.infer<typeof zModelInfo>;
|
||||
|
||||
export const zLoRAModelField = zModelIdentifier;
|
||||
export type LoRAModelField = z.infer<typeof zLoRAModelField>;
|
||||
|
||||
export const zControlNetModelField = zModelIdentifier;
|
||||
export type ControlNetModelField = z.infer<typeof zControlNetModelField>;
|
||||
|
||||
export const zIPAdapterModelField = zModelIdentifier;
|
||||
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
|
||||
|
||||
export const zT2IAdapterModelField = zModelIdentifier;
|
||||
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
|
||||
|
||||
export const zLoraInfo = zModelInfo.extend({
|
||||
weight: z.number().optional(),
|
||||
});
|
||||
export type LoraInfo = z.infer<typeof zLoraInfo>;
|
||||
|
||||
export const zUNetField = z.object({
|
||||
unet: zModelInfo,
|
||||
scheduler: zModelInfo,
|
||||
loras: z.array(zLoraInfo),
|
||||
});
|
||||
export type UNetField = z.infer<typeof zUNetField>;
|
||||
|
||||
export const zCLIPField = z.object({
|
||||
tokenizer: zModelInfo,
|
||||
text_encoder: zModelInfo,
|
||||
skipped_layers: z.number(),
|
||||
loras: z.array(zLoraInfo),
|
||||
});
|
||||
export type CLIPField = z.infer<typeof zCLIPField>;
|
||||
|
||||
export const zVAEField = z.object({
|
||||
vae: zModelInfo,
|
||||
});
|
||||
export type VAEField = z.infer<typeof zVAEField>;
|
||||
// #endregion
|
||||
|
||||
// #region Control Adapters
|
||||
export const zControlField = z.object({
|
||||
image: zImageField,
|
||||
control_model: zControlNetModelField,
|
||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
control_mode: z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']).optional(),
|
||||
resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(),
|
||||
});
|
||||
export type ControlField = z.infer<typeof zControlField>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: zIPAdapterModelField,
|
||||
weight: z.number(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
});
|
||||
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||
|
||||
export const zT2IAdapterField = z.object({
|
||||
image: zImageField,
|
||||
t2i_adapter_model: zT2IAdapterModelField,
|
||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(),
|
||||
});
|
||||
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
|
||||
// #endregion
|
||||
|
||||
// #region ProgressImage
|
||||
export const zProgressImage = z.object({
|
||||
dataURL: z.string(),
|
||||
width: z.number().int(),
|
||||
height: z.number().int(),
|
||||
});
|
||||
export type ProgressImage = z.infer<typeof zProgressImage>;
|
||||
// #endregion
|
||||
|
||||
// #region ImageOutput
|
||||
export const zImageOutput = z.object({
|
||||
image: zImageField,
|
||||
width: z.number().int().gt(0),
|
||||
height: z.number().int().gt(0),
|
||||
type: z.literal('image_output'),
|
||||
});
|
||||
export type ImageOutput = z.infer<typeof zImageOutput>;
|
||||
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
|
||||
// #endregion
|
||||
@@ -0,0 +1,80 @@
|
||||
import type { Node } from 'reactflow';
|
||||
|
||||
/**
|
||||
* How long to wait before showing a tooltip when hovering a field handle.
|
||||
*/
|
||||
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
||||
|
||||
/**
|
||||
* The width of a node in the UI in pixels.
|
||||
*/
|
||||
export const NODE_WIDTH = 320;
|
||||
|
||||
/**
|
||||
* This class name is special - reactflow uses it to identify the drag handle of a node,
|
||||
* applying the appropriate listeners to it.
|
||||
*/
|
||||
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
|
||||
|
||||
/**
|
||||
* reactflow-specifc properties shared between all node types.
|
||||
*/
|
||||
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper for getting the kind of a field.
|
||||
*/
|
||||
export const KIND_MAP = {
|
||||
input: 'inputs' as const,
|
||||
output: 'outputs' as const,
|
||||
};
|
||||
|
||||
/**
|
||||
* Model types' handles are rendered as squares in the UI.
|
||||
*/
|
||||
export const MODEL_TYPES = [
|
||||
'IPAdapterModelField',
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'ClipField',
|
||||
'T2IAdapterModelField',
|
||||
'IPAdapterModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
* Colors for each field type - applies to their handles and edges.
|
||||
*/
|
||||
export const FIELD_COLORS: { [key: string]: string } = {
|
||||
BoardField: 'purple.500',
|
||||
BooleanField: 'green.500',
|
||||
ClipField: 'green.500',
|
||||
ColorField: 'pink.300',
|
||||
ConditioningField: 'cyan.500',
|
||||
ControlField: 'teal.500',
|
||||
ControlNetModelField: 'teal.500',
|
||||
EnumField: 'blue.500',
|
||||
FloatField: 'orange.500',
|
||||
ImageField: 'purple.500',
|
||||
IntegerField: 'red.500',
|
||||
IPAdapterField: 'teal.500',
|
||||
IPAdapterModelField: 'teal.500',
|
||||
LatentsField: 'pink.500',
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
UNetField: 'red.500',
|
||||
VaeField: 'blue.500',
|
||||
VaeModelField: 'teal.500',
|
||||
};
|
||||
58
invokeai/frontend/web/src/features/nodes/types/v2/error.ts
Normal file
58
invokeai/frontend/web/src/features/nodes/types/v2/error.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
/**
|
||||
* Invalid Workflow Version Error
|
||||
* Raised when a workflow version is not recognized.
|
||||
*/
|
||||
export class WorkflowVersionError extends Error {
|
||||
/**
|
||||
* Create WorkflowVersionError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Workflow Migration Error
|
||||
* Raised when a workflow migration fails.
|
||||
*/
|
||||
export class WorkflowMigrationError extends Error {
|
||||
/**
|
||||
* Create WorkflowMigrationError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unable to Update Node Error
|
||||
* Raised when a node cannot be updated.
|
||||
*/
|
||||
export class NodeUpdateError extends Error {
|
||||
/**
|
||||
* Create NodeUpdateError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldParseError
|
||||
* Raised when a field cannot be parsed from a field schema.
|
||||
*/
|
||||
export class FieldParseError extends Error {
|
||||
/**
|
||||
* Create FieldTypeParseError
|
||||
* @param {String} message
|
||||
*/
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
875
invokeai/frontend/web/src/features/nodes/types/v2/field.ts
Normal file
875
invokeai/frontend/web/src/features/nodes/types/v2/field.ts
Normal file
@@ -0,0 +1,875 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
zBoardField,
|
||||
zColorField,
|
||||
zControlNetModelField,
|
||||
zImageField,
|
||||
zIPAdapterModelField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zSchedulerField,
|
||||
zT2IAdapterModelField,
|
||||
zVAEModelField,
|
||||
} from './common';
|
||||
|
||||
/**
|
||||
* zod schemas & inferred types for fields.
|
||||
*
|
||||
* These schemas and types are only required for stateful field - fields that have UI components
|
||||
* and allow the user to directly provide values.
|
||||
*
|
||||
* This includes primitive values (numbers, strings, booleans), models, scheduler, etc.
|
||||
*
|
||||
* If a field type does not have a UI component, then it does not need to be included here, because
|
||||
* we never store its value. Such field types will be handled via the "StatelessField" logic.
|
||||
*
|
||||
* Fields require:
|
||||
* - z<TypeName>FieldType - zod schema for the field type
|
||||
* - z<TypeName>FieldValue - zod schema for the field value
|
||||
* - z<TypeName>FieldInputInstance - zod schema for the field's input instance
|
||||
* - z<TypeName>FieldOutputInstance - zod schema for the field's output instance
|
||||
* - z<TypeName>FieldInputTemplate - zod schema for the field's input template
|
||||
* - z<TypeName>FieldOutputTemplate - zod schema for the field's output template
|
||||
* - inferred types for each schema
|
||||
* - type guards for InputInstance and InputTemplate
|
||||
*
|
||||
* These then must be added to the unions at the bottom of this file.
|
||||
*/
|
||||
|
||||
/** */
|
||||
|
||||
// #region Base schemas & misc
|
||||
export const zFieldInput = z.enum(['connection', 'direct', 'any']);
|
||||
export type FieldInput = z.infer<typeof zFieldInput>;
|
||||
|
||||
export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']);
|
||||
export type FieldUIComponent = z.infer<typeof zFieldUIComponent>;
|
||||
|
||||
export const zFieldInstanceBase = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
name: z.string().trim().min(1),
|
||||
});
|
||||
export const zFieldInputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('input'),
|
||||
label: z.string().nullish(),
|
||||
});
|
||||
export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
});
|
||||
export type FieldInstanceBase = z.infer<typeof zFieldInstanceBase>;
|
||||
export type FieldInputInstanceBase = z.infer<typeof zFieldInputInstanceBase>;
|
||||
export type FieldOutputInstanceBase = z.infer<typeof zFieldOutputInstanceBase>;
|
||||
|
||||
export const zFieldTemplateBase = z.object({
|
||||
name: z.string().min(1),
|
||||
title: z.string().min(1),
|
||||
description: z.string().nullish(),
|
||||
ui_hidden: z.boolean(),
|
||||
ui_type: z.string().nullish(),
|
||||
ui_order: z.number().int().nullish(),
|
||||
});
|
||||
export const zFieldInputTemplateBase = zFieldTemplateBase.extend({
|
||||
fieldKind: z.literal('input'),
|
||||
input: zFieldInput,
|
||||
required: z.boolean(),
|
||||
ui_component: zFieldUIComponent.nullish(),
|
||||
ui_choice_labels: z.record(z.string()).nullish(),
|
||||
});
|
||||
export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({
|
||||
fieldKind: z.literal('output'),
|
||||
});
|
||||
export type FieldTemplateBase = z.infer<typeof zFieldTemplateBase>;
|
||||
export type FieldInputTemplateBase = z.infer<typeof zFieldInputTemplateBase>;
|
||||
export type FieldOutputTemplateBase = z.infer<typeof zFieldOutputTemplateBase>;
|
||||
|
||||
export const zFieldTypeBase = z.object({
|
||||
isCollection: z.boolean(),
|
||||
isCollectionOrScalar: z.boolean(),
|
||||
});
|
||||
|
||||
export const zFieldIdentifier = z.object({
|
||||
nodeId: z.string().trim().min(1),
|
||||
fieldName: z.string().trim().min(1),
|
||||
});
|
||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||
export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField
|
||||
export const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IntegerField'),
|
||||
});
|
||||
export const zIntegerFieldValue = z.number().int();
|
||||
export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
value: zIntegerFieldValue,
|
||||
});
|
||||
export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
});
|
||||
export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
default: zIntegerFieldValue,
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().int().optional(),
|
||||
exclusiveMaximum: z.number().int().optional(),
|
||||
minimum: z.number().int().optional(),
|
||||
exclusiveMinimum: z.number().int().optional(),
|
||||
});
|
||||
export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIntegerFieldType,
|
||||
});
|
||||
export type IntegerFieldType = z.infer<typeof zIntegerFieldType>;
|
||||
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
|
||||
export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance =>
|
||||
zIntegerFieldInputInstance.safeParse(val).success;
|
||||
export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate =>
|
||||
zIntegerFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region FloatField
|
||||
export const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
});
|
||||
export const zFloatFieldValue = z.number();
|
||||
export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
value: zFloatFieldValue,
|
||||
});
|
||||
export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zFloatFieldType,
|
||||
});
|
||||
export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
default: zFloatFieldValue,
|
||||
multipleOf: z.number().optional(),
|
||||
maximum: z.number().optional(),
|
||||
exclusiveMaximum: z.number().optional(),
|
||||
minimum: z.number().optional(),
|
||||
exclusiveMinimum: z.number().optional(),
|
||||
});
|
||||
export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatFieldType,
|
||||
});
|
||||
export type FloatFieldType = z.infer<typeof zFloatFieldType>;
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
export type FloatFieldOutputInstance = z.infer<typeof zFloatFieldOutputInstance>;
|
||||
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
|
||||
export type FloatFieldOutputTemplate = z.infer<typeof zFloatFieldOutputTemplate>;
|
||||
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
|
||||
zFloatFieldInputInstance.safeParse(val).success;
|
||||
export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate =>
|
||||
zFloatFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StringField
|
||||
export const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
});
|
||||
export const zStringFieldValue = z.string();
|
||||
export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
value: zStringFieldValue,
|
||||
});
|
||||
export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
default: zStringFieldValue,
|
||||
maxLength: z.number().int().optional(),
|
||||
minLength: z.number().int().optional(),
|
||||
});
|
||||
export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
|
||||
export type StringFieldType = z.infer<typeof zStringFieldType>;
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldOutputInstance = z.infer<typeof zStringFieldOutputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
export type StringFieldOutputTemplate = z.infer<typeof zStringFieldOutputTemplate>;
|
||||
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
|
||||
zStringFieldInputInstance.safeParse(val).success;
|
||||
export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate =>
|
||||
zStringFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region BooleanField
|
||||
export const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
});
|
||||
export const zBooleanFieldValue = z.boolean();
|
||||
export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
value: zBooleanFieldValue,
|
||||
});
|
||||
export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
});
|
||||
export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
default: zBooleanFieldValue,
|
||||
});
|
||||
export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBooleanFieldType,
|
||||
});
|
||||
export type BooleanFieldType = z.infer<typeof zBooleanFieldType>;
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
export type BooleanFieldOutputInstance = z.infer<typeof zBooleanFieldOutputInstance>;
|
||||
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
|
||||
export type BooleanFieldOutputTemplate = z.infer<typeof zBooleanFieldOutputTemplate>;
|
||||
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
|
||||
zBooleanFieldInputInstance.safeParse(val).success;
|
||||
export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate =>
|
||||
zBooleanFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region EnumField
|
||||
export const zEnumFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('EnumField'),
|
||||
});
|
||||
export const zEnumFieldValue = z.string();
|
||||
export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
value: zEnumFieldValue,
|
||||
});
|
||||
export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zEnumFieldType,
|
||||
});
|
||||
export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
default: zEnumFieldValue,
|
||||
options: z.array(z.string()),
|
||||
labels: z.record(z.string()).optional(),
|
||||
});
|
||||
export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zEnumFieldType,
|
||||
});
|
||||
export type EnumFieldType = z.infer<typeof zEnumFieldType>;
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
export type EnumFieldOutputInstance = z.infer<typeof zEnumFieldOutputInstance>;
|
||||
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
|
||||
export type EnumFieldOutputTemplate = z.infer<typeof zEnumFieldOutputTemplate>;
|
||||
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
|
||||
zEnumFieldInputInstance.safeParse(val).success;
|
||||
export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate =>
|
||||
zEnumFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ImageField
|
||||
export const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
});
|
||||
export const zImageFieldValue = zImageField.optional();
|
||||
export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
value: zImageFieldValue,
|
||||
});
|
||||
export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zImageFieldType,
|
||||
});
|
||||
export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
default: zImageFieldValue,
|
||||
});
|
||||
export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zImageFieldType,
|
||||
});
|
||||
export type ImageFieldType = z.infer<typeof zImageFieldType>;
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
export type ImageFieldOutputInstance = z.infer<typeof zImageFieldOutputInstance>;
|
||||
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
|
||||
export type ImageFieldOutputTemplate = z.infer<typeof zImageFieldOutputTemplate>;
|
||||
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
|
||||
zImageFieldInputInstance.safeParse(val).success;
|
||||
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
|
||||
zImageFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
export const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
});
|
||||
export const zBoardFieldValue = zBoardField.optional();
|
||||
export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
value: zBoardFieldValue,
|
||||
});
|
||||
export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zBoardFieldType,
|
||||
});
|
||||
export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
default: zBoardFieldValue,
|
||||
});
|
||||
export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBoardFieldType,
|
||||
});
|
||||
export type BoardFieldType = z.infer<typeof zBoardFieldType>;
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
export type BoardFieldOutputInstance = z.infer<typeof zBoardFieldOutputInstance>;
|
||||
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
|
||||
export type BoardFieldOutputTemplate = z.infer<typeof zBoardFieldOutputTemplate>;
|
||||
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
|
||||
zBoardFieldInputInstance.safeParse(val).success;
|
||||
export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate =>
|
||||
zBoardFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ColorField
|
||||
export const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
});
|
||||
export const zColorFieldValue = zColorField.optional();
|
||||
export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
value: zColorFieldValue,
|
||||
});
|
||||
export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zColorFieldType,
|
||||
});
|
||||
export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
default: zColorFieldValue,
|
||||
});
|
||||
export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zColorFieldType,
|
||||
});
|
||||
export type ColorFieldType = z.infer<typeof zColorFieldType>;
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
export type ColorFieldOutputInstance = z.infer<typeof zColorFieldOutputInstance>;
|
||||
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
|
||||
export type ColorFieldOutputTemplate = z.infer<typeof zColorFieldOutputTemplate>;
|
||||
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
|
||||
zColorFieldInputInstance.safeParse(val).success;
|
||||
export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate =>
|
||||
zColorFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region MainModelField
|
||||
export const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
});
|
||||
export const zMainModelFieldValue = zMainModelField.optional();
|
||||
export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
});
|
||||
export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
default: zMainModelFieldValue,
|
||||
});
|
||||
export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zMainModelFieldType,
|
||||
});
|
||||
export type MainModelFieldType = z.infer<typeof zMainModelFieldType>;
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
export type MainModelFieldOutputInstance = z.infer<typeof zMainModelFieldOutputInstance>;
|
||||
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
|
||||
export type MainModelFieldOutputTemplate = z.infer<typeof zMainModelFieldOutputTemplate>;
|
||||
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
|
||||
zMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate =>
|
||||
zMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
export const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLMainModelField'),
|
||||
});
|
||||
export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
value: zSDXLMainModelFieldValue,
|
||||
});
|
||||
export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
});
|
||||
export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
default: zSDXLMainModelFieldValue,
|
||||
});
|
||||
export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLMainModelFieldType,
|
||||
});
|
||||
export type SDXLMainModelFieldType = z.infer<typeof zSDXLMainModelFieldType>;
|
||||
export type SDXLMainModelFieldValue = z.infer<typeof zSDXLMainModelFieldValue>;
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldOutputInstance = z.infer<typeof zSDXLMainModelFieldOutputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
export type SDXLMainModelFieldOutputTemplate = z.infer<typeof zSDXLMainModelFieldOutputTemplate>;
|
||||
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
|
||||
zSDXLMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate =>
|
||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
});
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
value: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
default: zSDXLRefinerModelFieldValue,
|
||||
});
|
||||
export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSDXLRefinerModelFieldType,
|
||||
});
|
||||
export type SDXLRefinerModelFieldType = z.infer<typeof zSDXLRefinerModelFieldType>;
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
export type SDXLRefinerModelFieldOutputInstance = z.infer<typeof zSDXLRefinerModelFieldOutputInstance>;
|
||||
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
|
||||
export type SDXLRefinerModelFieldOutputTemplate = z.infer<typeof zSDXLRefinerModelFieldOutputTemplate>;
|
||||
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
|
||||
zSDXLRefinerModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate =>
|
||||
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
export const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
});
|
||||
export const zVAEModelFieldValue = zVAEModelField.optional();
|
||||
export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
});
|
||||
export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
default: zVAEModelFieldValue,
|
||||
});
|
||||
export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zVAEModelFieldType,
|
||||
});
|
||||
export type VAEModelFieldType = z.infer<typeof zVAEModelFieldType>;
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
export type VAEModelFieldOutputInstance = z.infer<typeof zVAEModelFieldOutputInstance>;
|
||||
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
|
||||
export type VAEModelFieldOutputTemplate = z.infer<typeof zVAEModelFieldOutputTemplate>;
|
||||
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
|
||||
zVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate =>
|
||||
zVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region LoRAModelField
|
||||
export const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
});
|
||||
export const zLoRAModelFieldValue = zLoRAModelField.optional();
|
||||
export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
});
|
||||
export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
default: zLoRAModelFieldValue,
|
||||
});
|
||||
export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLoRAModelFieldType,
|
||||
});
|
||||
export type LoRAModelFieldType = z.infer<typeof zLoRAModelFieldType>;
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
export type LoRAModelFieldOutputInstance = z.infer<typeof zLoRAModelFieldOutputInstance>;
|
||||
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
|
||||
export type LoRAModelFieldOutputTemplate = z.infer<typeof zLoRAModelFieldOutputTemplate>;
|
||||
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
|
||||
zLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate =>
|
||||
zLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
export const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
});
|
||||
export const zControlNetModelFieldValue = zControlNetModelField.optional();
|
||||
export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
});
|
||||
export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
default: zControlNetModelFieldValue,
|
||||
});
|
||||
export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zControlNetModelFieldType,
|
||||
});
|
||||
export type ControlNetModelFieldType = z.infer<typeof zControlNetModelFieldType>;
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
export type ControlNetModelFieldOutputInstance = z.infer<typeof zControlNetModelFieldOutputInstance>;
|
||||
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
|
||||
export type ControlNetModelFieldOutputTemplate = z.infer<typeof zControlNetModelFieldOutputTemplate>;
|
||||
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
|
||||
zControlNetModelFieldInputInstance.safeParse(val).success;
|
||||
export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate =>
|
||||
zControlNetModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isControlNetModelFieldValue = (v: unknown): v is ControlNetModelFieldValue =>
|
||||
zControlNetModelFieldValue.safeParse(v).success;
|
||||
// #endregion
|
||||
|
||||
// #region IPAdapterModelField
|
||||
export const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
});
|
||||
export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional();
|
||||
export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
});
|
||||
export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
default: zIPAdapterModelFieldValue,
|
||||
});
|
||||
export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIPAdapterModelFieldType,
|
||||
});
|
||||
export type IPAdapterModelFieldType = z.infer<typeof zIPAdapterModelFieldType>;
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
export type IPAdapterModelFieldOutputInstance = z.infer<typeof zIPAdapterModelFieldOutputInstance>;
|
||||
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
|
||||
export type IPAdapterModelFieldOutputTemplate = z.infer<typeof zIPAdapterModelFieldOutputTemplate>;
|
||||
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
|
||||
zIPAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate =>
|
||||
zIPAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isIPAdapterModelFieldValue = (val: unknown): val is IPAdapterModelFieldValue =>
|
||||
zIPAdapterModelFieldValue.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region T2IAdapterField
|
||||
export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
});
|
||||
export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional();
|
||||
export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
});
|
||||
export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
default: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zT2IAdapterModelFieldType,
|
||||
});
|
||||
export type T2IAdapterModelFieldType = z.infer<typeof zT2IAdapterModelFieldType>;
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
export type T2IAdapterModelFieldOutputInstance = z.infer<typeof zT2IAdapterModelFieldOutputInstance>;
|
||||
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
|
||||
export type T2IAdapterModelFieldOutputTemplate = z.infer<typeof zT2IAdapterModelFieldOutputTemplate>;
|
||||
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
|
||||
zT2IAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate =>
|
||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
export const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
});
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
value: zSchedulerFieldValue,
|
||||
});
|
||||
export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
});
|
||||
export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
default: zSchedulerFieldValue,
|
||||
});
|
||||
export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSchedulerFieldType,
|
||||
});
|
||||
export type SchedulerFieldType = z.infer<typeof zSchedulerFieldType>;
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
export type SchedulerFieldOutputInstance = z.infer<typeof zSchedulerFieldOutputInstance>;
|
||||
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
|
||||
export type SchedulerFieldOutputTemplate = z.infer<typeof zSchedulerFieldOutputTemplate>;
|
||||
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
|
||||
zSchedulerFieldInputInstance.safeParse(val).success;
|
||||
export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate =>
|
||||
zSchedulerFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatelessField
|
||||
/**
|
||||
* StatelessField is a catchall for stateless fields with no UI input components. They do not
|
||||
* do not support "direct" input, instead only accepting connections from other fields.
|
||||
*
|
||||
* This field type serves as a "generic" field type.
|
||||
*
|
||||
* Examples include:
|
||||
* - Fields like UNetField or LatentsField where we do not allow direct UI input
|
||||
* - Reserved fields like IsIntermediate
|
||||
* - Any other field we don't have full-on schemas for
|
||||
*/
|
||||
export const zStatelessFieldType = zFieldTypeBase.extend({
|
||||
name: z.string().min(1), // stateless --> we accept the field's name as the type
|
||||
});
|
||||
export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling
|
||||
export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
value: zStatelessFieldValue,
|
||||
});
|
||||
export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
});
|
||||
export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
default: zStatelessFieldValue,
|
||||
input: z.literal('connection'), // stateless --> only accepts connection inputs
|
||||
});
|
||||
export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStatelessFieldType,
|
||||
});
|
||||
|
||||
export type StatelessFieldType = z.infer<typeof zStatelessFieldType>;
|
||||
export type StatelessFieldValue = z.infer<typeof zStatelessFieldValue>;
|
||||
export type StatelessFieldInputInstance = z.infer<typeof zStatelessFieldInputInstance>;
|
||||
export type StatelessFieldOutputInstance = z.infer<typeof zStatelessFieldOutputInstance>;
|
||||
export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTemplate>;
|
||||
export type StatelessFieldOutputTemplate = z.infer<typeof zStatelessFieldOutputTemplate>;
|
||||
// #endregion
|
||||
|
||||
/**
|
||||
* Here we define the main field unions:
|
||||
* - FieldType
|
||||
* - FieldValue
|
||||
* - FieldInputInstance
|
||||
* - FieldOutputInstance
|
||||
* - FieldInputTemplate
|
||||
* - FieldOutputTemplate
|
||||
*
|
||||
* All stateful fields are unioned together, and then that union is unioned with StatelessField.
|
||||
*
|
||||
* This allows us to interact with stateful fields without needing to worry about "generic" handling
|
||||
* for all other StatelessFields.
|
||||
*/
|
||||
|
||||
// #region StatefulFieldType & FieldType
|
||||
export const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
export const isStatefulFieldType = (val: unknown): val is StatefulFieldType =>
|
||||
zStatefulFieldType.safeParse(val).success;
|
||||
|
||||
export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]);
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
export const isFieldType = (val: unknown): val is FieldType => zFieldType.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldValue & FieldValue
|
||||
export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
zFloatFieldValue,
|
||||
zStringFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zEnumFieldValue,
|
||||
zImageFieldValue,
|
||||
zBoardFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
export type StatefulFieldValue = z.infer<typeof zStatefulFieldValue>;
|
||||
export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue =>
|
||||
zStatefulFieldValue.safeParse(val).success;
|
||||
|
||||
export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]);
|
||||
export type FieldValue = z.infer<typeof zFieldValue>;
|
||||
export const isFieldValue = (val: unknown): val is FieldValue => zFieldValue.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputInstance & FieldInputInstance
|
||||
export const zStatefulFieldInputInstance = z.union([
|
||||
zIntegerFieldInputInstance,
|
||||
zFloatFieldInputInstance,
|
||||
zStringFieldInputInstance,
|
||||
zBooleanFieldInputInstance,
|
||||
zEnumFieldInputInstance,
|
||||
zImageFieldInputInstance,
|
||||
zBoardFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
export type StatefulFieldInputInstance = z.infer<typeof zStatefulFieldInputInstance>;
|
||||
export const isStatefulFieldInputInstance = (val: unknown): val is StatefulFieldInputInstance =>
|
||||
zStatefulFieldInputInstance.safeParse(val).success;
|
||||
|
||||
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
|
||||
export type FieldInputInstance = z.infer<typeof zFieldInputInstance>;
|
||||
export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
|
||||
zFieldInputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputInstance & FieldOutputInstance
|
||||
export const zStatefulFieldOutputInstance = z.union([
|
||||
zIntegerFieldOutputInstance,
|
||||
zFloatFieldOutputInstance,
|
||||
zStringFieldOutputInstance,
|
||||
zBooleanFieldOutputInstance,
|
||||
zEnumFieldOutputInstance,
|
||||
zImageFieldOutputInstance,
|
||||
zBoardFieldOutputInstance,
|
||||
zMainModelFieldOutputInstance,
|
||||
zSDXLMainModelFieldOutputInstance,
|
||||
zSDXLRefinerModelFieldOutputInstance,
|
||||
zVAEModelFieldOutputInstance,
|
||||
zLoRAModelFieldOutputInstance,
|
||||
zControlNetModelFieldOutputInstance,
|
||||
zIPAdapterModelFieldOutputInstance,
|
||||
zT2IAdapterModelFieldOutputInstance,
|
||||
zColorFieldOutputInstance,
|
||||
zSchedulerFieldOutputInstance,
|
||||
]);
|
||||
export type StatefulFieldOutputInstance = z.infer<typeof zStatefulFieldOutputInstance>;
|
||||
export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance =>
|
||||
zStatefulFieldOutputInstance.safeParse(val).success;
|
||||
|
||||
export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]);
|
||||
export type FieldOutputInstance = z.infer<typeof zFieldOutputInstance>;
|
||||
export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance =>
|
||||
zFieldOutputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputTemplate & FieldInputTemplate
|
||||
export const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
zFloatFieldInputTemplate,
|
||||
zStringFieldInputTemplate,
|
||||
zBooleanFieldInputTemplate,
|
||||
zEnumFieldInputTemplate,
|
||||
zImageFieldInputTemplate,
|
||||
zBoardFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
zLoRAModelFieldInputTemplate,
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
]);
|
||||
export type StatefulFieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
|
||||
export const isStatefulFieldInputTemplate = (val: unknown): val is StatefulFieldInputTemplate =>
|
||||
zStatefulFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
|
||||
export type FieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
|
||||
export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate =>
|
||||
zFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputTemplate & FieldOutputTemplate
|
||||
export const zStatefulFieldOutputTemplate = z.union([
|
||||
zIntegerFieldOutputTemplate,
|
||||
zFloatFieldOutputTemplate,
|
||||
zStringFieldOutputTemplate,
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
zLoRAModelFieldOutputTemplate,
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
]);
|
||||
export type StatefulFieldOutputTemplate = z.infer<typeof zStatefulFieldOutputTemplate>;
|
||||
export const isStatefulFieldOutputTemplate = (val: unknown): val is StatefulFieldOutputTemplate =>
|
||||
zStatefulFieldOutputTemplate.safeParse(val).success;
|
||||
|
||||
export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]);
|
||||
export type FieldOutputTemplate = z.infer<typeof zFieldOutputTemplate>;
|
||||
export const isFieldOutputTemplate = (val: unknown): val is FieldOutputTemplate =>
|
||||
zFieldOutputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
@@ -0,0 +1,93 @@
|
||||
import type { Edge, Node } from 'reactflow';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zClassification, zProgressImage } from './common';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field';
|
||||
import { zSemVer } from './semver';
|
||||
|
||||
// #region InvocationTemplate
|
||||
export const zInvocationTemplate = z.object({
|
||||
type: z.string(),
|
||||
title: z.string(),
|
||||
description: z.string(),
|
||||
tags: z.array(z.string().min(1)),
|
||||
inputs: z.record(zFieldInputTemplate),
|
||||
outputs: z.record(zFieldOutputTemplate),
|
||||
outputType: z.string().min(1),
|
||||
version: zSemVer,
|
||||
useCache: z.boolean(),
|
||||
nodePack: z.string().min(1).nullish(),
|
||||
classification: zClassification,
|
||||
});
|
||||
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
||||
// #endregion
|
||||
|
||||
// #region NodeData
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.string().trim().min(1),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
isIntermediate: z.boolean(),
|
||||
useCache: z.boolean(),
|
||||
version: zSemVer,
|
||||
nodePack: z.string().min(1).nullish(),
|
||||
inputs: z.record(zFieldInputInstance),
|
||||
outputs: z.record(zFieldOutputInstance),
|
||||
});
|
||||
|
||||
export const zNotesNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
});
|
||||
export const zCurrentImageNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('current_image'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
});
|
||||
export const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]);
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
|
||||
export type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
|
||||
export type AnyNodeData = z.infer<typeof zAnyNodeData>;
|
||||
|
||||
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
|
||||
export type NotesNode = Node<NotesNodeData, 'notes'>;
|
||||
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
|
||||
export type AnyNode = Node<AnyNodeData>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode =>
|
||||
Boolean(node && node.type === 'current_image');
|
||||
export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
// #endregion
|
||||
|
||||
// #region NodeExecutionState
|
||||
export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']);
|
||||
export const zNodeExecutionState = z.object({
|
||||
nodeId: z.string().trim().min(1),
|
||||
status: zNodeStatus,
|
||||
progress: z.number().nullable(),
|
||||
progressImage: zProgressImage.nullable(),
|
||||
error: z.string().nullable(),
|
||||
outputs: z.array(z.any()),
|
||||
});
|
||||
export type NodeExecutionState = z.infer<typeof zNodeExecutionState>;
|
||||
export type NodeStatus = z.infer<typeof zNodeStatus>;
|
||||
// #endregion
|
||||
|
||||
// #region Edges
|
||||
export const zInvocationNodeEdgeExtra = z.object({
|
||||
type: z.union([z.literal('default'), z.literal('collapsed')]),
|
||||
});
|
||||
export type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
|
||||
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
|
||||
// #endregion
|
||||
@@ -0,0 +1,77 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterField,
|
||||
zVAEModelField,
|
||||
} from './common';
|
||||
|
||||
// #region Metadata-optimized versions of schemas
|
||||
// TODO: It's possible that `deepPartial` will be deprecated:
|
||||
// - https://github.com/colinhacks/zod/issues/2106
|
||||
// - https://github.com/colinhacks/zod/issues/2854
|
||||
export const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
const zControlNetMetadataItem = zControlField.deepPartial();
|
||||
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
||||
const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial();
|
||||
const zModelMetadataItem = zMainModelField.deepPartial();
|
||||
const zVAEModelMetadataItem = zVAEModelField.deepPartial();
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
|
||||
export type T2IAdapterMetadataItem = z.infer<typeof zT2IAdapterMetadataItem>;
|
||||
export type SDXLRefinerModelMetadataItem = z.infer<typeof zSDXLRefinerModelMetadataItem>;
|
||||
export type ModelMetadataItem = z.infer<typeof zModelMetadataItem>;
|
||||
export type VAEModelMetadataItem = z.infer<typeof zVAEModelMetadataItem>;
|
||||
// #endregion
|
||||
|
||||
// #region CoreMetadata
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish().catch(null),
|
||||
generation_mode: z.string().nullish().catch(null),
|
||||
created_by: z.string().nullish().catch(null),
|
||||
positive_prompt: z.string().nullish().catch(null),
|
||||
negative_prompt: z.string().nullish().catch(null),
|
||||
width: z.number().int().nullish().catch(null),
|
||||
height: z.number().int().nullish().catch(null),
|
||||
seed: z.number().int().nullish().catch(null),
|
||||
rand_device: z.string().nullish().catch(null),
|
||||
cfg_scale: z.number().nullish().catch(null),
|
||||
cfg_rescale_multiplier: z.number().nullish().catch(null),
|
||||
steps: z.number().int().nullish().catch(null),
|
||||
scheduler: z.string().nullish().catch(null),
|
||||
clip_skip: z.number().int().nullish().catch(null),
|
||||
model: zModelMetadataItem.nullish().catch(null),
|
||||
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
|
||||
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
|
||||
t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null),
|
||||
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
|
||||
vae: zVAEModelMetadataItem.nullish().catch(null),
|
||||
strength: z.number().nullish().catch(null),
|
||||
hrf_enabled: z.boolean().nullish().catch(null),
|
||||
hrf_strength: z.number().nullish().catch(null),
|
||||
hrf_method: z.string().nullish().catch(null),
|
||||
init_image: z.string().nullish().catch(null),
|
||||
positive_style_prompt: z.string().nullish().catch(null),
|
||||
negative_style_prompt: z.string().nullish().catch(null),
|
||||
refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null),
|
||||
refiner_cfg_scale: z.number().nullish().catch(null),
|
||||
refiner_steps: z.number().int().nullish().catch(null),
|
||||
refiner_scheduler: z.string().nullish().catch(null),
|
||||
refiner_positive_aesthetic_score: z.number().nullish().catch(null),
|
||||
refiner_negative_aesthetic_score: z.number().nullish().catch(null),
|
||||
refiner_start: z.number().nullish().catch(null),
|
||||
})
|
||||
.passthrough();
|
||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||
|
||||
// #endregion
|
||||
86
invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts
Normal file
86
invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import type { OpenAPIV3_1 } from 'openapi-types';
|
||||
import type {
|
||||
InputFieldJSONSchemaExtra,
|
||||
InvocationJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
} from 'services/api/types';
|
||||
|
||||
// Janky customization of OpenAPI Schema :/
|
||||
|
||||
export type InvocationSchemaExtra = InvocationJSONSchemaExtra & {
|
||||
output: OpenAPIV3_1.ReferenceObject; // the output of the invocation
|
||||
title: string;
|
||||
category?: string;
|
||||
tags?: string[];
|
||||
version: string;
|
||||
properties: Omit<
|
||||
NonNullable<OpenAPIV3_1.SchemaObject['properties']> & (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra),
|
||||
'type'
|
||||
> & {
|
||||
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: string;
|
||||
};
|
||||
use_cache: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: boolean;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export type InvocationSchemaType = {
|
||||
default: string; // the type of the invocation
|
||||
};
|
||||
|
||||
export type InvocationBaseSchemaObject = Omit<OpenAPIV3_1.BaseSchemaObject, 'title' | 'type' | 'properties'> &
|
||||
InvocationSchemaExtra;
|
||||
|
||||
export type InvocationOutputSchemaObject = Omit<OpenAPIV3_1.SchemaObject, 'properties'> & {
|
||||
properties: OpenAPIV3_1.SchemaObject['properties'] & {
|
||||
type: Omit<OpenAPIV3_1.SchemaObject, 'default'> & {
|
||||
default: string;
|
||||
};
|
||||
} & {
|
||||
class: 'output';
|
||||
};
|
||||
};
|
||||
|
||||
export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & InputFieldJSONSchemaExtra;
|
||||
|
||||
export type OpenAPIV3_1SchemaOrRef = OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject;
|
||||
|
||||
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type: OpenAPIV3_1.ArraySchemaObjectType;
|
||||
items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject;
|
||||
}
|
||||
export interface NonArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type?: OpenAPIV3_1.NonArraySchemaObjectType;
|
||||
}
|
||||
|
||||
export type InvocationSchemaObject = (ArraySchemaObject | NonArraySchemaObject) & { class: 'invocation' };
|
||||
|
||||
export const isSchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||
|
||||
export const isArraySchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.ArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||
|
||||
export const isNonArraySchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.NonArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||
|
||||
export const isRefObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined
|
||||
): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||
|
||||
export const isInvocationSchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationSchemaObject
|
||||
): obj is InvocationSchemaObject => 'class' in obj && obj.class === 'invocation';
|
||||
|
||||
export const isInvocationOutputSchemaObject = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationOutputSchemaObject
|
||||
): obj is InvocationOutputSchemaObject => 'class' in obj && obj.class === 'output';
|
||||
|
||||
export const isInvocationFieldSchema = (
|
||||
obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject
|
||||
): obj is InvocationFieldSchema => !('$ref' in obj);
|
||||
21
invokeai/frontend/web/src/features/nodes/types/v2/semver.ts
Normal file
21
invokeai/frontend/web/src/features/nodes/types/v2/semver.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
// Schemas and types for working with semver
|
||||
|
||||
const zVersionInt = z.coerce.number().int().min(0);
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
zVersionInt.safeParse(major).success && zVersionInt.safeParse(minor).success && zVersionInt.safeParse(patch).success
|
||||
);
|
||||
});
|
||||
|
||||
export const zParsedSemver = zSemVer.transform((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return {
|
||||
major: Number(major),
|
||||
minor: Number(minor),
|
||||
patch: Number(patch),
|
||||
};
|
||||
});
|
||||
@@ -0,0 +1,89 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zFieldIdentifier } from './field';
|
||||
import { zInvocationNodeData, zNotesNodeData } from './invocation';
|
||||
|
||||
// #region Workflow misc
|
||||
export const zXYPosition = z
|
||||
.object({
|
||||
x: z.number(),
|
||||
y: z.number(),
|
||||
})
|
||||
.default({ x: 0, y: 0 });
|
||||
export type XYPosition = z.infer<typeof zXYPosition>;
|
||||
|
||||
export const zDimension = z.number().gt(0).nullish();
|
||||
export type Dimension = z.infer<typeof zDimension>;
|
||||
|
||||
export const zWorkflowCategory = z.enum(['user', 'default', 'project']);
|
||||
export type WorkflowCategory = z.infer<typeof zWorkflowCategory>;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow Nodes
|
||||
export const zWorkflowInvocationNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('invocation'),
|
||||
data: zInvocationNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNotesNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
data: zNotesNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]);
|
||||
|
||||
export type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
|
||||
export type WorkflowNotesNode = z.infer<typeof zWorkflowNotesNode>;
|
||||
export type WorkflowNode = z.infer<typeof zWorkflowNode>;
|
||||
|
||||
export const isWorkflowInvocationNode = (val: unknown): val is WorkflowInvocationNode =>
|
||||
zWorkflowInvocationNode.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow Edges
|
||||
export const zWorkflowEdgeBase = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
source: z.string().trim().min(1),
|
||||
target: z.string().trim().min(1),
|
||||
});
|
||||
export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({
|
||||
type: z.literal('default'),
|
||||
sourceHandle: z.string().trim().min(1),
|
||||
targetHandle: z.string().trim().min(1),
|
||||
});
|
||||
export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({
|
||||
type: z.literal('collapsed'),
|
||||
});
|
||||
export const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]);
|
||||
|
||||
export type WorkflowEdgeDefault = z.infer<typeof zWorkflowEdgeDefault>;
|
||||
export type WorkflowEdgeCollapsed = z.infer<typeof zWorkflowEdgeCollapsed>;
|
||||
export type WorkflowEdge = z.infer<typeof zWorkflowEdge>;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow
|
||||
export const zWorkflowV2 = z.object({
|
||||
id: z.string().min(1).optional(),
|
||||
name: z.string(),
|
||||
author: z.string(),
|
||||
description: z.string(),
|
||||
version: z.string(),
|
||||
contact: z.string(),
|
||||
tags: z.string(),
|
||||
notes: z.string(),
|
||||
nodes: z.array(zWorkflowNode),
|
||||
edges: z.array(zWorkflowEdge),
|
||||
exposedFields: z.array(zFieldIdentifier),
|
||||
meta: z.object({
|
||||
category: zWorkflowCategory.default('user'),
|
||||
version: z.literal('2.0.0'),
|
||||
}),
|
||||
});
|
||||
export type WorkflowV2 = z.infer<typeof zWorkflowV2>;
|
||||
// #endregion
|
||||
@@ -24,16 +24,12 @@ export const zWorkflowInvocationNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('invocation'),
|
||||
data: zInvocationNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNotesNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('notes'),
|
||||
data: zNotesNodeData,
|
||||
width: zDimension,
|
||||
height: zDimension,
|
||||
position: zXYPosition,
|
||||
});
|
||||
export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]);
|
||||
@@ -68,7 +64,7 @@ export type WorkflowEdge = z.infer<typeof zWorkflowEdge>;
|
||||
// #endregion
|
||||
|
||||
// #region Workflow
|
||||
export const zWorkflowV2 = z.object({
|
||||
export const zWorkflowV3 = z.object({
|
||||
id: z.string().min(1).optional(),
|
||||
name: z.string(),
|
||||
author: z.string(),
|
||||
@@ -82,8 +78,8 @@ export const zWorkflowV2 = z.object({
|
||||
exposedFields: z.array(zFieldIdentifier),
|
||||
meta: z.object({
|
||||
category: zWorkflowCategory.default('user'),
|
||||
version: z.literal('2.0.0'),
|
||||
version: z.literal('3.0.0'),
|
||||
}),
|
||||
});
|
||||
export type WorkflowV2 = z.infer<typeof zWorkflowV2>;
|
||||
export type WorkflowV3 = z.infer<typeof zWorkflowV3>;
|
||||
// #endregion
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type { FieldInputInstance, FieldOutputInstance } from 'features/nodes/types/field';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance';
|
||||
import { reduce } from 'lodash-es';
|
||||
@@ -24,25 +24,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe
|
||||
{} as Record<string, FieldInputInstance>
|
||||
);
|
||||
|
||||
const outputs = reduce(
|
||||
template.outputs,
|
||||
(outputsAccumulator, outputTemplate, outputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const outputFieldValue: FieldOutputInstance = {
|
||||
id: fieldId,
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
fieldKind: 'output',
|
||||
};
|
||||
|
||||
outputsAccumulator[outputName] = outputFieldValue;
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, FieldOutputInstance>
|
||||
);
|
||||
|
||||
const node: InvocationNode = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
@@ -58,7 +39,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe
|
||||
isIntermediate: type === 'save_image' ? false : true,
|
||||
useCache: template.useCache,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -54,6 +54,5 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
|
||||
|
||||
// Remove any fields that are not in the template
|
||||
clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs));
|
||||
clone.data.outputs = pick(clone.data.outputs, keys(defaults.data.outputs));
|
||||
return clone;
|
||||
};
|
||||
|
||||
@@ -23,11 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
const fieldInstance: FieldInputInstance = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input' as const,
|
||||
value: template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name),
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,379 @@
|
||||
import {
|
||||
UnableToExtractSchemaNameFromRefError,
|
||||
UnsupportedArrayItemType,
|
||||
UnsupportedPrimitiveTypeError,
|
||||
UnsupportedUnionError,
|
||||
} from 'features/nodes/types/error';
|
||||
import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
|
||||
import { parseFieldType, refObjectToSchemaName } from 'features/nodes/util/schema/parseFieldType';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
type ParseFieldTypeTestCase = {
|
||||
name: string;
|
||||
schema: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema;
|
||||
expected: { name: string; isCollection: boolean; isCollectionOrScalar: boolean };
|
||||
};
|
||||
|
||||
const primitiveTypes: ParseFieldTypeTestCase[] = [
|
||||
{
|
||||
name: 'Scalar IntegerField',
|
||||
schema: { type: 'integer' },
|
||||
expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Scalar FloatField',
|
||||
schema: { type: 'number' },
|
||||
expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Scalar StringField',
|
||||
schema: { type: 'string' },
|
||||
expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Scalar BooleanField',
|
||||
schema: { type: 'boolean' },
|
||||
expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Collection IntegerField',
|
||||
schema: { items: { type: 'integer' }, type: 'array' },
|
||||
expected: { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Collection FloatField',
|
||||
schema: { items: { type: 'number' }, type: 'array' },
|
||||
expected: { name: 'FloatField', isCollection: true, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Collection StringField',
|
||||
schema: { items: { type: 'string' }, type: 'array' },
|
||||
expected: { name: 'StringField', isCollection: true, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Collection BooleanField',
|
||||
schema: { items: { type: 'boolean' }, type: 'array' },
|
||||
expected: { name: 'BooleanField', isCollection: true, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'CollectionOrScalar IntegerField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
type: 'integer',
|
||||
},
|
||||
{
|
||||
items: {
|
||||
type: 'integer',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true },
|
||||
},
|
||||
{
|
||||
name: 'CollectionOrScalar FloatField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
type: 'number',
|
||||
},
|
||||
{
|
||||
items: {
|
||||
type: 'number',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: true },
|
||||
},
|
||||
{
|
||||
name: 'CollectionOrScalar StringField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
type: 'string',
|
||||
},
|
||||
{
|
||||
items: {
|
||||
type: 'string',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: true },
|
||||
},
|
||||
{
|
||||
name: 'CollectionOrScalar BooleanField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
type: 'boolean',
|
||||
},
|
||||
{
|
||||
items: {
|
||||
type: 'boolean',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: true },
|
||||
},
|
||||
];
|
||||
|
||||
const complexTypes: ParseFieldTypeTestCase[] = [
|
||||
{
|
||||
name: 'Scalar ConditioningField',
|
||||
schema: {
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Nullable Scalar ConditioningField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Collection ConditioningField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
items: {
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Nullable Collection ConditioningField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
items: {
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'CollectionOrScalar ConditioningField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
items: {
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
{
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true },
|
||||
},
|
||||
{
|
||||
name: 'Nullable CollectionOrScalar ConditioningField',
|
||||
schema: {
|
||||
anyOf: [
|
||||
{
|
||||
items: {
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
type: 'array',
|
||||
},
|
||||
{
|
||||
$ref: '#/components/schemas/ConditioningField',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true },
|
||||
},
|
||||
];
|
||||
|
||||
const specialCases: ParseFieldTypeTestCase[] = [
|
||||
{
|
||||
name: 'String EnumField',
|
||||
schema: {
|
||||
type: 'string',
|
||||
enum: ['large', 'base', 'small'],
|
||||
},
|
||||
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'String EnumField with one value',
|
||||
schema: {
|
||||
const: 'Some Value',
|
||||
},
|
||||
expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (SchedulerField)',
|
||||
schema: {
|
||||
type: 'string',
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'SchedulerField',
|
||||
},
|
||||
expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (AnyField)',
|
||||
schema: {
|
||||
type: 'string',
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'AnyField',
|
||||
},
|
||||
expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (CollectionField)',
|
||||
schema: {
|
||||
type: 'string',
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'CollectionField',
|
||||
},
|
||||
expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false },
|
||||
},
|
||||
];
|
||||
|
||||
describe('refObjectToSchemaName', async () => {
|
||||
it('parses ref object 1', () => {
|
||||
expect(
|
||||
refObjectToSchemaName({
|
||||
$ref: '#/components/schemas/ImageField',
|
||||
})
|
||||
).toEqual('ImageField');
|
||||
});
|
||||
it('parses ref object 2', () => {
|
||||
expect(
|
||||
refObjectToSchemaName({
|
||||
$ref: '#/components/schemas/T2IAdapterModelField',
|
||||
})
|
||||
).toEqual('T2IAdapterModelField');
|
||||
});
|
||||
});
|
||||
|
||||
describe.concurrent('parseFieldType', async () => {
|
||||
it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }) => {
|
||||
expect(parseFieldType(schema)).toEqual(expected);
|
||||
});
|
||||
it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }) => {
|
||||
expect(parseFieldType(schema)).toEqual(expected);
|
||||
});
|
||||
it.each(specialCases)('parses special case types ($name)', ({ schema, expected }) => {
|
||||
expect(parseFieldType(schema)).toEqual(expected);
|
||||
});
|
||||
|
||||
it('raises if it cannot extract a schema name from a ref', () => {
|
||||
expect(() =>
|
||||
parseFieldType({
|
||||
allOf: [
|
||||
{
|
||||
$ref: '#/components/schemas/',
|
||||
},
|
||||
],
|
||||
})
|
||||
).toThrowError(UnableToExtractSchemaNameFromRefError);
|
||||
});
|
||||
|
||||
it('raises if it receives a union of mismatched types', () => {
|
||||
expect(() =>
|
||||
parseFieldType({
|
||||
anyOf: [
|
||||
{
|
||||
type: 'string',
|
||||
},
|
||||
{
|
||||
type: 'integer',
|
||||
},
|
||||
],
|
||||
})
|
||||
).toThrowError(UnsupportedUnionError);
|
||||
});
|
||||
|
||||
it('raises if it receives a union of mismatched types (excluding null)', () => {
|
||||
expect(() =>
|
||||
parseFieldType({
|
||||
anyOf: [
|
||||
{
|
||||
type: 'string',
|
||||
},
|
||||
{
|
||||
type: 'integer',
|
||||
},
|
||||
{
|
||||
type: 'null',
|
||||
},
|
||||
],
|
||||
})
|
||||
).toThrowError(UnsupportedUnionError);
|
||||
});
|
||||
|
||||
it('raises if it received an unsupported primitive type (object)', () => {
|
||||
expect(() =>
|
||||
parseFieldType({
|
||||
type: 'object',
|
||||
})
|
||||
).toThrowError(UnsupportedPrimitiveTypeError);
|
||||
});
|
||||
|
||||
it('raises if it received an unsupported primitive type (null)', () => {
|
||||
expect(() =>
|
||||
parseFieldType({
|
||||
type: 'null',
|
||||
})
|
||||
).toThrowError(UnsupportedPrimitiveTypeError);
|
||||
});
|
||||
|
||||
it('raises if it received an unsupported array item type (object)', () => {
|
||||
expect(() =>
|
||||
parseFieldType({
|
||||
items: {
|
||||
type: 'object',
|
||||
},
|
||||
type: 'array',
|
||||
})
|
||||
).toThrowError(UnsupportedArrayItemType);
|
||||
});
|
||||
|
||||
it('raises if it received an unsupported array item type (null)', () => {
|
||||
expect(() =>
|
||||
parseFieldType({
|
||||
items: {
|
||||
type: 'null',
|
||||
},
|
||||
type: 'array',
|
||||
})
|
||||
).toThrowError(UnsupportedArrayItemType);
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,12 @@
|
||||
import { FieldParseError } from 'features/nodes/types/error';
|
||||
import {
|
||||
FieldParseError,
|
||||
UnableToExtractSchemaNameFromRefError,
|
||||
UnsupportedArrayItemType,
|
||||
UnsupportedPrimitiveTypeError,
|
||||
UnsupportedUnionError,
|
||||
} from 'features/nodes/types/error';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
|
||||
import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi';
|
||||
import {
|
||||
isArraySchemaObject,
|
||||
isInvocationFieldSchema,
|
||||
@@ -42,7 +48,7 @@ const isCollectionFieldType = (fieldType: string) => {
|
||||
return false;
|
||||
};
|
||||
|
||||
export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => {
|
||||
export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => {
|
||||
if (isInvocationFieldSchema(schemaObject)) {
|
||||
// Check if this field has an explicit type provided by the node schema
|
||||
const { ui_type } = schemaObject;
|
||||
@@ -72,7 +78,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
// This is a single ref type
|
||||
const name = refObjectToSchemaName(allOf[0]);
|
||||
if (!name) {
|
||||
throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
}
|
||||
return {
|
||||
name,
|
||||
@@ -95,7 +101,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
if (isRefObject(filteredAnyOf[0])) {
|
||||
const name = refObjectToSchemaName(filteredAnyOf[0]);
|
||||
if (!name) {
|
||||
throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -118,7 +124,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
|
||||
if (filteredAnyOf.length !== 2) {
|
||||
// This is a union of more than 2 types, which we don't support
|
||||
throw new FieldParseError(
|
||||
throw new UnsupportedUnionError(
|
||||
t('nodes.unsupportedAnyOfLength', {
|
||||
count: filteredAnyOf.length,
|
||||
})
|
||||
@@ -159,7 +165,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
};
|
||||
}
|
||||
|
||||
throw new FieldParseError(
|
||||
throw new UnsupportedUnionError(
|
||||
t('nodes.unsupportedMismatchedUnion', {
|
||||
firstType,
|
||||
secondType,
|
||||
@@ -178,7 +184,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
if (isSchemaObject(schemaObject.items)) {
|
||||
const itemType = schemaObject.items.type;
|
||||
if (!itemType || isArray(itemType)) {
|
||||
throw new FieldParseError(
|
||||
throw new UnsupportedArrayItemType(
|
||||
t('nodes.unsupportedArrayItemType', {
|
||||
type: itemType,
|
||||
})
|
||||
@@ -188,7 +194,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType];
|
||||
if (!name) {
|
||||
// it's 'null', 'object', or 'array' - skip
|
||||
throw new FieldParseError(
|
||||
throw new UnsupportedArrayItemType(
|
||||
t('nodes.unsupportedArrayItemType', {
|
||||
type: itemType,
|
||||
})
|
||||
@@ -204,7 +210,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
// This is a ref object, extract the type name
|
||||
const name = refObjectToSchemaName(schemaObject.items);
|
||||
if (!name) {
|
||||
throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
}
|
||||
return {
|
||||
name,
|
||||
@@ -216,7 +222,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type];
|
||||
if (!name) {
|
||||
// it's 'null', 'object', or 'array' - skip
|
||||
throw new FieldParseError(
|
||||
throw new UnsupportedPrimitiveTypeError(
|
||||
t('nodes.unsupportedArrayItemType', {
|
||||
type: schemaObject.type,
|
||||
})
|
||||
@@ -232,7 +238,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
} else if (isRefObject(schemaObject)) {
|
||||
const name = refObjectToSchemaName(schemaObject);
|
||||
if (!name) {
|
||||
throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||
}
|
||||
return {
|
||||
name,
|
||||
|
||||
@@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { NodesState, WorkflowsState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import i18n from 'i18n';
|
||||
import { cloneDeep, pick } from 'lodash-es';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
@@ -25,14 +25,14 @@ const workflowKeys = [
|
||||
'exposedFields',
|
||||
'meta',
|
||||
'id',
|
||||
] satisfies (keyof WorkflowV2)[];
|
||||
] satisfies (keyof WorkflowV3)[];
|
||||
|
||||
export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2;
|
||||
export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3;
|
||||
|
||||
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 => {
|
||||
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => {
|
||||
const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys);
|
||||
|
||||
const newWorkflow: WorkflowV2 = {
|
||||
const newWorkflow: WorkflowV3 = {
|
||||
...clonedWorkflow,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
@@ -45,8 +45,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
position: { ...node.position },
|
||||
width: node.width,
|
||||
height: node.height,
|
||||
});
|
||||
} else if (isNotesNode(node) && node.type) {
|
||||
newWorkflow.nodes.push({
|
||||
@@ -54,8 +52,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
type: node.type,
|
||||
data: cloneDeep(node.data),
|
||||
position: { ...node.position },
|
||||
width: node.width,
|
||||
height: node.height,
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -83,12 +79,12 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
|
||||
return newWorkflow;
|
||||
};
|
||||
|
||||
export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 | null => {
|
||||
export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 | null => {
|
||||
// builds what really, really should be a valid workflow
|
||||
const workflowToValidate = buildWorkflowFast({ nodes, edges, workflow });
|
||||
|
||||
// but bc we are storing this in the DB, let's be extra sure
|
||||
const result = zWorkflowV2.safeParse(workflowToValidate);
|
||||
const result = zWorkflowV3.safeParse(workflowToValidate);
|
||||
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
|
||||
@@ -6,8 +6,10 @@ import { zSemVer } from 'features/nodes/types/semver';
|
||||
import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap';
|
||||
import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1';
|
||||
import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { z } from 'zod';
|
||||
@@ -30,7 +32,7 @@ const zWorkflowMetaVersion = z.object({
|
||||
* - Workflow schema version bumped to 2.0.0
|
||||
*/
|
||||
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||
const invocationTemplates = $store.get()?.getState().nodeTemplates.templates;
|
||||
const invocationTemplates = $store.get()?.getState().nodes.templates;
|
||||
|
||||
if (!invocationTemplates) {
|
||||
throw new Error(t('app.storeNotInitialized'));
|
||||
@@ -70,26 +72,34 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||
return zWorkflowV2.parse(workflowToMigrate);
|
||||
};
|
||||
|
||||
const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => {
|
||||
// Bump version
|
||||
(workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0';
|
||||
// Parsing strips out any extra properties not in the latest version
|
||||
return zWorkflowV3.parse(workflowToMigrate);
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses a workflow and migrates it to the latest version if necessary.
|
||||
*/
|
||||
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => {
|
||||
export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
|
||||
const workflowVersionResult = zWorkflowMetaVersion.safeParse(data);
|
||||
|
||||
if (!workflowVersionResult.success) {
|
||||
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
|
||||
}
|
||||
|
||||
const { version } = workflowVersionResult.data.meta;
|
||||
let workflow = data as WorkflowV1 | WorkflowV2 | WorkflowV3;
|
||||
|
||||
if (version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(data);
|
||||
return migrateV1toV2(v1);
|
||||
if (workflow.meta.version === '1.0.0') {
|
||||
const v1 = zWorkflowV1.parse(workflow);
|
||||
workflow = migrateV1toV2(v1);
|
||||
}
|
||||
|
||||
if (version === '2.0.0') {
|
||||
return zWorkflowV2.parse(data);
|
||||
if (workflow.meta.version === '2.0.0') {
|
||||
const v2 = zWorkflowV2.parse(workflow);
|
||||
workflow = migrateV2toV3(v2);
|
||||
}
|
||||
|
||||
throw new WorkflowVersionError(t('nodes.unrecognizedWorkflowVersion', { version }));
|
||||
return workflow as WorkflowV3;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { isWorkflowInvocationNode } from 'features/nodes/types/workflow';
|
||||
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
|
||||
import { t } from 'i18next';
|
||||
@@ -16,7 +16,7 @@ type WorkflowWarning = {
|
||||
};
|
||||
|
||||
type ValidateWorkflowResult = {
|
||||
workflow: WorkflowV2;
|
||||
workflow: WorkflowV3;
|
||||
warnings: WorkflowWarning[];
|
||||
};
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useToast } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import { workflowIDChanged, workflowSaved } from 'features/nodes/store/workflowSlice';
|
||||
import type { WorkflowV2 } from 'features/nodes/types/workflow';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { workflowUpdated } from 'features/workflowLibrary/store/actions';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -18,7 +18,7 @@ type UseSaveLibraryWorkflowReturn = {
|
||||
|
||||
type UseSaveLibraryWorkflow = () => UseSaveLibraryWorkflowReturn;
|
||||
|
||||
export const isWorkflowWithID = (workflow: WorkflowV2): workflow is O.Required<WorkflowV2, 'id'> =>
|
||||
export const isWorkflowWithID = (workflow: WorkflowV3): workflow is O.Required<WorkflowV3, 'id'> =>
|
||||
Boolean(workflow.id);
|
||||
|
||||
export const useSaveLibraryWorkflow: UseSaveLibraryWorkflow = () => {
|
||||
|
||||
@@ -1,12 +1,90 @@
|
||||
/// <reference types="vitest" />
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import path from 'path';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import type { PluginOption } from 'vite';
|
||||
import { defineConfig } from 'vite';
|
||||
|
||||
import { appConfig } from './config/vite.app.config.mjs';
|
||||
import { packageConfig } from './config/vite.package.config.mjs';
|
||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||
import dts from 'vite-plugin-dts';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
|
||||
export default defineConfig(({ mode }) => {
|
||||
if (mode === 'package') {
|
||||
return packageConfig;
|
||||
return {
|
||||
base: './',
|
||||
plugins: [
|
||||
react(),
|
||||
eslint(),
|
||||
tsconfigPaths(),
|
||||
visualizer() as unknown as PluginOption,
|
||||
dts({
|
||||
insertTypesEntry: true,
|
||||
}),
|
||||
cssInjectedByJsPlugin(),
|
||||
],
|
||||
build: {
|
||||
cssCodeSplit: true,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
||||
name: 'InvokeAIUI',
|
||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
||||
},
|
||||
rollupOptions: {
|
||||
external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'],
|
||||
output: {
|
||||
globals: {
|
||||
react: 'React',
|
||||
'react-dom': 'ReactDOM',
|
||||
'@emotion/react': 'EmotionReact',
|
||||
'@invoke-ai/ui-library': 'UiLibrary',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
app: path.resolve(__dirname, '../src/app'),
|
||||
assets: path.resolve(__dirname, '../src/assets'),
|
||||
common: path.resolve(__dirname, '../src/common'),
|
||||
features: path.resolve(__dirname, '../src/features'),
|
||||
services: path.resolve(__dirname, '../src/services'),
|
||||
theme: path.resolve(__dirname, '../src/theme'),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return appConfig;
|
||||
return {
|
||||
base: './',
|
||||
plugins: [react(), mode !== 'test' && eslint(), tsconfigPaths(), visualizer() as unknown as PluginOption],
|
||||
build: {
|
||||
chunkSizeWarningLimit: 1500,
|
||||
},
|
||||
server: {
|
||||
// Proxy HTTP requests to the flask server
|
||||
proxy: {
|
||||
// Proxy socket.io to the nodes socketio server
|
||||
'/ws/socket.io': {
|
||||
target: 'ws://127.0.0.1:9090',
|
||||
ws: true,
|
||||
},
|
||||
// Proxy openapi schema definiton
|
||||
'/openapi.json': {
|
||||
target: 'http://127.0.0.1:9090/openapi.json',
|
||||
rewrite: (path) => path.replace(/^\/openapi.json/, ''),
|
||||
changeOrigin: true,
|
||||
},
|
||||
// proxy nodes api
|
||||
'/api/v1': {
|
||||
target: 'http://127.0.0.1:9090/api/v1',
|
||||
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
test: {
|
||||
//
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user